Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Null propagation now passes nulls into methods instead of skipping execution #817

Merged
merged 8 commits into from Feb 19, 2021
Expand Up @@ -629,6 +629,7 @@ class TestViewModel

public long[] LongArray => new long[] { 1, 2, long.MaxValue };
public TestViewModel2[] VmArray => new TestViewModel2[] { new TestViewModel2() };
public int[] IntArray { get; set; }

public string SetStringProp(string a, int b)
{
Expand Down
109 changes: 103 additions & 6 deletions src/DotVVM.Framework.Tests.Common/Binding/NullPropagationTests.cs
Expand Up @@ -21,17 +21,38 @@ public class NullPropagationTests
private LambdaExpression[] ExpressionFragments = new LambdaExpression[] {
Create((TestViewModel t) => t.EnumProperty - 1),
Create((TestViewModel t) => t.StringProp),
Create((TestViewModel t) => t.LongArray),
Create((TestViewModel t) => t.LongProperty),
Create((TestViewModel t) => t.VmArray),
Create((TestViewModel t) => t.VmArray[0]),
Create((TestViewModel2[] t, int b) => t[b & 0]),
Create((long[] t) => t.Length),
Create((long[] t) => t[0]),
Create((TestViewModel2 t) => t.Enum),
Create((TestEnum t) => (int)t),
Create((TestEnum t) => (TestEnum?)t),
Create((TestViewModel a, TestViewModel b) => a.BoolMethodExecuted ? a : b),
Create((int a, TestViewModel b, TestViewModel c) => a == 0 ? b : c),
Create((string b) => b ?? "NULL STRING"),
Create((TestViewModel a, int b) => a.Identity(b)),
Create((TestViewModel a, DateTime b) => a.Identity<DateTime?>(b)),
Create((TestViewModel a, string b) => a.Identity(b)),
Create((TestViewModel a, TestViewModel b) => a.Identity(b)),
Create((TestViewModel a, char b) => a.GetCharCode(b)),
Create((string a) => a.Length > 0 ? a[0] : 'f'),
Create((string a, int b) => a.Length > b ? a[b] : 'l'),
// Create((string a, int b) => a[b]),
Create((int a, int b) => a+b),
Create((int a) => a+ 1.0),
Create((int a) => a + 1.0),
Create((long a) => (int)a),
Create((int a, double b) => a * b),
Create((TimeSpan span) => span.TotalMilliseconds),
Create((TestViewModel vm) => vm.DateFrom), // DateTime?
Create((DateTime? vm) => vm.HasValue),
Create((DateTime? vm) => vm.Value),
Create((DateTime? vm) => vm.GetValueOrDefault()),
Create((DateTime d) => d.Day),
Create((DateTime d) => d.ToString()),
Create((double a) => TimeSpan.FromSeconds(a)),
Create((TestViewModel vm) => (TimeSpan)vm.Time),
Create((TestViewModel vm, TimeSpan time, int integer, double number) => new TestViewModel{ EnumProperty = TestEnum.B, BoolMethodExecuted = vm.BoolMethodExecuted, StringProp = time + ": " + vm.StringProp, StringProp2 = integer + vm.StringProp2 + number, TestViewModel2 = vm.TestViewModel2}),
Expand Down Expand Up @@ -96,7 +117,7 @@ void AddFragment(LambdaExpression fragment, int maxCount)
while (unusedFragments.Any())
{
var possibleOnes = unusedFragments.Where(f => f.Parameters.All(p => typeSources.ContainsKey(p.Type))).ToArray();
Assert.IsFalse(possibleOnes.Length == 0);
Assert.IsFalse(possibleOnes.Length == 0, $"Can not continue from {string.Join(", ", typeSources.Select(t => t.Key.Name))} to {string.Join(", ", unusedFragments.AsEnumerable())}");
foreach (var fragment in possibleOnes)
{
AddFragment(fragment, 4);
Expand All @@ -122,20 +143,38 @@ private void TestExpression(Random rnd, Expression expression, ParameterExpressi
Func<TestViewModel[], object> compile(Expression e) =>
Expression.Lambda<Func<TestViewModel[], object>>(Expression.Convert(e, typeof(object)), parameter).Compile();

var withNullCheks = compile(ExpressionNullPropagationVisitor.PropagateNulls(expr, _ => true));
var withNullChecks = compile(ExpressionNullPropagationVisitor.PropagateNulls(expr, _ => true));
var withoutNullChecks = compile(expr);

var args = Enumerable.Repeat(new TestViewModel { StringProp = "ll", StringProp2 = "pp", TestViewModel2 = new TestViewModel2() }, count).ToArray();
var args = Enumerable.Repeat(new TestViewModel { StringProp = "ll", StringProp2 = "pp", TestViewModel2 = new TestViewModel2(), DateFrom = DateTime.Parse("2020-03-29") }, count).ToArray();
var settings = DefaultSerializerSettingsProvider.Instance.Settings;
Assert.AreEqual(JsonConvert.SerializeObject(withNullCheks(args), settings), JsonConvert.SerializeObject(withoutNullChecks(args), settings));
Assert.AreEqual(JsonConvert.SerializeObject(withNullChecks(args), settings), JsonConvert.SerializeObject(withoutNullChecks(args), settings));

foreach (var i in Enumerable.Range(0, args.Length).Shuffle(rnd))
{
args[i] = null;
withNullCheks(args);
try
{
withNullChecks(args);
}
catch (NullReferenceException ex)
{
// we might get exceptions from B(null), ...
// none of the methods we call actually throws, so it always has to be ours friendly NRE
Assert.IsTrue(ex.Message.StartsWith("Binding expression"));
// and it must only occur for value types
Assert.IsTrue(new [] { "System.Int", "System.TimeSpan", "System.Char", "System.Double" }.Any(ex.Message.Contains));
}
}
}

private object EvalExpression<T>(Expression<Func<T, object>> a, T val)
{
var nullChecked = ExpressionNullPropagationVisitor.PropagateNulls(a.Body, _ => true);
var d = a.Update(body: nullChecked, a.Parameters).Compile();
return d(val);
}

class ReplacerVisitor : ExpressionVisitor
{
private readonly ParameterExpression originalParameter;
Expand Down Expand Up @@ -195,6 +234,64 @@ public void BindingNullPropagation_1()
TestExpression(random, expr, viewModel);
}
}

private static T Identity<T>(T x) => x;

[TestMethod]
public void MethodArgument_Null_ValueType()
{
var e = Assert.ThrowsException<NullReferenceException>(() =>
EvalExpression<TestViewModel>(v => Identity(v.LongProperty), null));
Assert.AreEqual("Binding expression 'v.LongProperty' of type 'System.Int64' has evaluated to null.", e.Message);
}

[TestMethod]
public void MethodArgument_Null_NullableType()
{
Assert.IsNull(EvalExpression<TestViewModel>(v => Identity(v.DateFrom), null));
Assert.IsNull(EvalExpression<TestViewModel>(v => Identity<long?>(v.LongProperty), null));
Assert.IsNull(EvalExpression<TestViewModel>(v => Identity<DateTime?>(v.DateFrom.Value), null));
}

[TestMethod]
public void MethodArgument_Null_RefType()
{
Assert.IsNull(EvalExpression<TestViewModel>(v => Identity(v), null));
Assert.IsNull(EvalExpression<TestViewModel>(v => Identity(v.LongArray), null));
}

[TestMethod]
public void Operator()
{
Assert.IsNull(EvalExpression<TestViewModel>(v => v.LongArray[0] + 1, null));
Assert.IsNull(EvalExpression<TestViewModel>(v => v.LongArray[0] + v.TestViewModel2.MyProperty, new TestViewModel()));
Assert.IsNull(EvalExpression<TestViewModel>(v => v.TestViewModel2B.ChildObject.SomeString.Length + v.TestViewModel2.MyProperty, new TestViewModel()));
Assert.AreEqual(2L, EvalExpression<TestViewModel>(v => v.LongArray[0] + 1, new TestViewModel()));
}

[TestMethod]
public void StringConcat()
{
Assert.AreEqual("abc", EvalExpression<TestViewModel>(v => v.StringProp + "abc", null));
}

[TestMethod]
public void Indexer()
{
Assert.IsNull(EvalExpression<int[]>(v => v[0] + 1, null));
Assert.IsNull(EvalExpression<TestViewModel>(v => v.IntArray[0] + 1, null));
Assert.IsNull(EvalExpression<TestViewModel>(v => v.IntArray[0] + 1, new TestViewModel { IntArray = null }));
Assert.IsNull(EvalExpression<TestViewModel>(v => v.TestViewModel2.Collection[0].StringValue.Length + 5, new TestViewModel { IntArray = null }));
Assert.IsNull(EvalExpression<TestViewModel>(v => v.TestViewModel2.Collection[0].StringValue.Length + 5, new TestViewModel { IntArray = null }));
}

[TestMethod]
public void Coalesce()
{
Assert.AreEqual(1, EvalExpression<object>(v => v ?? 1, null));
Assert.AreEqual(1, EvalExpression<object>(v => (v ?? 1) ?? 2, null));
Assert.AreEqual(1, EvalExpression<int?>(v => v ?? null, 1));
}
}

public static class EnumerableExtensions
Expand Down
Expand Up @@ -21,41 +21,11 @@ protected override Expression VisitMember(MemberExpression node)
{
if (node.Expression?.Type?.IsNullable() == true)
{
if (node.Member.Name == "Value") return node.Expression;
if (node.Member.Name == "Value") return Visit(node.Expression);
else return base.VisitMember(node);
}
else return CheckForNull(Visit(node.Expression), expr =>
Expression.MakeMemberAccess(expr, node.Member));

//var left = Visit(node.Expression);
//var nullableType = Nullable.GetUnderlyingType(left.Type);
//if ((!left.Type.GetTypeInfo().IsValueType || nullableType != null) && (CanBeNull == null || CanBeNull(left)))
//{
// var wrapNullable = node.Type.GetTypeInfo().IsValueType;
// var restype = wrapNullable ? typeof(Nullable<>).MakeGenericType(node.Type) : node.Type;
// var variable = Expression.Parameter(left.Type);
// Expression body;
// Expression condition;
// if (nullableType != null)
// {
// condition = Expression.Property(variable, "HasValue");
// body = node.Update(Expression.Convert(variable, node.Expression.Type));
// }
// else
// {
// condition = Expression.NotEqual(variable, Expression.Constant(null, variable.Type));
// body = node.Update(variable);
// }
// if (wrapNullable)
// {
// body = Expression.Convert(body, restype);
// }
// return Expression.Block(new[] { variable },
// Expression.Assign(variable, left),
// Expression.IfThenElse(condition, body, Expression.Default(restype))
// );
//}
//else return base.VisitMember(node);
}

protected override Expression VisitLambda<T>(Expression<T> expression)
Expand Down Expand Up @@ -96,26 +66,23 @@ Expression createExpr(Expression left)
// this should only be ParameterExpression
else return createExpr(node.Left);
}
else if (node.NodeType == ExpressionType.ArrayIndex)
{
return CheckForNull(base.Visit(node.Left), left2 => createExpr(left2));
}
else
{
if (true)
{
var left = Visit(node.Left);
var right = Visit(node.Right);
var nullable = left.Type.IsNullable() ? left.Type : right.Type;
left = TypeConversion.ImplicitConversion(left, nullable);
right = TypeConversion.ImplicitConversion(right, nullable);

if (right != null && left != null)
return Expression.MakeBinary(node.NodeType, left, right, left.Type.IsNullable() && node.NodeType != ExpressionType.Equal && node.NodeType != ExpressionType.NotEqual, node.Method);
else return CheckForNull(base.Visit(node.Left), left2 =>
createExpr(left2),
checkReferenceTypes: false);
}
else
{

}
var left = Visit(node.Left);
var right = Visit(node.Right);
var nullable = left.Type.IsNullable() ? left.Type : right.Type;
left = TypeConversion.ImplicitConversion(left, nullable);
right = TypeConversion.ImplicitConversion(right, nullable);

if (right != null && left != null)
return Expression.MakeBinary(node.NodeType, left, right, left.Type.IsNullable() && node.NodeType != ExpressionType.Equal && node.NodeType != ExpressionType.NotEqual, node.Method);
else return CheckForNull(base.Visit(node.Left), left2 =>
createExpr(left2),
checkReferenceTypes: false);
}
}

