Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for basic type inference within lambda functions #936

Merged
merged 23 commits into from May 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f914f2c
Added basic support for type inference
acizmarik Feb 4, 2021
fb230f2
Fixed comparing instatiated with open-generic types
acizmarik Feb 5, 2021
8e95140
Added more test for Inferer
acizmarik Feb 5, 2021
9f2f219
Fixed issue with accessing parameters out of range when probing candi…
acizmarik Feb 9, 2021
48bda1a
Adjusted sample to showcase type inference
acizmarik Feb 9, 2021
b99d542
Merged changes from main
acizmarik Mar 23, 2021
b79d913
Inferer improvements, out-of-order arguments resolving support
acizmarik Mar 26, 2021
26cfda9
Added more tests for type inference
acizmarik Mar 26, 2021
b4b2656
Merge branch 'main' into feature/lambdas-type-inference
acizmarik Apr 12, 2021
a7f6f34
Fixed issue when compiling lambdas to Actions
acizmarik Apr 14, 2021
e21521b
Added lambda to Action tests
acizmarik Apr 14, 2021
bfd5789
Type inferer improvements and fixes for System.Action
acizmarik Apr 19, 2021
f818585
Removed unsupported test cases
acizmarik Apr 19, 2021
8aaa9b3
Refactored changes
acizmarik Apr 20, 2021
429023a
Extension methods candidates are now filtered based on the `this` exp…
acizmarik Apr 20, 2021
124ca49
Added disambiguation strategy for lambdas that differ in return types
acizmarik Apr 20, 2021
a6737b4
Added more tests
acizmarik Apr 20, 2021
bf9034e
Fixed issue with resolving generics
acizmarik May 13, 2021
59d8d65
Removed unnecessary access modifiers
acizmarik May 13, 2021
ee14b14
Added support for System.Predicate
acizmarik May 13, 2021
b37a79a
Merged remote changes
acizmarik May 13, 2021
ff68763
Removed obsolete code
acizmarik May 14, 2021
f4ced59
Added check for delegate namespace and missing variant with backtick
acizmarik May 14, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
111 changes: 111 additions & 0 deletions src/DotVVM.Framework.Tests.Common/Binding/BindingCompilationTests.cs
Expand Up @@ -196,6 +196,54 @@ public void BindingCompiler_Valid_LambdaParameterType(string expr, params Type[]
Assert.AreEqual(parameterTypes[index++], paramType);
}

[TestMethod]
[DataRow("GetFirstGenericArgType(Tuple)", typeof(int))]
[DataRow("Enumerable.Where(LongArray, item => item % 2 == 0)", typeof(long))]
[DataRow("Enumerable.Select(LongArray, item => -item)", typeof(long), typeof(long))]
[DataRow("Enumerable.Select(Enumerable.Where(LongArray, item => item % 2 == 0), item => -item)", typeof(long), typeof(long))]
public void BindingCompiler_RegularGenericMethodsInference(string expr, params Type[] instantiations)
{
var viewModel = new TestViewModel() { StringProp = "abc" };
var binding = ExecuteBinding(expr, new[] { viewModel }, null, new[] { new NamespaceImport("System.Linq") });
var genericArgs = binding.GetType().GetGenericArguments();

for (var argIndex = 0; argIndex < genericArgs.Length; argIndex++)
Assert.AreEqual(instantiations[argIndex], genericArgs[argIndex]);
}

[TestMethod]
[DataRow("LongArray.Where(item => item % 2 == 0)", typeof(long))]
[DataRow("LongArray.Select(item => -item)", typeof(long), typeof(long))]
[DataRow("LongArray.Where(item => item % 2 == 0).Select(item => -item)", typeof(long), typeof(long))]
public void BindingCompiler_ExtensionGenericMethodsInference(string expr, params Type[] instantiations)
{
var viewModel = new TestViewModel() { StringProp = "abc" };
var binding = ExecuteBinding(expr, new[] { viewModel }, null, new[] { new NamespaceImport("System.Linq") });
var genericArgs = binding.GetType().GetGenericArguments();

for (var argIndex = 0; argIndex < genericArgs.Length; argIndex++)
Assert.AreEqual(instantiations[argIndex], genericArgs[argIndex]);
}

