Skip to content

Commit

Permalink
Merge pull request #817 from riganti/null-propagation-args
Browse files Browse the repository at this point in the history
Null propagation now passes nulls into methods instead of skipping execution
  • Loading branch information
quigamdev committed Feb 19, 2021
2 parents aab5747 + e498916 commit 1bac555
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 83 deletions.
Expand Up @@ -629,6 +629,7 @@ class TestViewModel

public long[] LongArray => new long[] { 1, 2, long.MaxValue };
public TestViewModel2[] VmArray => new TestViewModel2[] { new TestViewModel2() };
public int[] IntArray { get; set; }

public string SetStringProp(string a, int b)
{
Expand Down
109 changes: 103 additions & 6 deletions src/DotVVM.Framework.Tests.Common/Binding/NullPropagationTests.cs
Expand Up @@ -21,17 +21,38 @@ public class NullPropagationTests
private LambdaExpression[] ExpressionFragments = new LambdaExpression[] {
Create((TestViewModel t) => t.EnumProperty - 1),
Create((TestViewModel t) => t.StringProp),
Create((TestViewModel t) => t.LongArray),
Create((TestViewModel t) => t.LongProperty),
Create((TestViewModel t) => t.VmArray),
Create((TestViewModel t) => t.VmArray[0]),
Create((TestViewModel2[] t, int b) => t[b & 0]),
Create((long[] t) => t.Length),
Create((long[] t) => t[0]),
Create((TestViewModel2 t) => t.Enum),
Create((TestEnum t) => (int)t),
Create((TestEnum t) => (TestEnum?)t),
Create((TestViewModel a, TestViewModel b) => a.BoolMethodExecuted ? a : b),
Create((int a, TestViewModel b, TestViewModel c) => a == 0 ? b : c),
Create((string b) => b ?? "NULL STRING"),
Create((TestViewModel a, int b) => a.Identity(b)),
Create((TestViewModel a, DateTime b) => a.Identity<DateTime?>(b)),
Create((TestViewModel a, string b) => a.Identity(b)),
Create((TestViewModel a, TestViewModel b) => a.Identity(b)),
Create((TestViewModel a, char b) => a.GetCharCode(b)),
Create((string a) => a.Length > 0 ? a[0] : 'f'),
Create((string a, int b) => a.Length > b ? a[b] : 'l'),
// Create((string a, int b) => a[b]),
Create((int a, int b) => a+b),
Create((int a) => a+ 1.0),
Create((int a) => a + 1.0),
Create((long a) => (int)a),
Create((int a, double b) => a * b),
Create((TimeSpan span) => span.TotalMilliseconds),
Create((TestViewModel vm) => vm.DateFrom), // DateTime?
Create((DateTime? vm) => vm.HasValue),
Create((DateTime? vm) => vm.Value),
Create((DateTime? vm) => vm.GetValueOrDefault()),
Create((DateTime d) => d.Day),
Create((DateTime d) => d.ToString()),
Create((double a) => TimeSpan.FromSeconds(a)),
Create((TestViewModel vm) => (TimeSpan)vm.Time),
Create((TestViewModel vm, TimeSpan time, int integer, double number) => new TestViewModel{ EnumProperty = TestEnum.B, BoolMethodExecuted = vm.BoolMethodExecuted, StringProp = time + ": " + vm.StringProp, StringProp2 = integer + vm.StringProp2 + number, TestViewModel2 = vm.TestViewModel2}),
Expand Down Expand Up @@ -96,7 +117,7 @@ void AddFragment(LambdaExpression fragment, int maxCount)
while (unusedFragments.Any())
{
var possibleOnes = unusedFragments.Where(f => f.Parameters.All(p => typeSources.ContainsKey(p.Type))).ToArray();
Assert.IsFalse(possibleOnes.Length == 0);
Assert.IsFalse(possibleOnes.Length == 0, $"Can not continue from {string.Join(", ", typeSources.Select(t => t.Key.Name))} to {string.Join(", ", unusedFragments.AsEnumerable())}");
foreach (var fragment in possibleOnes)
{
AddFragment(fragment, 4);
Expand All @@ -122,20 +143,38 @@ private void TestExpression(Random rnd, Expression expression, ParameterExpressi
Func<TestViewModel[], object> compile(Expression e) =>
Expression.Lambda<Func<TestViewModel[], object>>(Expression.Convert(e, typeof(object)), parameter).Compile();

var withNullCheks = compile(ExpressionNullPropagationVisitor.PropagateNulls(expr, _ => true));
var withNullChecks = compile(ExpressionNullPropagationVisitor.PropagateNulls(expr, _ => true));
var withoutNullChecks = compile(expr);

var args = Enumerable.Repeat(new TestViewModel { StringProp = "ll", StringProp2 = "pp", TestViewModel2 = new TestViewModel2() }, count).ToArray();
var args = Enumerable.Repeat(new TestViewModel { StringProp = "ll", StringProp2 = "pp", TestViewModel2 = new TestViewModel2(), DateFrom = DateTime.Parse("2020-03-29") }, count).ToArray();
var settings = DefaultSerializerSettingsProvider.Instance.Settings;
Assert.AreEqual(JsonConvert.SerializeObject(withNullCheks(args), settings), JsonConvert.SerializeObject(withoutNullChecks(args), settings));
Assert.AreEqual(JsonConvert.SerializeObject(withNullChecks(args), settings), JsonConvert.SerializeObject(withoutNullChecks(args), settings));

foreach (var i in Enumerable.Range(0, args.Length).Shuffle(rnd))
{
args[i] = null;
withNullCheks(args);
try
{
withNullChecks(args);
}
catch (NullReferenceException ex)
{
// we might get exceptions from B(null), ...
// none of the methods we call actually throws, so it always has to be ours friendly NRE
Assert.IsTrue(ex.Message.StartsWith("Binding expression"));
// and it must only occur for value types
Assert.IsTrue(new [] { "System.Int", "System.TimeSpan", "System.Char", "System.Double" }.Any(ex.Message.Contains));
}
}
}

private object EvalExpression<T>(Expression<Func<T, object>> a, T val)
{
var nullChecked = ExpressionNullPropagationVisitor.PropagateNulls(a.Body, _ => true);
var d = a.Update(body: nullChecked, a.Parameters).Compile();
return d(val);
}

class ReplacerVisitor : ExpressionVisitor
{
private readonly ParameterExpression originalParameter;
Expand Down Expand Up @@ -195,6 +234,64 @@ public void BindingNullPropagation_1()
TestExpression(random, expr, viewModel);
}
}

private static T Identity<T>(T x) => x;

[TestMethod]
public void MethodArgument_Null_ValueType()
{
var e = Assert.ThrowsException<NullReferenceException>(() =>
EvalExpression<TestViewModel>(v => Identity(v.LongProperty), null));
Assert.AreEqual("Binding expression 'v.LongProperty' of type 'System.Int64' has evaluated to null.", e.Message);
}

[TestMethod]
public void MethodArgument_Null_NullableType()
{
Assert.IsNull(EvalExpression<TestViewModel>(v => Identity(v.DateFrom), null));
Assert.IsNull(EvalExpression<TestViewModel>(v => Identity<long?>(v.LongProperty), null));
Assert.IsNull(EvalExpression<TestViewModel>(v => Identity<DateTime?>(v.DateFrom.Value), null));
}

[TestMethod]
public void MethodArgument_Null_RefType()
{
Assert.IsNull(EvalExpression<TestViewModel>(v => Identity(v), null));
Assert.IsNull(EvalExpression<TestViewModel>(v => Identity(v.LongArray), null));
}

[TestMethod]
public void Operator()
{
Assert.IsNull(EvalExpression<TestViewModel>(v => v.LongArray[0] + 1, null));
Assert.IsNull(EvalExpression<TestViewModel>(v => v.LongArray[0] + v.TestViewModel2.MyProperty, new TestViewModel()));
Assert.IsNull(EvalExpression<TestViewModel>(v => v.TestViewModel2B.ChildObject.SomeString.Length + v.TestViewModel2.MyProperty, new TestViewModel()));
Assert.AreEqual(2L, EvalExpression<TestViewModel>(v => v.LongArray[0] + 1, new TestViewModel()));
}

[TestMethod]
public void StringConcat()
{
Assert.AreEqual("abc", EvalExpression<TestViewModel>(v => v.StringProp + "abc", null));
}

[TestMethod]
public void Indexer()
{
Assert.IsNull(EvalExpression<int[]>(v => v[0] + 1, null));
Assert.IsNull(EvalExpression<TestViewModel>(v => v.IntArray[0] + 1, null));
Assert.IsNull(EvalExpression<TestViewModel>(v => v.IntArray[0] + 1, new TestViewModel { IntArray = null }));
Assert.IsNull(EvalExpression<TestViewModel>(v => v.TestViewModel2.Collection[0].StringValue.Length + 5, new TestViewModel { IntArray = null }));
Assert.IsNull(EvalExpression<TestViewModel>(v => v.TestViewModel2.Collection[0].StringValue.Length + 5, new TestViewModel { IntArray = null }));
}

[TestMethod]
public void Coalesce()
{
Assert.AreEqual(1, EvalExpression<object>(v => v ?? 1, null));
Assert.AreEqual(1, EvalExpression<object>(v => (v ?? 1) ?? 2, null));
Assert.AreEqual(1, EvalExpression<int?>(v => v ?? null, 1));
}
}

public static class EnumerableExtensions
Expand Down
Expand Up @@ -21,41 +21,11 @@ protected override Expression VisitMember(MemberExpression node)
{
if (node.Expression?.Type?.IsNullable() == true)
{
if (node.Member.Name == "Value") return node.Expression;
if (node.Member.Name == "Value") return Visit(node.Expression);
else return base.VisitMember(node);
}
else return CheckForNull(Visit(node.Expression), expr =>
Expression.MakeMemberAccess(expr, node.Member));

//var left = Visit(node.Expression);
//var nullableType = Nullable.GetUnderlyingType(left.Type);
//if ((!left.Type.GetTypeInfo().IsValueType || nullableType != null) && (CanBeNull == null || CanBeNull(left)))
//{
// var wrapNullable = node.Type.GetTypeInfo().IsValueType;
// var restype = wrapNullable ? typeof(Nullable<>).MakeGenericType(node.Type) : node.Type;
// var variable = Expression.Parameter(left.Type);
// Expression body;
// Expression condition;
// if (nullableType != null)
// {
// condition = Expression.Property(variable, "HasValue");
// body = node.Update(Expression.Convert(variable, node.Expression.Type));
// }
// else
// {
// condition = Expression.NotEqual(variable, Expression.Constant(null, variable.Type));
// body = node.Update(variable);
// }
// if (wrapNullable)
// {
// body = Expression.Convert(body, restype);
// }
// return Expression.Block(new[] { variable },
// Expression.Assign(variable, left),
// Expression.IfThenElse(condition, body, Expression.Default(restype))
// );
//}
//else return base.VisitMember(node);
}

protected override Expression VisitLambda<T>(Expression<T> expression)
Expand Down Expand Up @@ -96,26 +66,23 @@ Expression createExpr(Expression left)
// this should only be ParameterExpression
else return createExpr(node.Left);
}
else if (node.NodeType == ExpressionType.ArrayIndex)
{
return CheckForNull(base.Visit(node.Left), left2 => createExpr(left2));
}
else
{
if (true)
{
var left = Visit(node.Left);
var right = Visit(node.Right);
var nullable = left.Type.IsNullable() ? left.Type : right.Type;
left = TypeConversion.ImplicitConversion(left, nullable);
right = TypeConversion.ImplicitConversion(right, nullable);

if (right != null && left != null)
return Expression.MakeBinary(node.NodeType, left, right, left.Type.IsNullable() && node.NodeType != ExpressionType.Equal && node.NodeType != ExpressionType.NotEqual, node.Method);
else return CheckForNull(base.Visit(node.Left), left2 =>
createExpr(left2),
checkReferenceTypes: false);
}
else
{

}
var left = Visit(node.Left);
var right = Visit(node.Right);
var nullable = left.Type.IsNullable() ? left.Type : right.Type;
left = TypeConversion.ImplicitConversion(left, nullable);
right = TypeConversion.ImplicitConversion(right, nullable);

if (right != null && left != null)
return Expression.MakeBinary(node.NodeType, left, right, left.Type.IsNullable() && node.NodeType != ExpressionType.Equal && node.NodeType != ExpressionType.NotEqual, node.Method);
else return CheckForNull(base.Visit(node.Left), left2 =>
createExpr(left2),
checkReferenceTypes: false);
}
}

