Skip to content

Commit

Permalink
Merge pull request #936 from riganti/feature/lambdas-type-inference
Browse files Browse the repository at this point in the history
Add support for basic type inference within lambda functions
  • Loading branch information
acizmarik committed May 17, 2021
2 parents 305bc98 + f4ced59 commit 1385534
Show file tree
Hide file tree
Showing 15 changed files with 683 additions and 24 deletions.
111 changes: 111 additions & 0 deletions src/DotVVM.Framework.Tests.Common/Binding/BindingCompilationTests.cs
Expand Up @@ -278,6 +278,54 @@ public void BindingCompiler_Valid_LambdaParameterType(string expr, params Type[]
Assert.AreEqual(parameterTypes[index++], paramType);
}

[TestMethod]
[DataRow("GetFirstGenericArgType(Tuple)", typeof(int))]
[DataRow("Enumerable.Where(LongArray, item => item % 2 == 0)", typeof(long))]
[DataRow("Enumerable.Select(LongArray, item => -item)", typeof(long), typeof(long))]
[DataRow("Enumerable.Select(Enumerable.Where(LongArray, item => item % 2 == 0), item => -item)", typeof(long), typeof(long))]
public void BindingCompiler_RegularGenericMethodsInference(string expr, params Type[] instantiations)
{
var viewModel = new TestViewModel() { StringProp = "abc" };
var binding = ExecuteBinding(expr, new[] { viewModel }, null, new[] { new NamespaceImport("System.Linq") });
var genericArgs = binding.GetType().GetGenericArguments();

for (var argIndex = 0; argIndex < genericArgs.Length; argIndex++)
Assert.AreEqual(instantiations[argIndex], genericArgs[argIndex]);
}

[TestMethod]
[DataRow("LongArray.Where(item => item % 2 == 0)", typeof(long))]
[DataRow("LongArray.Select(item => -item)", typeof(long), typeof(long))]
[DataRow("LongArray.Where(item => item % 2 == 0).Select(item => -item)", typeof(long), typeof(long))]
public void BindingCompiler_ExtensionGenericMethodsInference(string expr, params Type[] instantiations)
{
var viewModel = new TestViewModel() { StringProp = "abc" };
var binding = ExecuteBinding(expr, new[] { viewModel }, null, new[] { new NamespaceImport("System.Linq") });
var genericArgs = binding.GetType().GetGenericArguments();

for (var argIndex = 0; argIndex < genericArgs.Length; argIndex++)
Assert.AreEqual(instantiations[argIndex], genericArgs[argIndex]);
}

[TestMethod]
[DataRow("LongArray.All(item => item % 2 == 0)", typeof(bool))]
[DataRow("LongArray.Any(item => item % 2 == 0)", typeof(bool))]
[DataRow("LongArray.Concat(LongArray).ToArray()", typeof(long[]))]
[DataRow("LongArray.FirstOrDefault(item => item % 2 == 0)", typeof(long))]
[DataRow("LongArray.LastOrDefault(item => item % 2 == 0)", typeof(long))]
[DataRow("LongArray.Max(item => -item)", typeof(long))]
[DataRow("LongArray.Min(item => -item)", typeof(long))]
[DataRow("LongArray.OrderBy(item => item).ToArray()", typeof(long[]))]
[DataRow("LongArray.OrderByDescending(item => item).ToArray()", typeof(long[]))]
[DataRow("LongArray.Select(item => -item).ToArray()", typeof(long[]))]
[DataRow("LongArray.Where(item => item % 2 == 0).ToArray()", typeof(long[]))]
public void BindingCompiler_LinqMethodsInference(string expr, Type resultType)
{
var viewModel = new TestViewModel() { StringProp = "abc" };
var result = ExecuteBinding(expr, new[] { viewModel }, null, new[] { new NamespaceImport("System.Linq") });
Assert.AreEqual(resultType, result.GetType());
}

[TestMethod]
[DataRow("(int arg, float arg) => ;", DisplayName = "Can not use same identifier for multiple parameters")]
[DataRow("(object _this) => ;", DisplayName = "Can not use already used identifiers for parameters")]
Expand All @@ -287,6 +335,51 @@ public void BindingCompiler_Invalid_LambdaParameters(string expr)
Assert.ThrowsException<AggregateException>(() => ExecuteBinding(expr, viewModel));
}

