Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Operator Overloading #893

Merged
merged 10 commits into from
May 13, 2021
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,4 @@ project.lock.json

.idea
BenchmarkDotNet.Artifacts*
.vscode
232 changes: 232 additions & 0 deletions Jint.Tests/Runtime/OperatorOverloadingTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
using System;
using Xunit;

namespace Jint.Tests.Runtime
{
public class OperatorOverloadingTests : IDisposable
{
private readonly Engine _engine;

public OperatorOverloadingTests()
{
_engine = new Engine(cfg => cfg
.AllowOperatorOverloading())
.SetValue("log", new Action<object>(Console.WriteLine))
.SetValue("assert", new Action<bool>(Assert.True))
.SetValue("assertFalse", new Action<bool>(Assert.False))
.SetValue("equal", new Action<object, object>(Assert.Equal))
.SetValue("Vector2", typeof(Vector2))
.SetValue("Vector3", typeof(Vector3))
;
}

void IDisposable.Dispose()
{
}

private void RunTest(string source)
{
_engine.Execute(source);
}

public class Vector2
{
public double X { get; }
public double Y { get; }
public double SqrMagnitude => X * X + Y * Y;
public double Magnitude => Math.Sqrt(SqrMagnitude);

public Vector2(double x, double y)
{
X = x;
Y = y;
}

public static Vector2 operator +(Vector2 left, Vector2 right) => new Vector2(left.X + right.X, left.Y + right.Y);
public static Vector2 operator +(Vector2 left, double right) => new Vector2(left.X + right, left.Y + right);
public static Vector2 operator +(string left, Vector2 right) => new Vector2(right.X, right.Y);
public static Vector2 operator +(double left, Vector2 right) => new Vector2(right.X + left, right.Y + left);
public static Vector2 operator *(Vector2 left, double right) => new Vector2(left.X * right, left.Y * right);
public static Vector2 operator /(Vector2 left, double right) => new Vector2(left.X / right, left.Y / right);

public static bool operator >(Vector2 left, Vector2 right) => left.Magnitude > right.Magnitude;
public static bool operator <(Vector2 left, Vector2 right) => left.Magnitude < right.Magnitude;
public static bool operator >=(Vector2 left, Vector2 right) => left.Magnitude >= right.Magnitude;
public static bool operator <=(Vector2 left, Vector2 right) => left.Magnitude <= right.Magnitude;
public static Vector2 operator %(Vector2 left, Vector2 right) => new Vector2(left.X % right.X, left.Y % right.Y);
public static double operator &(Vector2 left, Vector2 right) => left.X * right.X + left.Y * right.Y;
public static Vector2 operator |(Vector2 left, Vector2 right) => right * ((left & right) / right.SqrMagnitude);


public static double operator +(Vector2 operand) => operand.Magnitude;
public static Vector2 operator -(Vector2 operand) => new Vector2(-operand.X, -operand.Y);
public static bool operator !(Vector2 operand) => operand.Magnitude == 0;
public static Vector2 operator ~(Vector2 operand) => new Vector2(operand.Y, operand.X);
public static Vector2 operator ++(Vector2 operand) => new Vector2(operand.X + 1, operand.Y + 1);
public static Vector2 operator --(Vector2 operand) => new Vector2(operand.X - 1, operand.Y - 1);

public static implicit operator Vector3(Vector2 val) => new Vector3(val.X, val.Y, 0);
public static bool operator !=(Vector2 left, Vector2 right) => !(left == right);
public static bool operator ==(Vector2 left, Vector2 right) => left.X == right.X && left.Y == right.Y;
public override bool Equals(object obj) => ReferenceEquals(this, obj);
public override int GetHashCode() => X.GetHashCode() + Y.GetHashCode();
}

public class Vector3
{
public double X { get; }
public double Y { get; }
public double Z { get; }

public Vector3(double x, double y, double z)
{
X = x;
Y = y;
Z = z;
}

public static Vector3 operator +(Vector3 left, double right) => new Vector3(left.X + right, left.Y + right, left.Z + right);
public static Vector3 operator +(double left, Vector3 right) => new Vector3(right.X + left, right.Y + left, right.Z + left);
public static Vector3 operator +(Vector3 left, Vector3 right) => new Vector3(left.X + right.X, left.Y + right.Y, left.Z + right.Z);
}

[Fact]
public void OperatorOverloading_BinaryOperators()
{
RunTest(@"
var v1 = new Vector2(1, 2);
var v2 = new Vector2(3, 4);
var n = 6;

var r1 = v1 + v2;
equal(4, r1.X);
equal(6, r1.Y);

var r2 = n + v1;
equal(7, r2.X);
equal(8, r2.Y);

var r3 = v1 + n;
equal(7, r3.X);
equal(8, r3.Y);

var r4 = v1 * n;
equal(6, r4.X);
equal(12, r4.Y);

var r5 = v1 / n;
equal(1 / 6, r5.X);
equal(2 / 6, r5.Y);

var r6 = v2 % new Vector2(2, 3);
equal(1, r6.X);
equal(1, r6.Y);

var r7 = v2 & v1;
equal(11, r7);

var r8 = new Vector2(3, 4) | new Vector2(2, 0);
equal(3, r8.X);
equal(0, r8.Y);


var vSmall = new Vector2(3, 4);
var vBig = new Vector2(4, 4);

assert(vSmall < vBig);
assert(vSmall <= vBig);
assert(vSmall <= vSmall);
assert(vBig > vSmall);
assert(vBig >= vSmall);
assert(vBig >= vBig);

assertFalse(vSmall > vSmall);
assertFalse(vSmall < vSmall);
assertFalse(vSmall > vBig);
assertFalse(vSmall >= vBig);
assertFalse(vBig < vBig);
assertFalse(vBig > vBig);
assertFalse(vBig < vSmall);
assertFalse(vBig <= vSmall);
");
}

[Fact]
public void OperatorOverloading_ShouldCoerceTypes()
{
RunTest(@"
var v1 = new Vector2(1, 2);
var v2 = new Vector3(4, 5, 6);
var res = v1 + v2;
equal(5, res.X);
equal(7, res.Y);
equal(6, res.Z);
");
}

[Fact]
public void OperatorOverloading_ShouldWorkForEqualityButNotForStrictEquality()
{
RunTest(@"
var v1 = new Vector2(1, 2);
var v2 = new Vector2(1, 2);
assert(v1 == v2);
assertFalse(v1 != v2);
assert(v1 !== v2);
assertFalse(v1 === v2);


var z1 = new Vector3(1, 2, 3);
var z2 = new Vector3(1, 2, 3);
assertFalse(z1 == z2);
");
}

[Fact]
public void OperatorOverloading_UnaryOperators()
{
RunTest(@"
var v0 = new Vector2(0, 0);
var v = new Vector2(3, 4);
var rv = -v;
var bv = ~v;

assert(!v0);
assertFalse(!v);

equal(0, +v0);
equal(5, +v);
equal(5, +rv);
equal(-3, rv.X);
equal(-4, rv.Y);

equal(4, bv.X);
equal(3, bv.Y);
");
}

[Fact]
public void OperatorOverloading_IncrementOperatorShouldWork()
{
RunTest(@"
var v = new Vector2(3, 22);
var original = v;
var pre = ++v;
var post = v++;

equal(3, original.X);
equal(4, pre.X);
equal(4, post.X);
equal(5, v.X);

var decPre = --v;
var decPost = v--;

equal(4, decPre.X);
equal(4, decPost.X);
equal(3, v.X);
");
}

}
}
11 changes: 10 additions & 1 deletion Jint/Options.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ public sealed class Options
private DebuggerStatementHandling _debuggerStatementHandling;
private bool _allowClr;
private bool _allowClrWrite = true;
private bool _allowOperatorOverloading;
private readonly List<IObjectConverter> _objectConverters = new();
private Func<Engine, object, ObjectInstance> _wrapObjectHandler;
private MemberAccessorDelegate _memberAccessor;
Expand Down Expand Up @@ -127,7 +128,7 @@ PropertyDescriptor CreateMethodInstancePropertyDescriptor(ClrFunctionInstance cl
JsValue key = overloads.Key;
PropertyDescriptor descriptorWithFallback = null;
PropertyDescriptor descriptorWithoutFallback = null;

if (prototype.HasOwnProperty(key) && prototype.GetOwnProperty(key).Value is ClrFunctionInstance clrFunctionInstance)
{
descriptorWithFallback = CreateMethodInstancePropertyDescriptor(clrFunctionInstance);
Expand Down Expand Up @@ -210,6 +211,12 @@ public Options AllowClrWrite(bool allow = true)
return this;
}

public Options AllowOperatorOverloading(bool allow = true)
{
_allowOperatorOverloading = allow;
return this;
}

/// <summary>
/// Exceptions thrown from CLR code are converted to JavaScript errors and
/// can be used in at try/catch statement. By default these exceptions are bubbled
Expand Down Expand Up @@ -335,6 +342,8 @@ internal void Apply(Engine engine)

internal bool _IsClrWriteAllowed => _allowClrWrite;

internal bool _IsOperatorOverloadingAllowed => _allowOperatorOverloading;

internal Predicate<Exception> _ClrExceptionsHandler => _clrExceptionsHandler;

internal List<Assembly> _LookupAssemblies => _lookupAssemblies;
Expand Down
40 changes: 33 additions & 7 deletions Jint/Runtime/Interop/DefaultTypeConverter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Dynamic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using Jint.Extensions;
Expand All @@ -16,8 +17,10 @@ public class DefaultTypeConverter : ITypeConverter

#if NETSTANDARD
private static readonly ConcurrentDictionary<(Type Source, Type Target), bool> _knownConversions = new ConcurrentDictionary<(Type Source, Type Target), bool>();
private static readonly ConcurrentDictionary<(Type Source, Type Target), MethodInfo> _knownCastOperators = new ConcurrentDictionary<(Type Source, Type Target), MethodInfo>();
#else
private static readonly ConcurrentDictionary<string, bool> _knownConversions = new ConcurrentDictionary<string, bool>();
private static readonly ConcurrentDictionary<string, MethodInfo> _knownCastOperators = new ConcurrentDictionary<string, MethodInfo>();
#endif

private static readonly Type nullableType = typeof(Nullable<>);
Expand All @@ -28,7 +31,7 @@ public class DefaultTypeConverter : ITypeConverter
private static readonly Type engineType = typeof(Engine);
private static readonly Type typeType = typeof(Type);

private static readonly MethodInfo convertChangeType = typeof(Convert).GetMethod("ChangeType", new [] { objectType, typeType, typeof(IFormatProvider) });
private static readonly MethodInfo convertChangeType = typeof(Convert).GetMethod("ChangeType", new[] { objectType, typeType, typeof(IFormatProvider) });
private static readonly MethodInfo jsValueFromObject = jsValueType.GetMethod(nameof(JsValue.FromObject));
private static readonly MethodInfo jsValueToObject = jsValueType.GetMethod(nameof(JsValue.ToObject));

Expand Down Expand Up @@ -76,7 +79,7 @@ public virtual object Convert(object value, Type type, IFormatProvider formatPro
// is the javascript value an ICallable instance ?
if (valueType == iCallableType)
{
var function = (Func<JsValue, JsValue[], JsValue>)value;
var function = (Func<JsValue, JsValue[], JsValue>) value;

if (typeof(Delegate).IsAssignableFrom(type) && !type.IsAbstract)
{
Expand Down Expand Up @@ -119,11 +122,11 @@ public virtual object Convert(object value, Type type, IFormatProvider formatPro
Expression.Convert(
Expression.Call(
null,
convertChangeType,
Expression.Call(callExpression, jsValueToObject),
Expression.Constant(method.ReturnType),
convertChangeType,
Expression.Call(callExpression, jsValueToObject),
Expression.Constant(method.ReturnType),
Expression.Constant(System.Globalization.CultureInfo.InvariantCulture, typeof(IFormatProvider))
),
),
method.ReturnType
),
new ReadOnlyCollection<ParameterExpression>(@params)).Compile();
Expand Down Expand Up @@ -168,7 +171,7 @@ public virtual object Convert(object value, Type type, IFormatProvider formatPro
}

// reference types - return null if no valid constructor is found
if(!type.IsValueType)
if (!type.IsValueType)
{
var found = false;
foreach (var constructor in constructors)
Expand Down Expand Up @@ -210,6 +213,29 @@ public virtual object Convert(object value, Type type, IFormatProvider formatPro
return obj;
}

if (_engine.Options._IsOperatorOverloadingAllowed)
{
#if NETSTANDARD
var key = (valueType, type);
#else
var key = $"{valueType}->{type}";
#endif

var castOperator = _knownCastOperators.GetOrAdd(key, _ =>
valueType
.GetMethods(BindingFlags.Public | BindingFlags.Static)
.Concat(type.GetMethods(BindingFlags.Public | BindingFlags.Static))
.FirstOrDefault(m =>
m.IsSpecialName
&& type.IsAssignableFrom(m.ReturnType)
&& (m.Name == "op_Implicit" || m.Name == "op_Explicit")));

if (castOperator != null)
{
return castOperator.Invoke(null, new[] { value });
}
}

return System.Convert.ChangeType(value, type, formatProvider);
}

Expand Down