Skip to content

Commit

Permalink
more progress on #64
Browse files Browse the repository at this point in the history
  • Loading branch information
rikimaru0345 committed Aug 27, 2019
1 parent 137c048 commit 02c3dcc
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 81 deletions.
50 changes: 16 additions & 34 deletions samples/LiveTesting/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ partial class Program
static Guid staticGuid = Guid.Parse("39b29409-880f-42a4-a4ae-2752d97886fa");




public class NullableWrapper
{
public Test? TestStruct;
Expand All @@ -44,17 +42,7 @@ public NullableWrapper(Test? nullableStruct)
TestStruct = nullableStruct;
}
}

public class NormalWrapper
{
public readonly Test Struct;

public NormalWrapper(Test nullableStruct)
{
Struct = nullableStruct;
}
}


public struct Test : IEquatable<Test>
{
public decimal Value;
Expand Down Expand Up @@ -98,7 +86,7 @@ public bool Equals(SubStruct other)
}
}

public struct NameAge : IEquatable<NameAge>
public readonly struct NameAge
{
public readonly string Name;
public readonly int? Age;
Expand All @@ -108,17 +96,6 @@ public NameAge(string name, int? age)
Name = name;
Age = age;
}

public override bool Equals(object obj)
{
return obj is NameAge age && Equals(age);
}

public bool Equals(NameAge other)
{
return Name == other.Name &&
EqualityComparer<int?>.Default.Equals(Age, other.Age);
}
}


Expand Down Expand Up @@ -280,9 +257,9 @@ static unsafe void Main(string[] args)



/*
var eqOpNullableTest = StructEquality<Test?>.EqualFunction;

var name = "abc";
var a = new NullableWrapper(new Test { Value = 2, SubStruct = new SubStruct(new NameAge(name, 8), 1111, 3) });
var b = new NullableWrapper(new Test { Value = 2, SubStruct = new SubStruct(new NameAge(name, 8), 5, 3) });
Expand All @@ -292,16 +269,21 @@ static unsafe void Main(string[] args)
var ab = eqOpNullableTest(ref a.TestStruct, ref b.TestStruct);
var bc = eqOpNullableTest(ref b.TestStruct, ref c.TestStruct);

MicroBenchmark.Run(2,
MicroBenchmark.Run(2, new[]{
new BenchJob(".Value -> IEquatable<T>.Equals()", () => b.TestStruct.Value.Equals(c.TestStruct.Value) ),
new BenchJob("object.Equals()", () => object.Equals(b.TestStruct, c.TestStruct) ),
new BenchJob("Nullable<T>.Equals()", () => b.TestStruct.Equals(c.TestStruct) ),
new BenchJob("StructEquality<T>.AreEqual()", () => StructEquality<Test?>.AreEqual(ref b.TestStruct, ref c.TestStruct)),
new BenchJob("object.Equals(left, right)", () => object.Equals(b, c) )
//new BenchJob("left.Equals(right)", () => b.TestStruct.Equals(c.TestStruct) ),
);
Console.ReadLine();
});
Console.WriteLine();
Console.WriteLine();
Console.WriteLine("done!");
Console.ReadLine();
*/

SaveDelegateIL(StructEquality<Test?>.Lambda);
//SaveDelegateIL(StructEquality<Test?>.Lambda);


var config = new SerializerConfig { DefaultTargets = TargetMember.AllFields, PreserveReferences = false };
Expand All @@ -312,7 +294,7 @@ static unsafe void Main(string[] args)

var ceras = new CerasSerializer(config);

var obj = new NullableWrapper(new Test { Value = 123.456M });
var obj = new NullableWrapper(new Test { Value = 2.34M, SubStruct = new SubStruct(new NameAge("riki", 5), 6, 7) });