[TestMethod]
[DataRow("LongArray.All(item => item % 2 == 0)", typeof(bool))]
[DataRow("LongArray.Any(item => item % 2 == 0)", typeof(bool))]
[DataRow("LongArray.Concat(LongArray).ToArray()", typeof(long[]))]
[DataRow("LongArray.FirstOrDefault(item => item % 2 == 0)", typeof(long))]
[DataRow("LongArray.LastOrDefault(item => item % 2 == 0)", typeof(long))]
[DataRow("LongArray.Max(item => -item)", typeof(long))]
[DataRow("LongArray.Min(item => -item)", typeof(long))]
[DataRow("LongArray.OrderBy(item => item).ToArray()", typeof(long[]))]
[DataRow("LongArray.OrderByDescending(item => item).ToArray()", typeof(long[]))]
[DataRow("LongArray.Select(item => -item).ToArray()", typeof(long[]))]
[DataRow("LongArray.Where(item => item % 2 == 0).ToArray()", typeof(long[]))]
public void BindingCompiler_LinqMethodsInference(string expr, Type resultType)
{
var viewModel = new TestViewModel() { StringProp = "abc" };
var result = ExecuteBinding(expr, new[] { viewModel }, null, new[] { new NamespaceImport("System.Linq") });
Assert.AreEqual(resultType, result.GetType());
}

[TestMethod]
[DataRow("(int arg, float arg) => ;", DisplayName = "Can not use same identifier for multiple parameters")]
[DataRow("(object _this) => ;", DisplayName = "Can not use already used identifiers for parameters")]
Expand All @@ -205,6 +253,51 @@ public void BindingCompiler_Invalid_LambdaParameters(string expr)
Assert.ThrowsException<AggregateException>(() => ExecuteBinding(expr, viewModel));
}

[TestMethod]
[DataRow("(TestViewModel vm) => vm.IntProp = 11")]
[DataRow("(TestViewModel vm) => vm.GetEnum()")]
[DataRow("(TestViewModel vm) => ()")]
[DataRow("(TestViewModel vm) => ;")]
public void BindingCompiler_Valid_LambdaToAction(string expr)
{
var viewModel = new TestViewModel();
var binding = ExecuteBinding(expr, new[] { viewModel }, null, expectedType: typeof(Action<TestViewModel>)) as Action<TestViewModel>;
Assert.AreEqual(typeof(Action<TestViewModel>), binding.GetType());
binding.Invoke(viewModel);
}

[TestMethod]
[DataRow("List.RemoveAll(item => item % 2 != 0)")]
public void BindingCompiler_Valid_LambdaToPredicate(string expr)
{
var viewModel = new TestViewModel() { List = new List<int>() { 1, 2, 3 } };
var removedCount = ExecuteBinding(expr, new[] { viewModel }, null, expectedType: typeof(int));
Assert.AreEqual(2, removedCount);
CollectionAssert.AreEqual(new List<int> { 2 }, viewModel.List);
}

[TestMethod]
[DataRow("ActionInvoker(arg => StringProp = arg)")]
[DataRow("ActionInvoker(arg => StringProp = ActionInvoker(innerArg => StringProp = innerArg))")]
[DataRow("Action2Invoker((arg1, arg2) => StringProp = arg1 + arg2)")]

public void BindingCompiler_Valid_ParameterLambdaToAction(string expr)
{
var viewModel = new TestLambdaCompilation();
var result = ExecuteBinding(expr, viewModel);
Assert.AreEqual("Action", result);
}

[TestMethod]
[DataRow("DelegateInvoker(arg => StringProp = arg)")]
[DataRow("DelegateInvoker(arg => arg + arg)")]
public void BindingCompiler_Valid_LambdaParameter_PreferFunc(string expr)
{
var viewModel = new TestLambdaCompilation();
var result = ExecuteBinding(expr, viewModel);
Assert.AreEqual("Func", result);
}

