Skip to content

Commit

Permalink
Copy LambdaDeserializer from scala-java8-compat to scala.runtime
Browse files Browse the repository at this point in the history
The original file is here:
https://github.com/scala/scala-java8-compat/blob/master/src/main/java/scala/compat/java8/runtime/LambdaDeserializer.scala

The only difference is the package name (changed to scala.runtime).

This commit is a cherry-pick of retronym's c0732e6
  • Loading branch information
retronym authored and lrytz committed Jun 30, 2015
1 parent 640ffe7 commit b0b73c5
Show file tree
Hide file tree
Showing 3 changed files with 327 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/compiler/scala/tools/nsc/backend/jvm/BCodeHelpers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ abstract class BCodeHelpers extends BCodeIdiomatic with BytecodeWriters {
* cache = new java.util.HashMap()
* $deserializeLambdaCache$ = cache
* }
* return scala.compat.java8.runtime.LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), cache, l);
* return scala.runtime.LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), cache, l);
* }
*/
def addLambdaDeserialize(clazz: Symbol, jclass: asm.ClassVisitor): Unit = {
Expand Down Expand Up @@ -731,7 +731,7 @@ abstract class BCodeHelpers extends BCodeIdiomatic with BytecodeWriters {
mv.visitMethodInsn(INVOKESTATIC, "java/lang/invoke/MethodHandles", "lookup", "()Ljava/lang/invoke/MethodHandles$Lookup;", false)
mv.visitVarInsn(ALOAD, 1)
mv.visitVarInsn(ALOAD, 0)
mv.visitMethodInsn(INVOKESTATIC, "scala/compat/java8/runtime/LambdaDeserializer", "deserializeLambda", "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/util/Map;Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object;", false)
mv.visitMethodInsn(INVOKESTATIC, "scala/runtime/LambdaDeserializer", "deserializeLambda", "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/util/Map;Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object;", false)
mv.visitInsn(ARETURN)
mv.visitEnd()
}
Expand Down
132 changes: 132 additions & 0 deletions src/library/scala/runtime/LambdaDeserializer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package scala.runtime

import java.lang.invoke._

/**
* This class is only intended to be called by synthetic `$deserializeLambda$` method that the Scala 2.12
* compiler will add to classes hosting lambdas.
*
* It is not intended to be consumed directly.
*/
object LambdaDeserializer {
/**
* Deserialize a lambda by calling `LambdaMetafactory.altMetafactory` to spin up a lambda class
* and instantiating this class with the captured arguments.
*
* A cache may be provided to ensure that subsequent deserialization of the same lambda expression
* is cheap, it amounts to a reflective call to the constructor of the previously created class.
* However, deserialization of the same lambda expression is not guaranteed to use the same class,
* concurrent deserialization of the same lambda expression may spin up more than one class.
*
* Assumptions:
* - No additional marker interfaces are required beyond `{java.io,scala.}Serializable`. These are
* not stored in `SerializedLambda`, so we can't reconstitute them.
* - No additional bridge methods are passed to `altMetafactory`. Again, these are not stored.
*
* @param lookup The factory for method handles. Must have access to the implementation method, the
* functional interface class, and `java.io.Serializable` or `scala.Serializable` as
* required.
* @param cache A cache used to avoid spinning up a class for each deserialization of a given lambda. May be `null`
* @param serialized The lambda to deserialize. Note that this is typically created by the `readResolve`
* member of the anonymous class created by `LambdaMetaFactory`.
* @return An instance of the functional interface
*/
def deserializeLambda(lookup: MethodHandles.Lookup, cache: java.util.Map[String, MethodHandle], serialized: SerializedLambda): AnyRef = {
def slashDot(name: String) = name.replaceAll("/", ".")
val loader = lookup.lookupClass().getClassLoader
val implClass = loader.loadClass(slashDot(serialized.getImplClass))

def makeCallSite: CallSite = {
import serialized._
def parseDescriptor(s: String) =
MethodType.fromMethodDescriptorString(s, loader)

val funcInterfaceSignature = parseDescriptor(getFunctionalInterfaceMethodSignature)
val instantiated = parseDescriptor(getInstantiatedMethodType)
val functionalInterfaceClass = loader.loadClass(slashDot(getFunctionalInterfaceClass))

val implMethodSig = parseDescriptor(getImplMethodSignature)
// Construct the invoked type from the impl method type. This is the type of a factory
// that will be generated by the meta-factory. It is a method type, with param types
// coming form the types of the captures, and return type being the functional interface.
val invokedType: MethodType = {
// 1. Add receiver for non-static impl methods
val withReceiver = getImplMethodKind match {
case MethodHandleInfo.REF_invokeStatic | MethodHandleInfo.REF_newInvokeSpecial =>
implMethodSig
case _ =>
implMethodSig.insertParameterTypes(0, implClass)
}
// 2. Remove lambda parameters, leaving only captures. Note: the receiver may be a lambda parameter,
// such as in `Function<Object, String> s = Object::toString`
val lambdaArity = funcInterfaceSignature.parameterCount()
val from = withReceiver.parameterCount() - lambdaArity
val to = withReceiver.parameterCount()

// 3. Drop the lambda return type and replace with the functional interface.
withReceiver.dropParameterTypes(from, to).changeReturnType(functionalInterfaceClass)
}

// Lookup the implementation method
val implMethod: MethodHandle = try {
findMember(lookup, getImplMethodKind, implClass, getImplMethodName, implMethodSig)
} catch {
case e: ReflectiveOperationException => throw new IllegalArgumentException("Illegal lambda deserialization", e)
}

val flags: Int = LambdaMetafactory.FLAG_SERIALIZABLE | LambdaMetafactory.FLAG_MARKERS
val isScalaFunction = functionalInterfaceClass.getName.startsWith("scala.Function")
val markerInterface: Class[_] = loader.loadClass(if (isScalaFunction) ScalaSerializable else JavaIOSerializable)

LambdaMetafactory.altMetafactory(
lookup, getFunctionalInterfaceMethodName, invokedType,

/* samMethodType = */ funcInterfaceSignature,
/* implMethod = */ implMethod,
/* instantiatedMethodType = */ instantiated,
/* flags = */ flags.asInstanceOf[AnyRef],
/* markerInterfaceCount = */ 1.asInstanceOf[AnyRef],
/* markerInterfaces[0] = */ markerInterface,
/* bridgeCount = */ 0.asInstanceOf[AnyRef]
)
}

val key = serialized.getImplMethodName + " : " + serialized.getImplMethodSignature
val factory: MethodHandle = if (cache == null) {
makeCallSite.getTarget
} else cache.get(key) match {
case null =>
val callSite = makeCallSite
val temp = callSite.getTarget
cache.put(key, temp)
temp
case target => target
}

val captures = Array.tabulate(serialized.getCapturedArgCount)(n => serialized.getCapturedArg(n))
factory.invokeWithArguments(captures: _*)
}

private val ScalaSerializable = "scala.Serializable"

private val JavaIOSerializable = {
// We could actually omit this marker interface as LambdaMetaFactory will add it if
// the FLAG_SERIALIZABLE is set and of the provided markers extend it. But the code
// is cleaner if we uniformly add a single marker, so I'm leaving it in place.
"java.io.Serializable"
}

private def findMember(lookup: MethodHandles.Lookup, kind: Int, owner: Class[_],
name: String, signature: MethodType): MethodHandle = {
kind match {
case MethodHandleInfo.REF_invokeStatic =>
lookup.findStatic(owner, name, signature)
case MethodHandleInfo.REF_newInvokeSpecial =>
lookup.findConstructor(owner, signature)
case MethodHandleInfo.REF_invokeVirtual | MethodHandleInfo.REF_invokeInterface =>
lookup.findVirtual(owner, name, signature)
case MethodHandleInfo.REF_invokeSpecial =>
lookup.findSpecial(owner, name, signature, owner)
}
}
}
193 changes: 193 additions & 0 deletions test/junit/scala/runtime/LambdaDeserializerTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
package scala.runtime;

import org.junit.Assert;
import org.junit.Test;

import java.io.Serializable;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.SerializedLambda;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashMap;

public final class LambdaDeserializerTest {
private LambdaHost lambdaHost = new LambdaHost();

@Test
public void serializationPrivate() {
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByPrivateImplMethod();
Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true));
}