Expand All @@ -129,11 +96,7 @@ protected override Expression VisitUnary(UnaryExpression node)
protected override Expression VisitInvocation(InvocationExpression node)
{
return CheckForNull(Visit(node.Expression), target =>
CheckForNulls(node.Arguments.Select(Visit).ToArray(), args =>
Expression.Invoke(target, args),
suppressThisOne: (arg, i) => node.Arguments[i].Type.IsNullable() || !arg.Type.GetTypeInfo().IsValueType),
suppress: node.Expression.Type.IsNullable()
);
Expression.Invoke(target, UnwrapNullableTypes(node.Arguments)));
}

protected override Expression VisitConditional(ConditionalExpression node)
Expand All @@ -154,42 +117,46 @@ protected override Expression VisitConditional(ConditionalExpression node)
protected override Expression VisitIndex(IndexExpression node)
{
return CheckForNull(Visit(node.Object), target =>
CheckForNulls(node.Arguments.Select(Visit).ToArray(), args =>
Expression.MakeIndex(target, node.Indexer, args),
suppressThisOne: (arg, i) => node.Arguments[i].Type.IsNullable() || !arg.Type.GetTypeInfo().IsValueType),
suppress: node.Object.Type.IsNullable()
Expression.MakeIndex(target, node.Indexer, UnwrapNullableTypes(node.Arguments))
);
}

