Skip to content

Commit

Permalink
make double-encoding prevention generic for any Codec
Browse files Browse the repository at this point in the history
  • Loading branch information
lhotari committed Apr 4, 2013
1 parent 00bcdae commit a73b371
Show file tree
Hide file tree
Showing 10 changed files with 151 additions and 76 deletions.
Expand Up @@ -19,6 +19,7 @@

import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.Set;

import org.springframework.util.ReflectionUtils;

Expand All @@ -31,12 +32,18 @@ public class DefaultGrailsCodecClass extends AbstractInjectableGrailsClass imple
public static final String CODEC = "Codec";
private Closure<?> encodeMethod;
private Closure<?> decodeMethod;

private static EncodingStateLookup encodingStateLookup=null;

public static void setEncodingStateLookup(EncodingStateLookup lookup) {
encodingStateLookup = lookup;
}

public DefaultGrailsCodecClass(Class<?> clazz) {
super(clazz, CODEC);

encodeMethod = getMethodOrClosureMethod(clazz, "encode");
decodeMethod = getMethodOrClosureMethod(clazz, "decode");
encodeMethod = getMethodOrClosureMethod(clazz, "encode", true);
decodeMethod = getMethodOrClosureMethod(clazz, "decode", false);
}

public Closure<?> getDecodeMethod() {
Expand All @@ -47,31 +54,101 @@ public Closure<?> getEncodeMethod() {
return encodeMethod;
}

private static class MethodCallerClosure extends Closure {
private static abstract class AbstractCallingClosure extends Closure<Object> implements Encoder {
private static final long serialVersionUID = 1L;
Method method;
public MethodCallerClosure(Object owner, Method method) {
private String codecName;
private boolean encode;

public AbstractCallingClosure(Object owner, String codecName, boolean encode) {
super(owner);
this.method = method;
maximumNumberOfParameters = 1;
parameterTypes = new Class[]{Object.class};
this.codecName = codecName;
this.encode = encode;
}

public Method getMethod() {
return method;
}

protected Object doCall(Object arguments) {
return ReflectionUtils.invokeMethod(method, !Modifier.isStatic(method.getModifiers()) ? getOwner() : null, (Object[])arguments);
}
protected abstract Object callMethod(Object argument);

@Override
public Object call(Object... args) {
return doCall(args);
}

protected Object doCall(Object[] args) {
Object target=null;
if(args != null && args.length > 0)
target=args[0];
if(target==null) {
return null;
}
if (encode) {
return encode(target);
} else {
return callMethod(target);
}
}

public String getCodecName() {
return codecName;
}

public CharSequence encode(Object target) {
if (target instanceof Encodeable) {
return ((Encodeable)target).encode(this);
}

String targetSrc = String.valueOf(target);
if(targetSrc.length() == 0) {
return "";
}
EncodingState encodingState=encodingStateLookup.lookup();
if(encodingState != null) {
Set<String> tags = encodingState.getEncodingTagsFor(targetSrc);
if(tags != null && tags.contains(codecName)) {
return targetSrc;
}
}
String encoded = String.valueOf(callMethod(targetSrc));
if(encodingState != null)
encodingState.registerEncodedWith(codecName, encoded);
return encoded;
}

public void markEncoded(CharSequence string) {
EncodingState encodingState=encodingStateLookup.lookup();
if(encodingState != null) {
encodingState.registerEncodedWith(codecName, string);
}
}
}

private static class MethodCallerClosure extends AbstractCallingClosure {
private static final long serialVersionUID = 1L;
Method method;
public MethodCallerClosure(Object owner, String codecName, boolean encode, Method method) {
super(owner, codecName, encode);
this.method = method;
}

protected Object callMethod(Object argument) {
return ReflectionUtils.invokeMethod(method, !Modifier.isStatic(method.getModifiers()) ? getOwner() : null, argument);
}
}

private Closure<?> getMethodOrClosureMethod(Class<?> clazz, String methodName) {
private static class ClosureCallerClosure extends AbstractCallingClosure {
private static final long serialVersionUID = 1L;
Closure<?> closure;
public ClosureCallerClosure(Object owner, String codecName, boolean encode, Closure<?> closure) {
super(owner, codecName, encode);
this.closure = closure;
}

protected Object callMethod(Object argument) {
return closure.call(new Object[]{argument});
}
}

private Closure<?> getMethodOrClosureMethod(Class<?> clazz, String methodName, boolean encode) {
Closure<?> closure = (Closure<?>) getPropertyOrStaticPropertyOrFieldValue(methodName, Closure.class);
if (closure == null) {
Method method = ReflectionUtils.findMethod(clazz, methodName, null);
Expand All @@ -82,9 +159,11 @@ private Closure<?> getMethodOrClosureMethod(Class<?> clazz, String methodName) {
} else {
owner=getReferenceInstance();
}
return new MethodCallerClosure(owner, method);
return new MethodCallerClosure(owner, getName(), encode, method);
}
return null;
} else {
return new ClosureCallerClosure(clazz, getName(), encode, closure);
}
return closure;
}
}
@@ -0,0 +1,5 @@
package org.codehaus.groovy.grails.commons;

public interface Encodeable {
public CharSequence encode(Encoder encoder);
}
@@ -0,0 +1,7 @@
package org.codehaus.groovy.grails.commons;

public interface Encoder {
public String getCodecName();
public CharSequence encode(Object o);
public void markEncoded(CharSequence string);
}
@@ -0,0 +1,9 @@
package org.codehaus.groovy.grails.commons;

import java.util.Set;

public interface EncodingState {
public Set<String> getEncodingTagsFor(CharSequence string);
public boolean isEncodedWith(String encoding, CharSequence string);
public void registerEncodedWith(String encoding, CharSequence escaped);
}
@@ -0,0 +1,5 @@
package org.codehaus.groovy.grails.commons;

public interface EncodingStateLookup {
public EncodingState lookup();
}
Expand Up @@ -14,11 +14,7 @@
*/
package org.codehaus.groovy.grails.plugins.codecs;

import java.util.Set;

import org.codehaus.groovy.grails.web.servlet.GrailsApplicationAttributes;
import org.codehaus.groovy.grails.web.servlet.mvc.GrailsWebRequest;
import org.codehaus.groovy.grails.web.util.StreamCharBuffer;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.util.HtmlUtils;
Expand All @@ -30,46 +26,9 @@
* @since 1.1
*/
public class HTMLCodec {
private static class HTMLEncoder implements StreamCharBuffer.Encoder {
public String getCodecName() {
return "HTML";
}

public Object encode(Object o) {
return HTMLCodec.encode(o);
}

public void markEncoded(String string) {
GrailsWebRequest webRequest = GrailsWebRequest.lookup();
if (webRequest != null) {
webRequest.registerEncodedWith("HTML", string);
}
}
}

private static HTMLEncoder encoderInstance=new HTMLEncoder();

public static CharSequence encode(Object target) {
if (target != null) {
if (target instanceof StreamCharBuffer) {
return ((StreamCharBuffer)target).encodeToBuffer(encoderInstance);
}

String targetSrc = String.valueOf(target);
if(targetSrc.length() == 0) {
return "";
}
GrailsWebRequest webRequest=GrailsWebRequest.lookup();
if(webRequest != null) {
Set<String> tags = webRequest.getEncodingTagsFor(targetSrc);
if(tags != null && tags.contains("HTML")) {
return targetSrc;
}
}
String escaped = HtmlUtils.htmlEscape(targetSrc);
if(webRequest != null)
webRequest.registerEncodedWith("HTML", escaped);
return escaped;
return HtmlUtils.htmlEscape(target.toString());
}
return null;
}
Expand Down
Expand Up @@ -14,7 +14,6 @@ class StreamCharBufferSpec extends Specification {
StreamCharBuffer buffer
CodecPrintWriter codecOut
GrailsPrintWriter out
StreamCharBuffer.Encoder htmlEncoder

def setup() {
buffer=new StreamCharBuffer()
Expand All @@ -27,7 +26,6 @@ class StreamCharBufferSpec extends Specification {
GrailsWebUtil.bindMockWebRequest()
new CodecsGrailsPlugin().configureCodecMethods(codecClass)
codecOut=new CodecPrintWriter(grailsApplication, out, HTMLCodec)
htmlEncoder = StreamCharBuffer.createEncoder(codecClass.name, codecClass.encodeMethod)
}

def "stream char buffer should support encoding"() {
Expand Down
Expand Up @@ -28,6 +28,9 @@
import javax.servlet.http.HttpServletResponse;

import org.codehaus.groovy.grails.commons.ControllerArtefactHandler;
import org.codehaus.groovy.grails.commons.DefaultGrailsCodecClass;
import org.codehaus.groovy.grails.commons.EncodingState;
import org.codehaus.groovy.grails.commons.EncodingStateLookup;
import org.codehaus.groovy.grails.commons.GrailsApplication;
import org.codehaus.groovy.grails.commons.GrailsControllerClass;
import org.codehaus.groovy.grails.web.binding.GrailsDataBinder;
Expand All @@ -53,7 +56,7 @@
* @author Graeme Rocher
* @since 0.4
*/
public class GrailsWebRequest extends DispatcherServletWebRequest implements ParameterInitializationCallback {
public class GrailsWebRequest extends DispatcherServletWebRequest implements ParameterInitializationCallback, EncodingState {

private GrailsApplicationAttributes attributes;
private GrailsParameterMap params;
Expand Down Expand Up @@ -361,7 +364,7 @@ private Set<Integer> getIdentityHashCodesForEncoding(String encoding) {
return identityHashCodes;
}

public Set<String> getEncodingTagsFor(String string) {
public Set<String> getEncodingTagsFor(CharSequence string) {
int identityHashCode = System.identityHashCode(string);
Set<String> result=null;
for(Map.Entry<String, Set<Integer>> entry : encodingTagIdentityHashCodes.entrySet()) {
Expand All @@ -379,11 +382,22 @@ public Set<String> getEncodingTagsFor(String string) {
return result;
}

public boolean isEncodedWith(String encoding, String string) {
public boolean isEncodedWith(String encoding, CharSequence string) {
return getIdentityHashCodesForEncoding(encoding).contains(System.identityHashCode(string));
}

public void registerEncodedWith(String encoding, String escaped) {
public void registerEncodedWith(String encoding, CharSequence escaped) {
getIdentityHashCodesForEncoding(encoding).add(System.identityHashCode(escaped));
}

private static final class DefaultEncodingStateLookup implements EncodingStateLookup {
public EncodingState lookup() {
return GrailsWebRequest.lookup();
}
}

static {
DefaultGrailsCodecClass.setEncodingStateLookup(new DefaultEncodingStateLookup());
}

}
Expand Up @@ -6,13 +6,13 @@
import java.io.IOException;
import java.io.Writer;

import org.codehaus.groovy.grails.commons.Encoder;
import org.codehaus.groovy.grails.commons.GrailsApplication;
import org.codehaus.groovy.grails.commons.GrailsCodecClass;
import org.codehaus.groovy.runtime.GStringImpl;

public class CodecPrintWriter extends GrailsPrintWriter {
private Closure<?> encodeClosure;
private StreamCharBuffer.Encoder encoder;

public CodecPrintWriter(GrailsApplication grailsApplication, Writer out, Class<?> codecClass) {
super(out);
Expand Down Expand Up @@ -40,7 +40,6 @@ private void initEncode(GrailsApplication grailsApplication, Class<?> codecClass
if (grailsApplication != null && codecClass != null) {
GrailsCodecClass codecArtefact = (GrailsCodecClass) grailsApplication.getArtefact("Codec", codecClass.getName());
encodeClosure = codecArtefact.getEncodeMethod();
encoder=StreamCharBuffer.createEncoder(codecArtefact.getName(), encodeClosure);
}
}

Expand Down
Expand Up @@ -39,6 +39,8 @@

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.codehaus.groovy.grails.commons.Encodeable;
import org.codehaus.groovy.grails.commons.Encoder;
import org.codehaus.groovy.grails.web.servlet.mvc.GrailsWebRequest;

/**
Expand Down Expand Up @@ -223,7 +225,7 @@
*
* @author Lari Hotari, Sagire Software Oy
*/
public class StreamCharBuffer implements Writable, CharSequence, Externalizable {
public class StreamCharBuffer implements Writable, CharSequence, Externalizable, Encodeable {
static final long serialVersionUID = 5486972234419632945L;
private static final Log log=LogFactory.getLog(StreamCharBuffer.class);

Expand Down Expand Up @@ -2131,23 +2133,17 @@ public static interface EncoderWriter {
public void write(Encoder encoder, StreamCharBuffer subBuffer) throws IOException;
}

public static interface Encoder {
public String getCodecName();
public Object encode(Object o);
public void markEncoded(String string);
}

public static Encoder createEncoder(final String codecName, final Closure<?> encodeClosure) {
return new Encoder() {
public String getCodecName() {
return codecName;
}

public Object encode(Object o) {
return encodeClosure.call(o);
public CharSequence encode(Object o) {
return (CharSequence)encodeClosure.call(o);
}

public void markEncoded(String string) {
public void markEncoded(CharSequence string) {
GrailsWebRequest webRequest = GrailsWebRequest.lookup();
if (webRequest != null) {
webRequest.registerEncodedWith(getCodecName(), string);
Expand All @@ -2163,4 +2159,8 @@ public EncodingTagsResolver getTagResolver() {
public void setTagResolver(EncodingTagsResolver tagResolver) {
this.tagsResolver = tagResolver;
}

public CharSequence encode(Encoder encoder) {
return encodeToBuffer(encoder);
}
}

0 comments on commit a73b371

Please sign in to comment.