[TestMethod]
[DataRow("(TestViewModel vm) => vm.IntProp = 11")]
[DataRow("(TestViewModel vm) => vm.GetEnum()")]
[DataRow("(TestViewModel vm) => ()")]
[DataRow("(TestViewModel vm) => ;")]
public void BindingCompiler_Valid_LambdaToAction(string expr)
{
var viewModel = new TestViewModel();
var binding = ExecuteBinding(expr, new[] { viewModel }, null, expectedType: typeof(Action<TestViewModel>)) as Action<TestViewModel>;
Assert.AreEqual(typeof(Action<TestViewModel>), binding.GetType());
binding.Invoke(viewModel);
}

[TestMethod]
[DataRow("List.RemoveAll(item => item % 2 != 0)")]
public void BindingCompiler_Valid_LambdaToPredicate(string expr)
{
var viewModel = new TestViewModel() { List = new List<int>() { 1, 2, 3 } };
var removedCount = ExecuteBinding(expr, new[] { viewModel }, null, expectedType: typeof(int));
Assert.AreEqual(2, removedCount);
CollectionAssert.AreEqual(new List<int> { 2 }, viewModel.List);
}

[TestMethod]
[DataRow("ActionInvoker(arg => StringProp = arg)")]
[DataRow("ActionInvoker(arg => StringProp = ActionInvoker(innerArg => StringProp = innerArg))")]
[DataRow("Action2Invoker((arg1, arg2) => StringProp = arg1 + arg2)")]

public void BindingCompiler_Valid_ParameterLambdaToAction(string expr)
{
var viewModel = new TestLambdaCompilation();
var result = ExecuteBinding(expr, viewModel);
Assert.AreEqual("Action", result);
}

[TestMethod]
[DataRow("DelegateInvoker(arg => StringProp = arg)")]
[DataRow("DelegateInvoker(arg => arg + arg)")]
public void BindingCompiler_Valid_LambdaParameter_PreferFunc(string expr)
{
var viewModel = new TestLambdaCompilation();
var result = ExecuteBinding(expr, viewModel);
Assert.AreEqual("Func", result);
}

[TestMethod]
public void BindingCompiler_Valid_ExtensionMethods()
{
Expand Down Expand Up @@ -738,6 +831,8 @@ class TestViewModel
public DateTime? DateTo { get; set; }
public object Time { get; set; } = TimeSpan.FromSeconds(5);
public Guid GuidProp { get; set; }
public Tuple<int, bool> Tuple { get; set; }
public List<int> List { get; set; }

public long LongProperty { get; set; }

Expand Down Expand Up @@ -765,6 +860,11 @@ public Type GetType<T>(T param)
return typeof(T);
}

public Type GetFirstGenericArgType<T, U>(Tuple<T, U> param)
{
return typeof(T);
}

public string Cat<T>(T obj, string str = null)
{
return obj.ToString() + (str ?? StringProp);
Expand Down Expand Up @@ -798,6 +898,17 @@ public async Task<string> GetStringPropAsync()
public int MethodWithOverloads(int a, int b) => a + b;
}

class TestLambdaCompilation
{
public string StringProp { get; set; }

public string DelegateInvoker(Func<string, string> func) { func(default); return "Func"; }
public string DelegateInvoker(Action<string> action) { action(default); return "Action"; }

public string ActionInvoker(Action<string> action) { action(default); return "Action"; }
public string Action2Invoker(Action<string, string> action) { action(default, default); return "Action"; }
}

class TestViewModel2
{
public int MyProperty { get; set; }
Expand Down
Expand Up @@ -26,7 +26,7 @@ public BindingExpressionBuilder(CompiledAssemblyCache compiledAssemblyCache, Ext
this.extensionMethodsCache = extensionMethodsCache;
}

public Expression Parse(string expression, DataContextStack dataContexts, BindingParserOptions options, params KeyValuePair<string, Expression>[] additionalSymbols)
public Expression Parse(string expression, DataContextStack dataContexts, BindingParserOptions options, Type expectedType = null, params KeyValuePair<string, Expression>[] additionalSymbols)
{
try
{
Expand Down Expand Up @@ -54,7 +54,7 @@ public Expression Parse(string expression, DataContextStack dataContexts, Bindin
symbols = symbols.AddSymbols(options.ExtensionParameters.Select(p => CreateParameter(dataContexts, p.Identifier, p)));
symbols = symbols.AddSymbols(additionalSymbols);

var visitor = new ExpressionBuildingVisitor(symbols, memberExpressionFactory);
var visitor = new ExpressionBuildingVisitor(symbols, memberExpressionFactory, expectedType);
visitor.Scope = symbols.Resolve(options.ScopeParameter);
return visitor.Visit(node);
}
Expand Down Expand Up @@ -143,10 +143,10 @@ public static Expression ParseWithLambdaConversion(this IBindingExpressionBuilde
extensionParameter: new TypeConversion.MagicLambdaConversionExtensionParameter(index, p.Name, p.ParameterType)))
))
.ToArray();
return builder.Parse(expression, dataContexts, options, additionalSymbols.Concat(delegateSymbols).ToArray());
return builder.Parse(expression, dataContexts, options, expectedType, additionalSymbols.Concat(delegateSymbols).ToArray());
}
else
return builder.Parse(expression, dataContexts, options, additionalSymbols);
return builder.Parse(expression, dataContexts, options, expectedType, additionalSymbols);
}
}
}
Expand Up @@ -7,8 +7,10 @@
using DotVVM.Framework.Utils;
using System.Linq;
using System.Threading.Tasks;
using DotVVM.Framework.Compilation.Inference;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Reflection;