@Test
public void serializationStatic() {
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod();
Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true));
}

@Test
public void serializationVirtualMethodReference() {
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByVirtualMethodReference();
Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true));
}

@Test
public void serializationInterfaceMethodReference() {
F1<I, Object> f1 = lambdaHost.lambdaBackedByInterfaceMethodReference();
I i = new I() {
};
Assert.assertEquals(f1.apply(i), reconstitute(f1).apply(i));
}

@Test
public void serializationStaticMethodReference() {
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticMethodReference();
Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true));
}

@Test
public void serializationNewInvokeSpecial() {
F0<Object> f1 = lambdaHost.lambdaBackedByConstructorCall();
Assert.assertEquals(f1.apply(), reconstitute(f1).apply());
}

@Test
public void uncached() {
F0<Object> f1 = lambdaHost.lambdaBackedByConstructorCall();
F0<Object> reconstituted1 = reconstitute(f1);
F0<Object> reconstituted2 = reconstitute(f1);
Assert.assertNotEquals(reconstituted1.getClass(), reconstituted2.getClass());
}

@Test
public void cached() {
HashMap<String, MethodHandle> cache = new HashMap<>();
F0<Object> f1 = lambdaHost.lambdaBackedByConstructorCall();
F0<Object> reconstituted1 = reconstitute(f1, cache);
F0<Object> reconstituted2 = reconstitute(f1, cache);
Assert.assertEquals(reconstituted1.getClass(), reconstituted2.getClass());
}

