diff --git a/hessian-lite/src/main/java/com/alibaba/com/caucho/hessian/io/JavaDeserializer.java b/hessian-lite/src/main/java/com/alibaba/com/caucho/hessian/io/JavaDeserializer.java index bd217c4e2bcc..d5de4f37f0fa 100644 --- a/hessian-lite/src/main/java/com/alibaba/com/caucho/hessian/io/JavaDeserializer.java +++ b/hessian-lite/src/main/java/com/alibaba/com/caucho/hessian/io/JavaDeserializer.java @@ -70,6 +70,7 @@ public class JavaDeserializer extends AbstractMapDeserializer { private Method _readResolve; private Constructor _constructor; private Object []_constructorArgs; + private boolean compatibleConstructNPE = true; public JavaDeserializer(Class cl) { @@ -267,15 +268,49 @@ protected Object instantiate() throws Exception { try { - if (_constructor != null) - return _constructor.newInstance(_constructorArgs); - else - return _type.newInstance(); + return _constructor == null ? _type.newInstance() : construct(); } catch (Exception e) { throw new HessianProtocolException("'" + _type.getName() + "' could not be instantiated", e); } } + protected Object construct() throws Exception { + InvocationTargetException ex; + try { + return _constructor.newInstance(_constructorArgs); + } catch (InvocationTargetException e) { + if (!compatibleConstructNPE + || !(e.getTargetException() instanceof NullPointerException)) { + throw e; + } + + ex = e; + } + + Class[] types = _constructor.getParameterTypes(); + Object[] args = new Object[types.length]; + System.arraycopy(_constructorArgs, 0, args, 0, types.length); + try { + for (int i = 0; i < types.length; i++) { + if (args[i] == null) { + try { + Constructor ctor = types[i].getDeclaredConstructor(new Class[0]); + if (!ctor.isAccessible()) ctor.setAccessible(true); + args[i] = ctor.newInstance(new Object[0]); + } catch (Exception e) { + } + } + } + Object ret = _constructor.newInstance(args); + _constructorArgs = args; + return ret; + } catch (Throwable t) { + } + + compatibleConstructNPE = false; + throw ex; + } + /** * Creates a map of the classes fields. */ diff --git a/hessian-lite/src/test/java/com/alibaba/com/caucho/hessian/io/JavaDeserializerTest.java b/hessian-lite/src/test/java/com/alibaba/com/caucho/hessian/io/JavaDeserializerTest.java new file mode 100644 index 000000000000..68e153e36746 --- /dev/null +++ b/hessian-lite/src/test/java/com/alibaba/com/caucho/hessian/io/JavaDeserializerTest.java @@ -0,0 +1,81 @@ +package com.alibaba.com.caucho.hessian.io; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.fail; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; +import java.sql.SQLException; + +import org.junit.Test; + +import com.alibaba.com.caucho.hessian.io.model.ConstructAlwaysNPE; +import com.alibaba.com.caucho.hessian.io.model.ConstructNPE; + +public class JavaDeserializerTest { + + /** + * #210 + * @see org.springframework.jdbc.UncategorizedSQLException + */ + @Test + public void testConstructorNPE() throws Exception { + String sql = "select * from demo"; + SQLException sqlEx = new SQLException("just a sql exception"); + ConstructNPE normalNPE = new ConstructNPE("junit", sql, sqlEx); + ConstructAlwaysNPE alwaysNPE = new ConstructAlwaysNPE("junit", sql, sqlEx); + + ByteArrayOutputStream bout = new ByteArrayOutputStream(); + Hessian2Output out = new Hessian2Output(bout); + out.writeObject(normalNPE); + out.writeObject(alwaysNPE); + out.flush(); + + SerializerFactory factory = new SerializerFactory(); + for (int repeat = 0; repeat < 2; repeat++) { + Hessian2Input input = new Hessian2Input(new ByteArrayInputStream(bout.toByteArray())); + input.setSerializerFactory(factory); + + assertDesEquals(normalNPE, (ConstructNPE) input.readObject()); + assertCompatibleConstructNPE(factory.getDeserializer(normalNPE.getClass()), true); + + try { + input.readObject(); + fail("must be always throw NullPointerException"); + } catch (HessianProtocolException e) { + assertEquals(InvocationTargetException.class, e.getCause().getClass()); + assertEquals(NullPointerException.class, e.getCause().getCause().getClass()); + } + assertCompatibleConstructNPE(factory.getDeserializer(alwaysNPE.getClass()), false); + } + } + + private void assertDesEquals(ConstructNPE expected, ConstructNPE actual) { + assertEquals(expected.getMessage(), actual.getMessage()); + assertEquals(expected.getCause().getClass(), actual.getCause().getClass()); + assertEquals(expected.getSql(), actual.getSql()); + } + + private void assertCompatibleConstructNPE(Deserializer deserializer, boolean compatible) throws Exception { + assertEquals(JavaDeserializer.class, deserializer.getClass()); + assertEquals(compatible, getFieldValue(deserializer, "compatibleConstructNPE")); + Object[] args = (Object[]) getFieldValue(deserializer, "_constructorArgs"); + for (int i = 0; i < args.length; i++) { + if (compatible) { + assertNotNull(args[i]); + } else { + assertNull(args[i]); + } + } + } + + public Object getFieldValue(Object bean, String fieldName) throws Exception { + Field field = bean.getClass().getDeclaredField(fieldName); + if (!field.isAccessible()) field.setAccessible(true); + return field.get(bean); + } +} diff --git a/hessian-lite/src/test/java/com/alibaba/com/caucho/hessian/io/model/ConstructAlwaysNPE.java b/hessian-lite/src/test/java/com/alibaba/com/caucho/hessian/io/model/ConstructAlwaysNPE.java new file mode 100644 index 000000000000..68a08f4cd7ea --- /dev/null +++ b/hessian-lite/src/test/java/com/alibaba/com/caucho/hessian/io/model/ConstructAlwaysNPE.java @@ -0,0 +1,26 @@ +package com.alibaba.com.caucho.hessian.io.model; + +import java.sql.SQLException; + +/** + * #210 + */ +public class ConstructAlwaysNPE extends RuntimeException { + private static final long serialVersionUID = 1L; + private final String sql; + + public ConstructAlwaysNPE(String task, String sql, SQLException ex) { + super(task + "; uncategorized SQLException for SQL [" + sql + "]; SQL state [" + + ex.getSQLState() + "]; error code [" + ex.getErrorCode() + "]; " + ex.getMessage(), ex); + if (sql.length() == 0) throw new NullPointerException("sql=" + sql); + this.sql = sql; + } + + public SQLException getSQLException() { + return (SQLException) getCause(); + } + + public String getSql() { + return this.sql; + } +} diff --git a/hessian-lite/src/test/java/com/alibaba/com/caucho/hessian/io/model/ConstructNPE.java b/hessian-lite/src/test/java/com/alibaba/com/caucho/hessian/io/model/ConstructNPE.java new file mode 100644 index 000000000000..f3c7c09469c5 --- /dev/null +++ b/hessian-lite/src/test/java/com/alibaba/com/caucho/hessian/io/model/ConstructNPE.java @@ -0,0 +1,26 @@ +package com.alibaba.com.caucho.hessian.io.model; + +import java.sql.SQLException; + +/** + * #210 + * @see org.springframework.jdbc.UncategorizedSQLException + */ +public class ConstructNPE extends RuntimeException { + private static final long serialVersionUID = 1L; + private final String sql; + + public ConstructNPE(String task, String sql, SQLException ex) { + super(task + "; uncategorized SQLException for SQL [" + sql + "]; SQL state [" + + ex.getSQLState() + "]; error code [" + ex.getErrorCode() + "]; " + ex.getMessage(), ex); + this.sql = sql; + } + + public SQLException getSQLException() { + return (SQLException) getCause(); + } + + public String getSql() { + return this.sql; + } +}