diff --git a/samples/LiveTesting/Program.cs b/samples/LiveTesting/Program.cs index eef135e..f42ccb3 100644 --- a/samples/LiveTesting/Program.cs +++ b/samples/LiveTesting/Program.cs @@ -33,8 +33,6 @@ partial class Program static Guid staticGuid = Guid.Parse("39b29409-880f-42a4-a4ae-2752d97886fa"); - - public class NullableWrapper { public Test? TestStruct; @@ -44,17 +42,7 @@ public NullableWrapper(Test? nullableStruct) TestStruct = nullableStruct; } } - - public class NormalWrapper - { - public readonly Test Struct; - - public NormalWrapper(Test nullableStruct) - { - Struct = nullableStruct; - } - } - + public struct Test : IEquatable { public decimal Value; @@ -98,7 +86,7 @@ public bool Equals(SubStruct other) } } - public struct NameAge : IEquatable + public readonly struct NameAge { public readonly string Name; public readonly int? Age; @@ -108,17 +96,6 @@ public NameAge(string name, int? age) Name = name; Age = age; } - - public override bool Equals(object obj) - { - return obj is NameAge age && Equals(age); - } - - public bool Equals(NameAge other) - { - return Name == other.Name && - EqualityComparer.Default.Equals(Age, other.Age); - } } @@ -280,9 +257,9 @@ static unsafe void Main(string[] args) + /* var eqOpNullableTest = StructEquality.EqualFunction; - var name = "abc"; var a = new NullableWrapper(new Test { Value = 2, SubStruct = new SubStruct(new NameAge(name, 8), 1111, 3) }); var b = new NullableWrapper(new Test { Value = 2, SubStruct = new SubStruct(new NameAge(name, 8), 5, 3) }); @@ -292,16 +269,21 @@ static unsafe void Main(string[] args) var ab = eqOpNullableTest(ref a.TestStruct, ref b.TestStruct); var bc = eqOpNullableTest(ref b.TestStruct, ref c.TestStruct); - - MicroBenchmark.Run(2, + + MicroBenchmark.Run(2, new[]{ + new BenchJob(".Value -> IEquatable.Equals()", () => b.TestStruct.Value.Equals(c.TestStruct.Value) ), + new BenchJob("object.Equals()", () => object.Equals(b.TestStruct, c.TestStruct) ), + new BenchJob("Nullable.Equals()", () => b.TestStruct.Equals(c.TestStruct) ), new BenchJob("StructEquality.AreEqual()", () => StructEquality.AreEqual(ref b.TestStruct, ref c.TestStruct)), - new BenchJob("object.Equals(left, right)", () => object.Equals(b, c) ) - //new BenchJob("left.Equals(right)", () => b.TestStruct.Equals(c.TestStruct) ), - ); - Console.ReadLine(); + }); + Console.WriteLine(); + Console.WriteLine(); + Console.WriteLine("done!"); + Console.ReadLine(); + */ - SaveDelegateIL(StructEquality.Lambda); + //SaveDelegateIL(StructEquality.Lambda); var config = new SerializerConfig { DefaultTargets = TargetMember.AllFields, PreserveReferences = false }; @@ -312,7 +294,7 @@ static unsafe void Main(string[] args) var ceras = new CerasSerializer(config); - var obj = new NullableWrapper(new Test { Value = 123.456M }); + var obj = new NullableWrapper(new Test { Value = 2.34M, SubStruct = new SubStruct(new NameAge("riki", 5), 6, 7) }); var data = ceras.Serialize(obj); var clone = ceras.Deserialize(data); diff --git a/src/Ceras/Formatters/DynamicFormatter/DynamicFormatterHelpers.cs b/src/Ceras/Formatters/DynamicFormatter/DynamicFormatterHelpers.cs index ed726a5..1c28a3e 100644 --- a/src/Ceras/Formatters/DynamicFormatter/DynamicFormatterHelpers.cs +++ b/src/Ceras/Formatters/DynamicFormatter/DynamicFormatterHelpers.cs @@ -29,9 +29,7 @@ internal static void EmitReadonlyWriteBack(Type type, ReadonlyFieldHandling read { // Value types are simple. // Either they match perfectly -> do nothing - // Or the values are not the same -> either throw an exception of do a forced overwrite - - + // Or the values are not the same -> throw exception or forced overwrite Expression onMismatch; if (readonlyFieldHandling == ReadonlyFieldHandling.ForcedOverwrite) @@ -41,7 +39,7 @@ internal static void EmitReadonlyWriteBack(Type type, ReadonlyFieldHandling read onMismatch = Throw(Constant(new CerasException($"The value-type in field '{fieldInfo.Name}' does not match the expected value, but the field is readonly and overwriting is not allowed in the configuration. Make the field writeable or enable 'ForcedOverwrite' in the serializer settings to allow Ceras to overwrite the readonly-field."))); block.Add(IfThenElse( - test: Equal(tempStore, MakeMemberAccess(refValueArg, fieldInfo)), + test: StructEquality.IsStructEqual(tempStore, Field(refValueArg, fieldInfo)), ifTrue: Empty(), ifFalse: onMismatch )); diff --git a/src/Ceras/Helpers/StructEquality.cs b/src/Ceras/Helpers/StructEquality.cs index 715b02b..099518f 100644 --- a/src/Ceras/Helpers/StructEquality.cs +++ b/src/Ceras/Helpers/StructEquality.cs @@ -3,6 +3,7 @@ using System.Linq; using System.Linq.Expressions; using System.Reflection; +using System.Runtime.CompilerServices; using System.Text; using System.Threading.Tasks; @@ -14,89 +15,134 @@ namespace Ceras.Helpers public delegate bool EqualsDelegate(ref TStruct left, ref TStruct right); static class StructEquality { - public static EqualsDelegate EqualFunction { get; } - public static LambdaExpression Lambda { get; } - - public static bool AreEqual(ref T left, ref T right) => EqualFunction(ref left, ref right); + public static EqualsDelegate EqualFunction { get; } + public static LambdaExpression Lambda { get; } static StructEquality() { if (!typeof(T).IsValueType || typeof(T).IsPrimitive) throw new InvalidOperationException("T must be a non-primitive value type (a struct)"); - (EqualFunction, Lambda) = GenerateEq(); + Lambda = GenerateEqualityExpression(); + EqualFunction = (EqualsDelegate)Lambda.Compile(); } - static (EqualsDelegate, LambdaExpression) GenerateEq() + static LambdaExpression GenerateEqualityExpression() { var type = typeof(T); - var left = Parameter(type.MakeByRefType(), "left"); var right = Parameter(type.MakeByRefType(), "right"); - var methodEnd = Label(typeof(bool), "methodEnd"); - var body = new List(); - foreach (var f in type.GetFields(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic)) - { - var fLeft = Field(left, f); - var fRight = Field(right, f); - - // "if( left.f != right.f ) return false;" - body.Add(IfThen( - Not(AreFieldsEqual(f.FieldType, fLeft, fRight)), - Return(methodEnd, Constant(false)) - )); - } + var isEqExp = StructEquality.IsStructEqual(left, right); + + var delType = typeof(EqualsDelegate<>).MakeGenericType(type); + return Lambda(delType, isEqExp, left, right); + } + } + + static class StructEquality + { + internal static readonly bool resolveLambda = true; + internal static readonly bool useIEquatable = false; - // "return true;" - body.Add(Label(methodEnd, Constant(true))); + // Knowing that is a struct, compare all of its fields one-by-one + [MethodImpl(MethodImplOptions.Synchronized)] + internal static Expression IsStructEqual(Expression left, Expression right) + { + if (left.Type != right.Type) + throw new InvalidOperationException("left and right expressions have the same type"); + + var type = left.Type; // Expression automatically strips 'byRef' + + if (!type.IsValueType || type.IsPrimitive) + throw new InvalidOperationException("T must be a non-primitive value type (a struct)"); - var lambda = Expression.Lambda( - delegateType: typeof(EqualsDelegate<>).MakeGenericType(type), - body: Block(body), - left, right); - var del = lambda.Compile(); + /* + * return ( + * (left.f1 == right.f1) && + * (left.f2 == right.f2) && + * (left.f3 == right.f3) && + * ... + * ); + */ + var fields = type.GetFields(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic); - return ((EqualsDelegate)del, lambda); + var fieldEqualities = fields.Select(f => IsFieldEqual(Field(left, f), Field(right, f))); + var andAll = fieldEqualities.Aggregate((a, b) => AndAlso(a, b)); + + return andAll; } - static Expression AreFieldsEqual(Type fieldType, MemberExpression leftField, MemberExpression rightField) + // Given the fields (left, right) of unknown types, determine the right comparison method (ReferenceEquals, Equals, RecurseIntoStruct) + static Expression IsFieldEqual(MemberExpression left, MemberExpression right) { + var fieldType = ((FieldInfo)left.Member).FieldType; + + // // ReferenceTypes: ReferenceEqual() if (!fieldType.IsValueType) - return ReferenceEqual(leftField, rightField); + return ReferenceEqual(left, right); + // // Primitives: Equal() if (fieldType.IsPrimitive) - return Equal(leftField, rightField); + return Equal(left, right); + + // + // Nullable: compare directly + if (Nullable.GetUnderlyingType(fieldType) != null) + return IsStructEqual(left, right); - // Try custom equality operator if it exists - try + // + // IEquatable: strongly typed custom implementation + // Not preferable because it takes the parameter by-value, instead of by-reference + if (useIEquatable) { - var customEq = Equal(leftField, rightField); - return customEq; + var equatable = typeof(IEquatable<>).MakeGenericType(fieldType); + if (equatable.IsAssignableFrom(fieldType)) + { + var typedEquals = fieldType.GetMethod(nameof(IEquatable.Equals), new Type[] { fieldType }); + return Call(left, typedEquals, right); + } } - catch { } - // Structs: Recurse into AreEqual() - const bool resolveLambda = true; + // !! What if a user defines 'bool Equals(MyStruct other)' + // !! but doesn't mark it as implementing 'IEquatable' ?? + // !! If they only override Equals() that'd be really bad. + // + // override Equals() + // try { return Equal(leftField, rightField); } catch { } + + // + // Structs: recurse into AreEqual (either by call, or by unpacking the lambda) + return GetStructEquality(left, right); + } + + // Create a 'method call' or 'lambda invoke' to compare the fields of the two given structs + static Expression GetStructEquality(MemberExpression left, MemberExpression right) + { + var fieldType = ((FieldInfo)left.Member).FieldType; + + // Resolving lambda gives an improvement from 20x slower -> 3-5x slower if (resolveLambda) { + // Get and unpack: + // 'StructEquality.Lambda' var lambdaProp = typeof(StructEquality<>).MakeGenericType(fieldType).GetProperty(nameof(Lambda)); var eqLambda = (LambdaExpression)lambdaProp.GetValue(null); - return Invoke(eqLambda, leftField, rightField); + return Invoke(eqLambda, left, right); } else { - var eqMethod = typeof(StructEquality<>).MakeGenericType(fieldType).GetMethod(nameof(AreEqual), BindingFlags.Static | BindingFlags.Public); - return Call(method: eqMethod, arg0: leftField, arg1: rightField); + // Call: + // 'StructEquality.AreEqual(ref left, ref right); + var eqMethod = typeof(StructEquality<>).MakeGenericType(fieldType).GetMethod(nameof(StructEquality.AreEqual), BindingFlags.Static | BindingFlags.Public); + return Call(method: eqMethod, arg0: left, arg1: right); } } - } - }