namespace DotVVM.Framework.Compilation.Binding
{
Expand All @@ -17,16 +19,21 @@ public class ExpressionBuildingVisitor : BindingParserNodeVisitor<Expression>
public TypeRegistry Registry { get; set; }
public Expression? Scope { get; set; }
public bool ResolveOnlyTypeName { get; set; }
public Type? ExpectedType { get; set; }
public ImmutableDictionary<string, ParameterExpression> Variables { get; set; } =
ImmutableDictionary<string, ParameterExpression>.Empty;

private TypeInferer inferer;
private int expressionDepth;
private List<Exception>? currentErrors;
private readonly MemberExpressionFactory memberExpressionFactory;

public ExpressionBuildingVisitor(TypeRegistry registry, MemberExpressionFactory memberExpressionFactory)
public ExpressionBuildingVisitor(TypeRegistry registry, MemberExpressionFactory memberExpressionFactory, Type? expectedType = null)
{
Registry = registry;
ExpectedType = expectedType;
this.memberExpressionFactory = memberExpressionFactory;
this.inferer = new TypeInferer();
}

[return: MaybeNull]
Expand Down Expand Up @@ -90,13 +97,15 @@ public override Expression Visit(BindingParserNode node)
var errors = currentErrors;
try
{
expressionDepth++;
ThrowIfNotTypeNameRelevant(node);
return base.Visit(node);
}
finally
{
currentErrors = errors;
Registry = regBackup;
expressionDepth--;
}
}