@Test
public void cachedStatic() {
HashMap<String, MethodHandle> cache = new HashMap<>();
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod();
// Check that deserialization of a static lambda always returns the
// same instance.
Assert.assertSame(reconstitute(f1, cache), reconstitute(f1, cache));

// (as is the case with regular invocation.)
Assert.assertSame(f1, lambdaHost.lambdaBackedByStaticImplMethod());
}

@Test
public void implMethodNameChanged() {
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod();
SerializedLambda sl = writeReplace(f1);
checkIllegalAccess(copySerializedLambda(sl, sl.getImplMethodName() + "___", sl.getImplMethodSignature()));
}

@Test
public void implMethodSignatureChanged() {
F1<Boolean, String> f1 = lambdaHost.lambdaBackedByStaticImplMethod();
SerializedLambda sl = writeReplace(f1);
checkIllegalAccess(copySerializedLambda(sl, sl.getImplMethodName(), sl.getImplMethodSignature().replace("Boolean", "Integer")));
}

private void checkIllegalAccess(SerializedLambda serialized) {
try {
LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), null, serialized);
throw new AssertionError();
} catch (IllegalArgumentException iae) {
if (!iae.getMessage().contains("Illegal lambda deserialization")) {
Assert.fail("Unexpected message: " + iae.getMessage());
}
}
}

private SerializedLambda copySerializedLambda(SerializedLambda sl, String implMethodName, String implMethodSignature) {
Object[] captures = new Object[sl.getCapturedArgCount()];
for (int i = 0; i < captures.length; i++) {
captures[i] = sl.getCapturedArg(i);
}
return new SerializedLambda(loadClass(sl.getCapturingClass()), sl.getFunctionalInterfaceClass(), sl.getFunctionalInterfaceMethodName(),
sl.getFunctionalInterfaceMethodSignature(), sl.getImplMethodKind(), sl.getImplClass(), implMethodName, implMethodSignature,
sl.getInstantiatedMethodType(), captures);
}

private Class<?> loadClass(String className) {
try {
return Class.forName(className.replace('/', '.'));
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
}
private <A, B> A reconstitute(A f1) {
return reconstitute(f1, null);
}

@SuppressWarnings("unchecked")
private <A, B> A reconstitute(A f1, java.util.HashMap<String, MethodHandle> cache) {
try {
return (A) LambdaDeserializer.deserializeLambda(LambdaHost.lookup(), cache, writeReplace(f1));
} catch (Exception e) {
throw new RuntimeException(e);
}
}

private <A> SerializedLambda writeReplace(A f1) {
try {
Method writeReplace = f1.getClass().getDeclaredMethod("writeReplace");
writeReplace.setAccessible(true);
return (SerializedLambda) writeReplace.invoke(f1);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}


interface F1<A, B> extends Serializable {
B apply(A a);
}

interface F0<A> extends Serializable {
A apply();
}

class LambdaHost {
public F1<Boolean, String> lambdaBackedByPrivateImplMethod() {
int local = 42;
return (b) -> Arrays.asList(local, b ? "true" : "false", LambdaHost.this).toString();
}

@SuppressWarnings("Convert2MethodRef")
public F1<Boolean, String> lambdaBackedByStaticImplMethod() {
return (b) -> String.valueOf(b);
}

public F1<Boolean, String> lambdaBackedByStaticMethodReference() {
return String::valueOf;
}

public F1<Boolean, String> lambdaBackedByVirtualMethodReference() {
return Object::toString;
}

public F1<I, Object> lambdaBackedByInterfaceMethodReference() {
return I::i;
}

public F0<Object> lambdaBackedByConstructorCall() {
return String::new;
}

public static MethodHandles.Lookup lookup() {
return MethodHandles.lookup();
}
}

interface I {
default String i() { return "i"; };
}

0 comments on commit b0b73c5

Please sign in to comment.