Expand All @@ -129,11 +96,7 @@ protected override Expression VisitUnary(UnaryExpression node)
protected override Expression VisitInvocation(InvocationExpression node)
{
return CheckForNull(Visit(node.Expression), target =>
CheckForNulls(node.Arguments.Select(Visit).ToArray(), args =>
Expression.Invoke(target, args),
suppressThisOne: (arg, i) => node.Arguments[i].Type.IsNullable() || !arg.Type.GetTypeInfo().IsValueType),
suppress: node.Expression.Type.IsNullable()
);
Expression.Invoke(target, UnwrapNullableTypes(node.Arguments)));
}

protected override Expression VisitConditional(ConditionalExpression node)
Expand All @@ -154,42 +117,46 @@ protected override Expression VisitConditional(ConditionalExpression node)
protected override Expression VisitIndex(IndexExpression node)
{
return CheckForNull(Visit(node.Object), target =>
CheckForNulls(node.Arguments.Select(Visit).ToArray(), args =>
Expression.MakeIndex(target, node.Indexer, args),
suppressThisOne: (arg, i) => node.Arguments[i].Type.IsNullable() || !arg.Type.GetTypeInfo().IsValueType),
suppress: node.Object.Type.IsNullable()
Expression.MakeIndex(target, node.Indexer, UnwrapNullableTypes(node.Arguments))
);
}