[TestMethod]
public void BindingCompiler_Valid_ExtensionMethods()
{
Expand Down Expand Up @@ -655,6 +748,8 @@ class TestViewModel
public DateTime? DateTo { get; set; }
public object Time { get; set; } = TimeSpan.FromSeconds(5);
public Guid GuidProp { get; set; }
public Tuple<int, bool> Tuple { get; set; }
public List<int> List { get; set; }

public long LongProperty { get; set; }

Expand All @@ -681,6 +776,11 @@ public Type GetType<T>(T param)
return typeof(T);
}

public Type GetFirstGenericArgType<T, U>(Tuple<T, U> param)
{
return typeof(T);
}

public string Cat<T>(T obj, string str = null)
{
return obj.ToString() + (str ?? StringProp);
Expand Down Expand Up @@ -714,6 +814,17 @@ public async Task<string> GetStringPropAsync()
public int MethodWithOverloads(int a, int b) => a + b;
}

class TestLambdaCompilation
{
public string StringProp { get; set; }

public string DelegateInvoker(Func<string, string> func) { func(default); return "Func"; }
public string DelegateInvoker(Action<string> action) { action(default); return "Action"; }

public string ActionInvoker(Action<string> action) { action(default); return "Action"; }
public string Action2Invoker(Action<string, string> action) { action(default, default); return "Action"; }
}

class TestViewModel2
{
public int MyProperty { get; set; }
Expand Down
Expand Up @@ -26,7 +26,7 @@ public BindingExpressionBuilder(CompiledAssemblyCache compiledAssemblyCache, Ext
this.extensionMethodsCache = extensionMethodsCache;
}

public Expression Parse(string expression, DataContextStack dataContexts, BindingParserOptions options, params KeyValuePair<string, Expression>[] additionalSymbols)
public Expression Parse(string expression, DataContextStack dataContexts, BindingParserOptions options, Type expectedType = null, params KeyValuePair<string, Expression>[] additionalSymbols)
{
try
{
Expand Down Expand Up @@ -54,7 +54,7 @@ public Expression Parse(string expression, DataContextStack dataContexts, Bindin
symbols = symbols.AddSymbols(options.ExtensionParameters.Select(p => CreateParameter(dataContexts, p.Identifier, p)));
symbols = symbols.AddSymbols(additionalSymbols);

var visitor = new ExpressionBuildingVisitor(symbols, memberExpressionFactory);
var visitor = new ExpressionBuildingVisitor(symbols, memberExpressionFactory, expectedType);
visitor.Scope = symbols.Resolve(options.ScopeParameter);
return visitor.Visit(node);
}
Expand Down Expand Up @@ -143,10 +143,10 @@ public static Expression ParseWithLambdaConversion(this IBindingExpressionBuilde
extensionParameter: new TypeConversion.MagicLambdaConversionExtensionParameter(index, p.Name, p.ParameterType)))
))
.ToArray();
return builder.Parse(expression, dataContexts, options, additionalSymbols.Concat(delegateSymbols).ToArray());
return builder.Parse(expression, dataContexts, options, expectedType, additionalSymbols.Concat(delegateSymbols).ToArray());
}
else
return builder.Parse(expression, dataContexts, options, additionalSymbols);
return builder.Parse(expression, dataContexts, options, expectedType, additionalSymbols);
}
}
}
Expand Up @@ -7,8 +7,10 @@
using DotVVM.Framework.Utils;
using System.Linq;
using System.Threading.Tasks;
using DotVVM.Framework.Compilation.Inference;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Reflection;