protected override Expression VisitMethodCall(MethodCallExpression node)
{
return CheckForNull(Visit(node.Object), target =>
CheckForNulls(node.Arguments.Select(Visit).ToArray(), args =>
Expression.Call(target, node.Method, args),
suppressThisOne: (arg, i) => node.Arguments[i].Type.IsNullable() || !arg.Type.GetTypeInfo().IsValueType),
Expression.Call(target, node.Method, UnwrapNullableTypes(node.Arguments)),
suppress: node.Object?.Type?.IsNullable() ?? true
);
}

protected override Expression VisitNew(NewExpression node)
{
return CheckForNulls(node.Arguments.Select(Visit).ToArray(), args =>
Expression.New(node.Constructor, args),
suppressThisOne: (arg, i) => node.Arguments[i].Type.IsNullable() || !arg.Type.GetTypeInfo().IsValueType);
return Expression.New(node.Constructor, UnwrapNullableTypes(node.Arguments));
}

protected Expression CheckForNulls(Expression[] parameters, Func<Expression[], Expression> callback, Func<Expression, int, bool> suppressThisOne = null)
protected Expression[] UnwrapNullableTypes(IEnumerable<Expression> uncheckedArguments) =>
uncheckedArguments.Select(UnwrapNullableType).ToArray();
protected Expression UnwrapNullableType(Expression expression) =>
UnwrapNullableType(Visit(expression), expression.Type, expression.ToString());
protected Expression UnwrapNullableType(Expression expression, Type expectedType, string formattedExpression)
{
if (parameters.Length == 0) return callback(new Expression[0]);
var list = new List<Expression>();
Func<Expression, Expression> cc = e => { list.Add(e); return callback(list.ToArray()); };
for (var i = parameters.Length - 1; i >= 1; i--)
if (expression.Type == expectedType)
return expression;
else if (expression.Type == typeof(Nullable<>).MakeGenericType(expectedType))
{
var iCopy = i;
var ccc = cc;
cc = e => { list.Add(e); return CheckForNull(parameters[iCopy], ccc, suppress: suppressThisOne?.Invoke(parameters[iCopy], iCopy) ?? false); };
var tmp = Expression.Parameter(expression.Type);
var nreCtor = typeof(NullReferenceException).GetConstructor(new [] { typeof(string) });
return Expression.Block(new [] { tmp },
Expression.Assign(tmp, expression),
Expression.Condition(
Expression.Property(tmp, "HasValue"),
Expression.Property(tmp, "Value"),
Expression.Throw(Expression.New(nreCtor, Expression.Constant($"Binding expression '{formattedExpression}' of type '{expectedType}' has evaluated to null.")), expectedType)
)
);
}
return CheckForNull(parameters[0], cc, suppress: suppressThisOne?.Invoke(parameters[0], 0) ?? false);
else
throw new Exception($"Type mismatch: {expectedType} was expected, got {expression.Type}");
}