protected override Expression VisitMethodCall(MethodCallExpression node)
{
return CheckForNull(Visit(node.Object), target =>
CheckForNulls(node.Arguments.Select(Visit).ToArray(), args =>
Expression.Call(target, node.Method, args),
suppressThisOne: (arg, i) => node.Arguments[i].Type.IsNullable() || !arg.Type.GetTypeInfo().IsValueType),
Expression.Call(target, node.Method, UnwrapNullableTypes(node.Arguments)),
suppress: node.Object?.Type?.IsNullable() ?? true
);
}

protected override Expression VisitNew(NewExpression node)
{
return CheckForNulls(node.Arguments.Select(Visit).ToArray(), args =>
Expression.New(node.Constructor, args),
suppressThisOne: (arg, i) => node.Arguments[i].Type.IsNullable() || !arg.Type.GetTypeInfo().IsValueType);
return Expression.New(node.Constructor, UnwrapNullableTypes(node.Arguments));
}

protected Expression CheckForNulls(Expression[] parameters, Func<Expression[], Expression> callback, Func<Expression, int, bool> suppressThisOne = null)
protected Expression[] UnwrapNullableTypes(IEnumerable<Expression> uncheckedArguments) =>
uncheckedArguments.Select(UnwrapNullableType).ToArray();
protected Expression UnwrapNullableType(Expression expression) =>
UnwrapNullableType(Visit(expression), expression.Type, expression.ToString());
protected Expression UnwrapNullableType(Expression expression, Type expectedType, string formattedExpression)
{
if (parameters.Length == 0) return callback(new Expression[0]);
var list = new List<Expression>();
Func<Expression, Expression> cc = e => { list.Add(e); return callback(list.ToArray()); };
for (var i = parameters.Length - 1; i >= 1; i--)
if (expression.Type == expectedType)
return expression;
else if (expression.Type == typeof(Nullable<>).MakeGenericType(expectedType))
{
var iCopy = i;
var ccc = cc;
cc = e => { list.Add(e); return CheckForNull(parameters[iCopy], ccc, suppress: suppressThisOne?.Invoke(parameters[iCopy], iCopy) ?? false); };
var tmp = Expression.Parameter(expression.Type);
var nreCtor = typeof(NullReferenceException).GetConstructor(new [] { typeof(string) });
return Expression.Block(new [] { tmp },
Expression.Assign(tmp, expression),
Expression.Condition(
Expression.Property(tmp, "HasValue"),
Expression.Property(tmp, "Value"),
Expression.Throw(Expression.New(nreCtor, Expression.Constant($"Binding expression '{formattedExpression}' of type '{expectedType}' has evaluated to null.")), expectedType)
)
);
}
return CheckForNull(parameters[0], cc, suppress: suppressThisOne?.Invoke(parameters[0], 0) ?? false);
else
throw new Exception($"Type mismatch: {expectedType} was expected, got {expression.Type}");
}