namespace DotVVM.Framework.Compilation.Binding
{
Expand All @@ -17,16 +19,21 @@ public class ExpressionBuildingVisitor : BindingParserNodeVisitor<Expression>
public TypeRegistry Registry { get; set; }
public Expression? Scope { get; set; }
public bool ResolveOnlyTypeName { get; set; }
public Type? ExpectedType { get; set; }
public ImmutableDictionary<string, ParameterExpression> Variables { get; set; } =
ImmutableDictionary<string, ParameterExpression>.Empty;

private TypeInferer inferer;
private int expressionDepth;
private List<Exception>? currentErrors;
private readonly MemberExpressionFactory memberExpressionFactory;

public ExpressionBuildingVisitor(TypeRegistry registry, MemberExpressionFactory memberExpressionFactory)
public ExpressionBuildingVisitor(TypeRegistry registry, MemberExpressionFactory memberExpressionFactory, Type? expectedType = null)
{
Registry = registry;
ExpectedType = expectedType;
this.memberExpressionFactory = memberExpressionFactory;
this.inferer = new TypeInferer();
}

[return: MaybeNull]
Expand Down Expand Up @@ -90,13 +97,15 @@ public override Expression Visit(BindingParserNode node)
var errors = currentErrors;
try
{
expressionDepth++;
ThrowIfNotTypeNameRelevant(node);
return base.Visit(node);
}
finally
{
currentErrors = errors;
Registry = regBackup;
expressionDepth--;
}
}

