Skip to content

Commit

Permalink
Preserve custom operators
Browse files Browse the repository at this point in the history
This will keep custom operators on marked types whenever System.Linq.Expressions
is used, and the operator input types are marked.

The behavior is enabled by default, and can be disabled by passing
--disable-operator-discovery.

Addresses dotnet#1821
  • Loading branch information
sbomer committed Jun 30, 2021
1 parent f549b4e commit 5656b89
Show file tree
Hide file tree
Showing 11 changed files with 457 additions and 0 deletions.
180 changes: 180 additions & 0 deletions src/linker/Linker.Steps/DiscoverCustomOperatorsHandler.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using Mono.Cecil;

namespace Mono.Linker.Steps
{
public class DiscoverOperatorsHandler : IMarkHandler
{
LinkContext _context;
bool markOperators;
HashSet<TypeDefinition> _trackedTypesWithOperators;
Dictionary<TypeDefinition, HashSet<MethodDefinition>> _pendingOperatorsForType;

public void Initialize (LinkContext context, MarkContext markContext)
{
_context = context;
_trackedTypesWithOperators = new HashSet<TypeDefinition> ();
_pendingOperatorsForType = new Dictionary<TypeDefinition, HashSet<MethodDefinition>> ();
markContext.RegisterMarkTypeAction (ProcessType);
}

void ProcessType (TypeDefinition type)
{
CheckForLinqExpressions (type);

if (_pendingOperatorsForType.TryGetValue (type, out var pendingOperators)) {
foreach (var customOperator in pendingOperators)
MarkOperator (customOperator);
_pendingOperatorsForType.Remove (type);
}

if (ProcessCustomOperators (type, mark: markOperators) && !markOperators)
_trackedTypesWithOperators.Add (type);
}

void CheckForLinqExpressions (TypeDefinition type)
{
if (markOperators)
return;

if (type.Namespace != "System.Linq.Expressions" || type.Name != "Expression")
return;

markOperators = true;

foreach (var markedType in _trackedTypesWithOperators)
ProcessCustomOperators (markedType, mark: true);

_trackedTypesWithOperators.Clear ();
}

void MarkOperator (MethodDefinition method)
{
_context.Annotations.Mark (method, new DependencyInfo (DependencyKind.PreservedOperator, method.DeclaringType));
}

bool ProcessCustomOperators (TypeDefinition type, bool mark)
{
if (!type.HasMethods)
return false;

bool hasCustomOperators = false;
foreach (var method in type.Methods) {
if (!IsOperator (method, out var otherType))
continue;

if (!mark)
return true;

hasCustomOperators = true;

if (otherType == null || _context.Annotations.IsMarked (otherType)) {
MarkOperator (method);
continue;
}

// Wait until otherType gets marked to mark the operator.
if (!_pendingOperatorsForType.TryGetValue (otherType, out var pendingOperators)) {
pendingOperators = new HashSet<MethodDefinition> ();
_pendingOperatorsForType.Add (otherType, pendingOperators);
}
pendingOperators.Add (method);
}
return hasCustomOperators;
}

TypeDefinition _int32;
TypeDefinition Int32 {
get {
if (_int32 == null)
_int32 = BCL.FindPredefinedType ("System", "Int32", _context);
return _int32;
}
}

bool IsOperator (MethodDefinition method, out TypeDefinition otherType)
{
otherType = null;

if (!method.IsStatic || !method.IsPublic || !method.IsSpecialName || !method.Name.StartsWith ("op_"))
return false;

var operatorName = method.Name.Substring (3);
var self = method.DeclaringType;

switch (operatorName) {
// Unary operators
case "UnaryPlus":
case "UnaryNegation":
case "LogicalNot":
case "OnesComplement":
case "Increment":
case "Decrement":
case "True":
case "False":
// Parameter type of a unary operator must be the declaring type
if (method.Parameters.Count != 1 || _context.TryResolve (method.Parameters[0].ParameterType) != self)
return false;
// ++ and -- must return the declaring type
if (operatorName is "Increment" or "Decrement" && _context.TryResolve (method.ReturnType) != self)
return false;
return true;
// Binary operators
case "Addition":
case "Subtraction":
case "Multiply":
case "Division":
case "Modulus":
case "BitwiseAnd":
case "BitwiseOr":
case "ExclusiveOr":
// take int as right
case "LeftShift":
case "RightShift":
case "Equality":
case "Inequality":
case "LessThan":
case "GreaterThan":
case "LessThanOrEqual":
case "GreaterThanOrEqual":
if (method.Parameters.Count != 2)
return false;
var left = _context.TryResolve (method.Parameters[0].ParameterType);
var right = _context.TryResolve (method.Parameters[1].ParameterType);
if (left == null || right == null)
return false;
// << and >> must take the declaring type and int
if (operatorName is "LeftShift" or "RightShift" && (left != self || right != Int32))
return false;
// At least one argument must be the declaring type
if (left != self && right != self)
return false;
if (left != self)
otherType = left;
if (right != self)
otherType = right;
return true;
// Conversion operators
case "Implicit":
case "Explicit":
if (method.Parameters.Count != 1)
return false;
var source = _context.TryResolve (method.Parameters[0].ParameterType);
var target = _context.TryResolve (method.ReturnType);
if (source == null || target == null)
return false;
// Exactly one of source/target must be the declaring type
if (source == self == (target == self))
return false;
otherType = source == self ? target : source;
return true;
default:
return false;
}
}
}
}
2 changes: 2 additions & 0 deletions src/linker/Linker/DependencyInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ public enum DependencyKind
XmlSerialized = 84, // entry type or member for XML serialization
SerializedRecursiveType = 85, // recursive type kept due to serialization handling
SerializedMember = 86, // field or property kept on a type for serialization