private int tmpCounter;
Expand Down
11 changes: 7 additions & 4 deletions src/DotVVM.Samples.Tests/Feature/PostbackConcurrencyTests.cs
Expand Up @@ -3,6 +3,7 @@
using DotVVM.Testing.Abstractions;
using Riganti.Selenium.Core;
using Riganti.Selenium.Core.Abstractions.Attributes;
using Riganti.Selenium.DotVVM;
using Xunit;
using Xunit.Abstractions;

Expand Down Expand Up @@ -40,6 +41,7 @@ public void Feature_PostbackConcurrency_DefaultMode(string longActionSelector, s
{
RunInAllBrowsers(browser => {
browser.NavigateToUrl(SamplesRouteUrls.FeatureSamples_PostbackConcurrency_DefaultMode);
browser.WaitUntilDotvvmInited();
// try the long action interrupted by the short one
browser.Single(longActionSelector).Click();
Expand All @@ -52,11 +54,12 @@ public void Feature_PostbackConcurrency_DefaultMode(string longActionSelector, s
// the postback index should be 1 now (because of short action)
AssertUI.InnerTextEquals(postbackIndexSpan, "1");
AssertUI.InnerTextEquals(lastActionSpan, "short");
// the result of the long action should be canceled, the counter shouldn't increase
browser.Wait(6000);
AssertUI.InnerTextEquals(postbackIndexSpan, "1");
AssertUI.InnerTextEquals(lastActionSpan, "short");
// the result of the long action should be canceled, the counter shouldn't increase
browser.WaitFor(()=> {
AssertUI.InnerTextEquals(postbackIndexSpan, "1");
AssertUI.InnerTextEquals(lastActionSpan, "short");
},3000);
});
}

Expand Down