Expand Down Expand Up @@ -233,10 +242,31 @@ protected override Expression VisitFunctionCall(FunctionCallBindingParserNode no
{
var target = HandleErrors(node.TargetExpression, Visit);
var args = new Expression[node.ArgumentExpressions.Count];
for (int i = 0; i < args.Length; i++)

inferer.BeginFunctionCall(target as MethodGroupExpression, args.Length);

var lambdaNodeIndices = new List<int>();
// Initially process all nodes that are not lambdas
for (var i = 0; i < args.Length; i++)
{
if (node.ArgumentExpressions[i] is LambdaBindingParserNode)
{
lambdaNodeIndices.Add(i);
continue;
}

args[i] = HandleErrors(node.ArgumentExpressions[i], Visit)!;
inferer.SetArgument(args[i], i);
}
// Subsequently process all lambdas
foreach (var index in lambdaNodeIndices)
{
inferer.SetProbedArgumentIndex(index);
args[index] = HandleErrors(node.ArgumentExpressions[index], Visit)!;
inferer.SetArgument(args[index], index);
}

inferer.EndFunctionCall();
ThrowOnErrors();

return memberExpressionFactory.Call(target, args);
Expand Down Expand Up @@ -315,8 +345,21 @@ protected override Expression VisitLambda(LambdaBindingParserNode node)
{
// Create lambda definition
var lambdaParameters = new ParameterExpression[node.ParameterExpressions.Count];

// Apply information from type inference if available
var hintType = (expressionDepth == 1) ? ExpectedType : null;
var typeInferenceData = inferer.Infer(hintType).Lambda(node.ParameterExpressions.Count);
if (typeInferenceData.Result)
{
for (var paramIndex = 0; paramIndex < typeInferenceData.Parameters!.Length; paramIndex++)
{
var currentParamType = typeInferenceData.Parameters[paramIndex];
node.ParameterExpressions[paramIndex].SetResolvedType(currentParamType);
}
}

for (var i = 0; i < lambdaParameters.Length; i++)
lambdaParameters[i] = (ParameterExpression)HandleErrors(node.ParameterExpressions[i], Visit);
lambdaParameters[i] = (ParameterExpression)HandleErrors(node.ParameterExpressions[i], Visit)!;

// Make sure that parameter identifiers are distinct
if (lambdaParameters.GroupBy(param => param.Name).Any(group => group.Count() > 1))
Expand All @@ -337,16 +380,49 @@ protected override Expression VisitLambda(LambdaBindingParserNode node)
var body = Visit(node.BodyExpression);

ThrowOnErrors();
return Expression.Lambda(body, lambdaParameters);
return CreateLambdaExpression(body, lambdaParameters, typeInferenceData.Type);
}

protected override Expression VisitLambdaParameter(LambdaParameterBindingParserNode node)
{
if (node.Type == null)
if (node.Type == null && node.ResolvedType == null)
throw new BindingCompilationException($"Could not infer type of parameter.", node);

var parameterType = Visit(node.Type).Type;
return Expression.Parameter(parameterType, node.Name.ToDisplayString());
if (node.ResolvedType != null)
{
// Type was not specified but was infered
return Expression.Parameter(node.ResolvedType, node.Name.ToDisplayString());
}
else
{
// Type was specified and needs to be obtained from binding node
var parameterType = Visit(node.Type).Type;
return Expression.Parameter(parameterType, node.Name.ToDisplayString());
}
}

private Expression CreateLambdaExpression(Expression body, ParameterExpression[] parameters, Type? delegateType)
{
if (delegateType != null && delegateType.Namespace == "System")
{
if (delegateType.Name == "Action" || delegateType.Name == $"Action`{parameters.Length}")
{
// We must validate that lambda body contains a valid statement
if ((body.NodeType != ExpressionType.Default) && (body.NodeType != ExpressionType.Block) && (body.NodeType != ExpressionType.Call) && (body.NodeType != ExpressionType.Assign))
throw new DotvvmCompilationException($"Only method invocations and assignments can be used as statements.");

// Make sure the result type will be void by adding an empty expression
return Expression.Lambda(Expression.Block(body, Expression.Empty()), parameters);
}
else if (delegateType.Name == "Predicate`1")
{
var type = delegateType.GetGenericTypeDefinition().MakeGenericType(parameters.Single().Type);
return Expression.Lambda(type, body, parameters);
}
}

// Assume delegate is a System.Func<...>
return Expression.Lambda(body, parameters);
}

protected override Expression VisitBlock(BlockBindingParserNode node)
Expand Down
Expand Up @@ -50,11 +50,13 @@ public Expression GetMember(Expression target, string name, Type[] typeArguments
.Where(m => ((isGeneric && m is TypeInfo) ? genericName : name) == m.Name)
.ToArray();

var isExtension = false;
if (members.Length == 0)
{
// We did not find any match in regular methods => try extension methods
var extensions = GetAllExtensionMethods().Where(m => m.Name == name).ToArray();
var extensions = GetAllExtensionMethods().Where(m => m.Name == name && ExtensionMethodsFilter(target, m)).ToArray();
members = extensions;
isExtension = true;

if (members.Length == 0 && throwExceptions)
throw new Exception($"Could not find { (isStatic ? "static" : "instance") } member { name } on type { type.FullName }.");
Expand Down Expand Up @@ -84,7 +86,29 @@ public Expression GetMember(Expression target, string name, Type[] typeArguments
: new StaticClassIdentifierExpression(nonGenericType.UnderlyingSystemType);
}
}
return new MethodGroupExpression() { MethodName = name, Target = target, TypeArgs = typeArguments };

var candidates = members.Cast<MethodInfo>().ToList();
return new MethodGroupExpression() { MethodName = name, Target = target, TypeArgs = typeArguments, Candidates = candidates, HasExtensionCandidates = isExtension };
}

private bool ExtensionMethodsFilter(Expression target, MethodInfo method)
{
var thisType = method.GetParameters().First().ParameterType;
if (thisType.IsGenericType)
{
if (thisType.ContainsGenericParameters)
{
return ReflectionUtils.IsAssignableToGenericType(target.Type, thisType.GetGenericTypeDefinition(), out _);
}
else
{
return thisType.IsAssignableFrom(target.Type);
}
}
else
{
return thisType.IsAssignableFrom(target.Type);
}
}

private Expression GetDotvvmPropertyMember(Expression target, string name)
Expand Down
Expand Up @@ -18,6 +18,8 @@ public class MethodGroupExpression : Expression
public Expression Target { get; set; }
public string MethodName { get; set; }
public Type[] TypeArgs { get; set; }
public List<MethodInfo> Candidates { get; set; }
public bool HasExtensionCandidates { get; set; }
public bool IsStatic => Target is StaticClassIdentifierExpression;

private static MethodInfo CreateDelegateMethodInfo = typeof(Delegate).GetMethod("CreateDelegate", new[] { typeof(Type), typeof(object), typeof(MethodInfo) });
Expand Down

0 comments on commit 1385534

Please sign in to comment.