PreservedOperator = 87 // operator method preserved on a type
}

public readonly struct DependencyInfo : IEquatable<DependencyInfo>
Expand Down
9 changes: 9 additions & 0 deletions src/linker/Linker/Driver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,12 @@ protected int SetupContext (ILogger customLogger = null)

continue;

case "--disable-operator-discovery":
if (!GetBoolParam (token, l => context.DisableOperatorDiscovery = l))
return -1;

continue;

case "--ignore-descriptors":
if (!GetBoolParam (token, l => context.IgnoreDescriptors = l))
return -1;
Expand Down Expand Up @@ -732,6 +738,9 @@ protected int SetupContext (ILogger customLogger = null)
if (!context.DisableSerializationDiscovery)
p.MarkHandlers.Add (new DiscoverSerializationHandler ());

if (!context.DisableOperatorDiscovery)
p.MarkHandlers.Add (new DiscoverOperatorsHandler ());

foreach (string custom_step in custom_steps) {
if (!AddCustomStep (p, custom_step))
return -1;
Expand Down
2 changes: 2 additions & 0 deletions src/linker/Linker/LinkContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ public bool IgnoreUnresolved {

public bool DisableSerializationDiscovery { get; set; }

public bool DisableOperatorDiscovery { get; set; }

public bool IgnoreDescriptors { get; set; }

public bool IgnoreSubstitutions { get; set; }
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using Mono.Linker.Tests.Cases.Expectations.Assertions;
using Mono.Linker.Tests.Cases.Expectations.Metadata;

namespace Mono.Linker.Tests.Cases.LinqExpressions
{
[SetupLinkerArgument ("--disable-operator-discovery")]
public class CanDisableOperatorDiscovery
{
public static void Main ()
{
var c = new CustomOperators ();
var expression = typeof (System.Linq.Expressions.Expression);
c = -c;
var t = typeof (TargetType);
}

[KeptMember (".ctor()")]
class CustomOperators
{
[Kept]
public static CustomOperators operator - (CustomOperators c) => null;

public static CustomOperators operator + (CustomOperators c) => null;
public static CustomOperators operator + (CustomOperators left, CustomOperators right) => null;
public static explicit operator TargetType (CustomOperators self) => null;
}

[Kept]
class TargetType { }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
using Mono.Linker.Tests.Cases.Expectations.Assertions;
using Mono.Linker.Tests.Cases.Expectations.Metadata;

namespace Mono.Linker.Tests.Cases.LinqExpressions
{
public class CanPreserveCustomOperators
{
public static void Main ()
{
var t = typeof (CustomOperators);
var expression = typeof (System.Linq.Expressions.Expression);

var t3 = typeof (TargetTypeImplicit);
var t4 = typeof (SourceTypeImplicit);
var t5 = typeof (TargetTypeExplicit);
var t6 = typeof (SourceTypeExplicit);
}

class CustomOperators
{
// Unary operators
[Kept]
public static CustomOperators operator + (CustomOperators c) => null;
[Kept]
public static CustomOperators operator - (CustomOperators c) => null;
[Kept]
public static CustomOperators operator ! (CustomOperators c) => null;
[Kept]
public static CustomOperators operator ~ (CustomOperators c) => null;
[Kept]
public static CustomOperators operator ++ (CustomOperators c) => null;
[Kept]
public static CustomOperators operator -- (CustomOperators c) => null;
[Kept]
public static bool operator true (CustomOperators c) => true;
[Kept]
public static bool operator false (CustomOperators c) => true;

// Binary operators
[Kept]
public static CustomOperators operator + (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator - (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator * (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator / (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator % (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator & (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator | (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator ^ (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator << (CustomOperators value, int shift) => null;
[Kept]
public static CustomOperators operator >> (CustomOperators value, int shift) => null;
[Kept]
public static CustomOperators operator == (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator != (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator < (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator > (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator <= (CustomOperators left, CustomOperators right) => null;
[Kept]
public static CustomOperators operator >= (CustomOperators left, CustomOperators right) => null;

// conversion operators
[Kept]
public static implicit operator TargetTypeImplicit (CustomOperators self) => null;
[Kept]
public static implicit operator CustomOperators (SourceTypeImplicit other) => null;
[Kept]
public static explicit operator TargetTypeExplicit (CustomOperators self) => null;
[Kept]
public static explicit operator CustomOperators (SourceTypeExplicit other) => null;
}

[Kept]
class TargetTypeImplicit { }
[Kept]
class SourceTypeImplicit { }
[Kept]
class TargetTypeExplicit { }
[Kept]
class SourceTypeExplicit { }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using Mono.Linker.Tests.Cases.Expectations.Assertions;

namespace Mono.Linker.Tests.Cases.LinqExpressions
{
public class CanRemoveMethodsNamedLikeCustomOperators
{
public static void Main ()
{
var t = typeof (FakeOperators);
var expression = typeof (System.Linq.Expressions.Expression);
var t1 = typeof (SubtractionType);
var t2 = typeof (TargetType);
}

public class FakeOperators
{
[Kept]
public static FakeOperators operator - (FakeOperators f) => null;

public static FakeOperators op_UnaryPlus (FakeOperators f) => null;
public static FakeOperators op_Addition (FakeOperators left, FakeOperators right) => null;
public static FakeOperators op_Subtraction (FakeOperators left, SubtractionType right) => null;
public static TargetType op_Explicit (FakeOperators self) => null;
}

[Kept]
public class SubtractionType { }
[Kept]
public class TargetType { }
}
}
Loading

0 comments on commit 5656b89

Please sign in to comment.