Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Operator Overloading #893

Merged
merged 10 commits into from
May 13, 2021
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,4 @@ project.lock.json

.idea
BenchmarkDotNet.Artifacts*
.vscode
75 changes: 75 additions & 0 deletions Jint.Tests/Runtime/MethodAmbiguityTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
using Jint.Native;
using System;
using Xunit;

namespace Jint.Tests.Runtime
{
public class MethodAmbiguityTests : IDisposable
{
private readonly Engine _engine;

public MethodAmbiguityTests()
{
_engine = new Engine(cfg => cfg
.AllowOperatorOverloading())
.SetValue("log", new Action<object>(Console.WriteLine))
.SetValue("throws", new Func<Action, Exception>(Assert.Throws<Exception>))
.SetValue("assert", new Action<bool>(Assert.True))
.SetValue("assertFalse", new Action<bool>(Assert.False))
.SetValue("equal", new Action<object, object>(Assert.Equal))
.SetValue("TestClass", typeof(TestClass))
.SetValue("ChildTestClass", typeof(ChildTestClass))
;
}

void IDisposable.Dispose()
{
}

private void RunTest(string source)
{
_engine.Execute(source);
}

public class TestClass
{
public int TestMethod(double a, string b, double c) => 0;
public int TestMethod(double a, double b, double c) => 1;
public int TestMethod(TestClass a, string b, double c) => 2;
public int TestMethod(TestClass a, TestClass b, double c) => 3;
public int TestMethod(TestClass a, TestClass b, TestClass c) => 4;
public int TestMethod(TestClass a, double b, string c) => 5;
public int TestMethod(ChildTestClass a, double b, string c) => 6;
public int TestMethod(ChildTestClass a, string b, JsValue c) => 7;

public static implicit operator TestClass(double i) => new TestClass();
public static implicit operator double(TestClass tc) => 0;
public static explicit operator string(TestClass tc) => "";
}

public class ChildTestClass : TestClass { }

[Fact]
public void BestMatchingMethodShouldBeCalled()
{
RunTest(@"
var tc = new TestClass();
var cc = new ChildTestClass();

equal(0, tc.TestMethod(0, '', 0));
equal(1, tc.TestMethod(0, 0, 0));
equal(2, tc.TestMethod(tc, '', 0));
equal(3, tc.TestMethod(tc, tc, 0));
equal(4, tc.TestMethod(tc, tc, tc));
equal(5, tc.TestMethod(tc, tc, ''));
equal(5, tc.TestMethod(0, 0, ''));

equal(6, tc.TestMethod(cc, 0, ''));
equal(1, tc.TestMethod(cc, 0, 0));
equal(6, tc.TestMethod(cc, cc, ''));
equal(6, tc.TestMethod(cc, 0, tc));
equal(7, tc.TestMethod(cc, '', {}));
");
}
}
}
261 changes: 261 additions & 0 deletions Jint.Tests/Runtime/OperatorOverloadingTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
using System;
using Xunit;

namespace Jint.Tests.Runtime
{
public class OperatorOverloadingTests : IDisposable
{
private readonly Engine _engine;

public OperatorOverloadingTests()
{
_engine = new Engine(cfg => cfg
.AllowOperatorOverloading())
.SetValue("log", new Action<object>(Console.WriteLine))
.SetValue("assert", new Action<bool>(Assert.True))
.SetValue("assertFalse", new Action<bool>(Assert.False))
.SetValue("equal", new Action<object, object>(Assert.Equal))
.SetValue("Vector2", typeof(Vector2))
.SetValue("Vector3", typeof(Vector3))
.SetValue("Vector2Child", typeof(Vector2Child))
;
}

void IDisposable.Dispose()
{
}

private void RunTest(string source)
{
_engine.Execute(source);
}

public class Vector2
{
public double X { get; }
public double Y { get; }
public double SqrMagnitude => X * X + Y * Y;
public double Magnitude => Math.Sqrt(SqrMagnitude);

public Vector2(double x, double y)
{
X = x;
Y = y;
}

public static Vector2 operator +(Vector2 left, Vector2 right) => new Vector2(left.X + right.X, left.Y + right.Y);
public static Vector2 operator +(Vector2 left, double right) => new Vector2(left.X + right, left.Y + right);
public static Vector2 operator +(string left, Vector2 right) => new Vector2(right.X, right.Y);
public static Vector2 operator +(double left, Vector2 right) => new Vector2(right.X + left, right.Y + left);
public static Vector2 operator *(Vector2 left, double right) => new Vector2(left.X * right, left.Y * right);
public static Vector2 operator /(Vector2 left, double right) => new Vector2(left.X / right, left.Y / right);

public static bool operator >(Vector2 left, Vector2 right) => left.Magnitude > right.Magnitude;
public static bool operator <(Vector2 left, Vector2 right) => left.Magnitude < right.Magnitude;
public static bool operator >=(Vector2 left, Vector2 right) => left.Magnitude >= right.Magnitude;
public static bool operator <=(Vector2 left, Vector2 right) => left.Magnitude <= right.Magnitude;
public static Vector2 operator %(Vector2 left, Vector2 right) => new Vector2(left.X % right.X, left.Y % right.Y);
public static double operator &(Vector2 left, Vector2 right) => left.X * right.X + left.Y * right.Y;
public static Vector2 operator |(Vector2 left, Vector2 right) => right * ((left & right) / right.SqrMagnitude);


public static double operator +(Vector2 operand) => operand.Magnitude;
public static Vector2 operator -(Vector2 operand) => new Vector2(-operand.X, -operand.Y);
public static bool operator !(Vector2 operand) => operand.Magnitude == 0;
public static Vector2 operator ~(Vector2 operand) => new Vector2(operand.Y, operand.X);
public static Vector2 operator ++(Vector2 operand) => new Vector2(operand.X + 1, operand.Y + 1);
public static Vector2 operator --(Vector2 operand) => new Vector2(operand.X - 1, operand.Y - 1);

public static implicit operator Vector3(Vector2 val) => new Vector3(val.X, val.Y, 0);
public static bool operator !=(Vector2 left, Vector2 right) => !(left == right);
public static bool operator ==(Vector2 left, Vector2 right) => left.X == right.X && left.Y == right.Y;
public override bool Equals(object obj) => ReferenceEquals(this, obj);
public override int GetHashCode() => X.GetHashCode() + Y.GetHashCode();
}

public class Vector2Child : Vector2
{
public Vector2Child(double x, double y) : base(x, y) { }

public static Vector2Child operator +(Vector2Child left, double right) => new Vector2Child(left.X + 2 * right, left.Y + 2 * right);
}

public class Vector3
{
public double X { get; }
public double Y { get; }
public double Z { get; }

public Vector3(double x, double y, double z)
{
X = x;
Y = y;
Z = z;
}

public static Vector3 operator +(Vector3 left, double right) => new Vector3(left.X + right, left.Y + right, left.Z + right);
public static Vector3 operator +(double left, Vector3 right) => new Vector3(right.X + left, right.Y + left, right.Z + left);
public static Vector3 operator +(Vector3 left, Vector3 right) => new Vector3(left.X + right.X, left.Y + right.Y, left.Z + right.Z);
}

[Fact]
public void OperatorOverloading_BinaryOperators()
{
RunTest(@"
var v1 = new Vector2(1, 2);
var v2 = new Vector2(3, 4);
var n = 6;

var r1 = v1 + v2;
equal(4, r1.X);
equal(6, r1.Y);

var r2 = n + v1;
equal(7, r2.X);
equal(8, r2.Y);

var r3 = v1 + n;
equal(7, r3.X);
equal(8, r3.Y);

var r4 = v1 * n;
equal(6, r4.X);
equal(12, r4.Y);

var r5 = v1 / n;
equal(1 / 6, r5.X);
equal(2 / 6, r5.Y);

var r6 = v2 % new Vector2(2, 3);
equal(1, r6.X);
equal(1, r6.Y);

var r7 = v2 & v1;
equal(11, r7);

var r8 = new Vector2(3, 4) | new Vector2(2, 0);
equal(3, r8.X);
equal(0, r8.Y);


var vSmall = new Vector2(3, 4);
var vBig = new Vector2(4, 4);

assert(vSmall < vBig);
assert(vSmall <= vBig);
assert(vSmall <= vSmall);
assert(vBig > vSmall);
assert(vBig >= vSmall);
assert(vBig >= vBig);

assertFalse(vSmall > vSmall);
assertFalse(vSmall < vSmall);
assertFalse(vSmall > vBig);
assertFalse(vSmall >= vBig);
assertFalse(vBig < vBig);
assertFalse(vBig > vBig);
assertFalse(vBig < vSmall);
assertFalse(vBig <= vSmall);
");
}

[Fact]
public void OperatorOverloading_ShouldCoerceTypes()
{
RunTest(@"
var v1 = new Vector2(1, 2);
var v2 = new Vector3(4, 5, 6);
var res = v1 + v2;
equal(5, res.X);
equal(7, res.Y);
equal(6, res.Z);
");
}

[Fact]
public void OperatorOverloading_ShouldWorkForEqualityButNotForStrictEquality()
{
RunTest(@"
var v1 = new Vector2(1, 2);
var v2 = new Vector2(1, 2);
assert(v1 == v2);
assertFalse(v1 != v2);
assert(v1 !== v2);
assertFalse(v1 === v2);


var z1 = new Vector3(1, 2, 3);
var z2 = new Vector3(1, 2, 3);
assertFalse(z1 == z2);
");
}

[Fact]
public void OperatorOverloading_UnaryOperators()
{
RunTest(@"
var v0 = new Vector2(0, 0);
var v = new Vector2(3, 4);
var rv = -v;
var bv = ~v;

assert(!v0);
assertFalse(!v);

equal(0, +v0);
equal(5, +v);
equal(5, +rv);
equal(-3, rv.X);
equal(-4, rv.Y);

equal(4, bv.X);
equal(3, bv.Y);
");
}

[Fact]
public void OperatorOverloading_IncrementOperatorShouldWork()
{
RunTest(@"
var v = new Vector2(3, 22);
var original = v;
var pre = ++v;
var post = v++;

equal(3, original.X);
equal(4, pre.X);
equal(4, post.X);
equal(5, v.X);

var decPre = --v;
var decPost = v--;

equal(4, decPre.X);
equal(4, decPost.X);
equal(3, v.X);
");
}

[Fact]
public void OperatorOverloading_ShouldWorkOnDerivedClasses()
{
RunTest(@"
var v1 = new Vector2Child(1, 2);
var v2 = new Vector2Child(3, 4);
var n = 5;

var v1v2 = v1 + v2;
var v1n = v1 + n;

// Uses the (Vector2 + Vector2) operator on the parent class
equal(4, v1v2.X);
equal(6, v1v2.Y);

// Uses the (Vector2Child + double) operator on the child class
equal(11, v1n.X);
equal(12, v1n.Y);
");
}

}
}
6 changes: 6 additions & 0 deletions Jint/Extensions/ReflectionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ internal static IEnumerable<MethodInfo> GetExtensionMethods(this Type type)
.Where(m => m.IsExtensionMethod());
}

internal static IEnumerable<MethodInfo> GetOperatorOverloadMethods(this Type type)
{
return type.GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy)
.Where(m => m.IsSpecialName);
}

private static bool IsExtensionMethod(this MethodBase methodInfo)
{
return methodInfo.IsDefined(typeof(ExtensionAttribute), true);
Expand Down
11 changes: 10 additions & 1 deletion Jint/Options.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ public sealed class Options
private DebuggerStatementHandling _debuggerStatementHandling;
private bool _allowClr;
private bool _allowClrWrite = true;
private bool _allowOperatorOverloading;
private readonly List<IObjectConverter> _objectConverters = new();
private Func<Engine, object, ObjectInstance> _wrapObjectHandler;
private MemberAccessorDelegate _memberAccessor;
Expand Down Expand Up @@ -127,7 +128,7 @@ PropertyDescriptor CreateMethodInstancePropertyDescriptor(ClrFunctionInstance cl
JsValue key = overloads.Key;
PropertyDescriptor descriptorWithFallback = null;
PropertyDescriptor descriptorWithoutFallback = null;

if (prototype.HasOwnProperty(key) && prototype.GetOwnProperty(key).Value is ClrFunctionInstance clrFunctionInstance)
{
descriptorWithFallback = CreateMethodInstancePropertyDescriptor(clrFunctionInstance);
Expand Down Expand Up @@ -210,6 +211,12 @@ public Options AllowClrWrite(bool allow = true)
return this;
}

public Options AllowOperatorOverloading(bool allow = true)
{
_allowOperatorOverloading = allow;
return this;
}

/// <summary>
/// Exceptions thrown from CLR code are converted to JavaScript errors and
/// can be used in at try/catch statement. By default these exceptions are bubbled
Expand Down Expand Up @@ -335,6 +342,8 @@ internal void Apply(Engine engine)

internal bool _IsClrWriteAllowed => _allowClrWrite;

internal bool _IsOperatorOverloadingAllowed => _allowOperatorOverloading;

internal Predicate<Exception> _ClrExceptionsHandler => _clrExceptionsHandler;

internal List<Assembly> _LookupAssemblies => _lookupAssemblies;
Expand Down