private int tmpCounter;
Expand Down
11 changes: 7 additions & 4 deletions src/DotVVM.Samples.Tests/Feature/PostbackConcurrencyTests.cs
Expand Up @@ -3,6 +3,7 @@
using DotVVM.Testing.Abstractions;
using Riganti.Selenium.Core;
using Riganti.Selenium.Core.Abstractions.Attributes;
using Riganti.Selenium.DotVVM;
using Xunit;
using Xunit.Abstractions;

Expand Down Expand Up @@ -40,6 +41,7 @@ public void Feature_PostbackConcurrency_DefaultMode(string longActionSelector, s
{
RunInAllBrowsers(browser => {
browser.NavigateToUrl(SamplesRouteUrls.FeatureSamples_PostbackConcurrency_DefaultMode);
browser.WaitUntilDotvvmInited();
// try the long action interrupted by the short one
browser.Single(longActionSelector).Click();
Expand All @@ -52,11 +54,12 @@ public void Feature_PostbackConcurrency_DefaultMode(string longActionSelector, s
// the postback index should be 1 now (because of short action)
AssertUI.InnerTextEquals(postbackIndexSpan, "1");
AssertUI.InnerTextEquals(lastActionSpan, "short");
// the result of the long action should be canceled, the counter shouldn't increase
browser.Wait(6000);
AssertUI.InnerTextEquals(postbackIndexSpan, "1");
AssertUI.InnerTextEquals(lastActionSpan, "short");
// the result of the long action should be canceled, the counter shouldn't increase
browser.WaitFor(()=> {
AssertUI.InnerTextEquals(postbackIndexSpan, "1");
AssertUI.InnerTextEquals(lastActionSpan, "short");
},3000);
});
}

Expand Down

0 comments on commit 1bac555

Please sign in to comment.