Expand Down Expand Up @@ -213,10 +222,31 @@ protected override Expression VisitFunctionCall(FunctionCallBindingParserNode no
{
var target = HandleErrors(node.TargetExpression, Visit);
var args = new Expression[node.ArgumentExpressions.Count];
for (int i = 0; i < args.Length; i++)

inferer.BeginFunctionCall(target as MethodGroupExpression, args.Length);

var lambdaNodeIndices = new List<int>();
// Initially process all nodes that are not lambdas
for (var i = 0; i < args.Length; i++)
{
if (node.ArgumentExpressions[i] is LambdaBindingParserNode)
{
lambdaNodeIndices.Add(i);
continue;
}

args[i] = HandleErrors(node.ArgumentExpressions[i], Visit)!;
inferer.SetArgument(args[i], i);
}
// Subsequently process all lambdas
foreach (var index in lambdaNodeIndices)
{
inferer.SetProbedArgumentIndex(index);
args[index] = HandleErrors(node.ArgumentExpressions[index], Visit)!;
inferer.SetArgument(args[index], index);
}

inferer.EndFunctionCall();
ThrowOnErrors();

return memberExpressionFactory.Call(target, args);
Expand Down Expand Up @@ -295,8 +325,21 @@ protected override Expression VisitLambda(LambdaBindingParserNode node)
{
// Create lambda definition
var lambdaParameters = new ParameterExpression[node.ParameterExpressions.Count];

// Apply information from type inference if available
var hintType = (expressionDepth == 1) ? ExpectedType : null;
var typeInferenceData = inferer.Infer(hintType).Lambda(node.ParameterExpressions.Count);
if (typeInferenceData.Result)
{
for (var paramIndex = 0; paramIndex < typeInferenceData.Parameters!.Length; paramIndex++)
{
var currentParamType = typeInferenceData.Parameters[paramIndex];
node.ParameterExpressions[paramIndex].SetResolvedType(currentParamType);
}
}

for (var i = 0; i < lambdaParameters.Length; i++)
lambdaParameters[i] = (ParameterExpression)HandleErrors(node.ParameterExpressions[i], Visit);
lambdaParameters[i] = (ParameterExpression)HandleErrors(node.ParameterExpressions[i], Visit)!;

// Make sure that parameter identifiers are distinct
if (lambdaParameters.GroupBy(param => param.Name).Any(group => group.Count() > 1))
Expand All @@ -317,16 +360,49 @@ protected override Expression VisitLambda(LambdaBindingParserNode node)
var body = Visit(node.BodyExpression);

ThrowOnErrors();
return Expression.Lambda(body, lambdaParameters);
return CreateLambdaExpression(body, lambdaParameters, typeInferenceData.Type);
}

protected override Expression VisitLambdaParameter(LambdaParameterBindingParserNode node)
{
if (node.Type == null)
if (node.Type == null && node.ResolvedType == null)
throw new BindingCompilationException($"Could not infer type of parameter.", node);

var parameterType = Visit(node.Type).Type;
return Expression.Parameter(parameterType, node.Name.ToDisplayString());
if (node.ResolvedType != null)
{
// Type was not specified but was infered
return Expression.Parameter(node.ResolvedType, node.Name.ToDisplayString());
}
else
{
// Type was specified and needs to be obtained from binding node
var parameterType = Visit(node.Type).Type;
return Expression.Parameter(parameterType, node.Name.ToDisplayString());
}
}

private Expression CreateLambdaExpression(Expression body, ParameterExpression[] parameters, Type? delegateType)
{
if (delegateType != null && delegateType.Namespace == "System")
{
if (delegateType.Name == "Action" || delegateType.Name == $"Action`{parameters.Length}")
{
// We must validate that lambda body contains a valid statement
if ((body.NodeType != ExpressionType.Default) && (body.NodeType != ExpressionType.Block) && (body.NodeType != ExpressionType.Call) && (body.NodeType != ExpressionType.Assign))
throw new DotvvmCompilationException($"Only method invocations and assignments can be used as statements.");

// Make sure the result type will be void by adding an empty expression
return Expression.Lambda(Expression.Block(body, Expression.Empty()), parameters);
}
else if (delegateType.Name == "Predicate`1")
{
var type = delegateType.GetGenericTypeDefinition().MakeGenericType(parameters.Single().Type);
return Expression.Lambda(type, body, parameters);
}
}

// Assume delegate is a System.Func<...>
return Expression.Lambda(body, parameters);
}

protected override Expression VisitBlock(BlockBindingParserNode node)
Expand Down
Expand Up @@ -50,11 +50,13 @@ public Expression GetMember(Expression target, string name, Type[] typeArguments
.Where(m => ((isGeneric && m is TypeInfo) ? genericName : name) == m.Name)
.ToArray();

var isExtension = false;
if (members.Length == 0)
{
// We did not find any match in regular methods => try extension methods
var extensions = GetAllExtensionMethods().Where(m => m.Name == name).ToArray();
var extensions = GetAllExtensionMethods().Where(m => m.Name == name && ExtensionMethodsFilter(target, m)).ToArray();
members = extensions;
isExtension = true;

if (members.Length == 0 && throwExceptions)
throw new Exception($"Could not find { (isStatic ? "static" : "instance") } member { name } on type { type.FullName }.");
Expand Down Expand Up @@ -84,7 +86,29 @@ public Expression GetMember(Expression target, string name, Type[] typeArguments
: new StaticClassIdentifierExpression(nonGenericType.UnderlyingSystemType);
}
}
return new MethodGroupExpression() { MethodName = name, Target = target, TypeArgs = typeArguments };

var candidates = members.Cast<MethodInfo>().ToList();
return new MethodGroupExpression() { MethodName = name, Target = target, TypeArgs = typeArguments, Candidates = candidates, HasExtensionCandidates = isExtension };
}

private bool ExtensionMethodsFilter(Expression target, MethodInfo method)
{
var thisType = method.GetParameters().First().ParameterType;
if (thisType.IsGenericType)
{
if (thisType.ContainsGenericParameters)
{
return ReflectionUtils.IsAssignableToGenericType(target.Type, thisType.GetGenericTypeDefinition(), out _);
}
else
{
return thisType.IsAssignableFrom(target.Type);
}
}
else
{
return thisType.IsAssignableFrom(target.Type);
}
}

private Expression GetDotvvmPropertyMember(Expression target, string name)
Expand Down
Expand Up @@ -18,6 +18,8 @@ public class MethodGroupExpression : Expression
public Expression Target { get; set; }
public string MethodName { get; set; }
public Type[] TypeArgs { get; set; }
public List<MethodInfo> Candidates { get; set; }
public bool HasExtensionCandidates { get; set; }
public bool IsStatic => Target is StaticClassIdentifierExpression;

private static MethodInfo CreateDelegateMethodInfo = typeof(Delegate).GetMethod("CreateDelegate", new[] { typeof(Type), typeof(object), typeof(MethodInfo) });
Expand Down