var data = ceras.Serialize(obj);
var clone = ceras.Deserialize<NullableWrapper>(data);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ internal static void EmitReadonlyWriteBack(Type type, ReadonlyFieldHandling read
{
// Value types are simple.
// Either they match perfectly -> do nothing
// Or the values are not the same -> either throw an exception of do a forced overwrite


// Or the values are not the same -> throw exception or forced overwrite

Expression onMismatch;
if (readonlyFieldHandling == ReadonlyFieldHandling.ForcedOverwrite)
Expand All @@ -41,7 +39,7 @@ internal static void EmitReadonlyWriteBack(Type type, ReadonlyFieldHandling read
onMismatch = Throw(Constant(new CerasException($"The value-type in field '{fieldInfo.Name}' does not match the expected value, but the field is readonly and overwriting is not allowed in the configuration. Make the field writeable or enable 'ForcedOverwrite' in the serializer settings to allow Ceras to overwrite the readonly-field.")));

block.Add(IfThenElse(
test: Equal(tempStore, MakeMemberAccess(refValueArg, fieldInfo)),
test: StructEquality.IsStructEqual(tempStore, Field(refValueArg, fieldInfo)),
ifTrue: Empty(),
ifFalse: onMismatch
));
Expand Down
132 changes: 89 additions & 43 deletions src/Ceras/Helpers/StructEquality.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading.Tasks;

Expand All @@ -14,89 +15,134 @@ namespace Ceras.Helpers
public delegate bool EqualsDelegate<TStruct>(ref TStruct left, ref TStruct right);
static class StructEquality<T>
{
public static EqualsDelegate<T> EqualFunction { get; }
public static LambdaExpression Lambda { get; }


public static bool AreEqual(ref T left, ref T right) => EqualFunction(ref left, ref right);

public static EqualsDelegate<T> EqualFunction { get; }
public static LambdaExpression Lambda { get; }

static StructEquality()
{
if (!typeof(T).IsValueType || typeof(T).IsPrimitive)
throw new InvalidOperationException("T must be a non-primitive value type (a struct)");

(EqualFunction, Lambda) = GenerateEq();
Lambda = GenerateEqualityExpression();
EqualFunction = (EqualsDelegate<T>)Lambda.Compile();
}

static (EqualsDelegate<T>, LambdaExpression) GenerateEq()
static LambdaExpression GenerateEqualityExpression()
{
var type = typeof(T);

var left = Parameter(type.MakeByRefType(), "left");
var right = Parameter(type.MakeByRefType(), "right");
var methodEnd = Label(typeof(bool), "methodEnd");
var body = new List<Expression>();

foreach (var f in type.GetFields(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic))
{
var fLeft = Field(left, f);
var fRight = Field(right, f);

// "if( left.f != right.f ) return false;"
body.Add(IfThen(
Not(AreFieldsEqual(f.FieldType, fLeft, fRight)),
Return(methodEnd, Constant(false))
));
}
var isEqExp = StructEquality.IsStructEqual(left, right);

var delType = typeof(EqualsDelegate<>).MakeGenericType(type);
return Lambda(delType, isEqExp, left, right);
}
}

static class StructEquality
{
internal static readonly bool resolveLambda = true;
internal static readonly bool useIEquatable = false;

// "return true;"
body.Add(Label(methodEnd, Constant(true)));
// Knowing that <T> is a struct, compare all of its fields one-by-one
[MethodImpl(MethodImplOptions.Synchronized)]
internal static Expression IsStructEqual(Expression left, Expression right)
{
if (left.Type != right.Type)
throw new InvalidOperationException("left and right expressions have the same type");

var type = left.Type; // Expression automatically strips 'byRef'

if (!type.IsValueType || type.IsPrimitive)
throw new InvalidOperationException("T must be a non-primitive value type (a struct)");

var lambda = Expression.Lambda(
delegateType: typeof(EqualsDelegate<>).MakeGenericType(type),
body: Block(body),
left, right);
var del = lambda.Compile();
/*
* return (
* (left.f1 == right.f1) &&
* (left.f2 == right.f2) &&
* (left.f3 == right.f3) &&
* ...
* );
*/
var fields = type.GetFields(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic);

return ((EqualsDelegate<T>)del, lambda);
var fieldEqualities = fields.Select(f => IsFieldEqual(Field(left, f), Field(right, f)));
var andAll = fieldEqualities.Aggregate((a, b) => AndAlso(a, b));

return andAll;
}

static Expression AreFieldsEqual(Type fieldType, MemberExpression leftField, MemberExpression rightField)
// Given the fields (left, right) of unknown types, determine the right comparison method (ReferenceEquals, Equals, RecurseIntoStruct)
static Expression IsFieldEqual(MemberExpression left, MemberExpression right)
{
var fieldType = ((FieldInfo)left.Member).FieldType;

//
// ReferenceTypes: ReferenceEqual()
if (!fieldType.IsValueType)
return ReferenceEqual(leftField, rightField);
return ReferenceEqual(left, right);

//
// Primitives: Equal()
if (fieldType.IsPrimitive)
return Equal(leftField, rightField);
return Equal(left, right);

//
// Nullable: compare directly
if (Nullable.GetUnderlyingType(fieldType) != null)
return IsStructEqual(left, right);

// Try custom equality operator if it exists
try
//
// IEquatable<T>: strongly typed custom implementation
// Not preferable because it takes the parameter by-value, instead of by-reference
if (useIEquatable)
{
var customEq = Equal(leftField, rightField);
return customEq;
var equatable = typeof(IEquatable<>).MakeGenericType(fieldType);
if (equatable.IsAssignableFrom(fieldType))
{
var typedEquals = fieldType.GetMethod(nameof(IEquatable<int>.Equals), new Type[] { fieldType });
return Call(left, typedEquals, right);
}
}
catch { }

// Structs: Recurse into AreEqual()

const bool resolveLambda = true;
// !! What if a user defines 'bool Equals(MyStruct other)'
// !! but doesn't mark it as implementing 'IEquatable<MyStruct>' ??
// !! If they only override Equals() that'd be really bad.
//
// override Equals()
// try { return Equal(leftField, rightField); } catch { }


//
// Structs: recurse into AreEqual (either by call, or by unpacking the lambda)
return GetStructEquality(left, right);
}

// Create a 'method call' or 'lambda invoke' to compare the fields of the two given structs
static Expression GetStructEquality(MemberExpression left, MemberExpression right)
{
var fieldType = ((FieldInfo)left.Member).FieldType;

// Resolving lambda gives an improvement from 20x slower -> 3-5x slower
if (resolveLambda)
{
// Get and unpack:
// 'StructEquality<fieldType>.Lambda'
var lambdaProp = typeof(StructEquality<>).MakeGenericType(fieldType).GetProperty(nameof(Lambda));
var eqLambda = (LambdaExpression)lambdaProp.GetValue(null);
return Invoke(eqLambda, leftField, rightField);
return Invoke(eqLambda, left, right);
}
else
{
var eqMethod = typeof(StructEquality<>).MakeGenericType(fieldType).GetMethod(nameof(AreEqual), BindingFlags.Static | BindingFlags.Public);
return Call(method: eqMethod, arg0: leftField, arg1: rightField);
// Call:
// 'StructEquality<fieldType>.AreEqual(ref left, ref right);
var eqMethod = typeof(StructEquality<>).MakeGenericType(fieldType).GetMethod(nameof(StructEquality<int>.AreEqual), BindingFlags.Static | BindingFlags.Public);
return Call(method: eqMethod, arg0: left, arg1: right);
}
}

}

}

0 comments on commit 02c3dcc

Please sign in to comment.