Skip to content

Commit

Permalink
feat: optimize enum ByValueCheckDefined strategy and support flags (#510
Browse files Browse the repository at this point in the history
)
  • Loading branch information
latonz committed Jun 21, 2023
1 parent 66124f8 commit 602c79d
Show file tree
Hide file tree
Showing 34 changed files with 474 additions and 256 deletions.
2 changes: 1 addition & 1 deletion docs/docs/02-configuration/04-enum.mdx
Expand Up @@ -88,7 +88,7 @@ public partial class CarMapper
}
```

### Strict enum mappings
## Strict enum mappings

To enforce strict enum mappings
(all source enum values have to be mapped to a target enum value
Expand Down
Expand Up @@ -72,6 +72,7 @@ private static TypeMapping BuildCastMappingAndDiagnostic(MappingBuilderContext c
.OfType<IFieldSymbol>()
.Where(x => !explicitMappingTargetNames.Contains(x.Name))
.ToDictionary(field => field.Name, field => field.ConstantValue);
var targetMemberNames = ctx.Target.GetMembers().OfType<IFieldSymbol>().Select(x => x.Name).ToHashSet();

var missingTargetValues = targetValues.Where(
field =>
Expand All @@ -95,7 +96,14 @@ private static TypeMapping BuildCastMappingAndDiagnostic(MappingBuilderContext c
checkTargetDefined = true;
}

var castFallbackMapping = new EnumCastMapping(ctx.Types.Get<Enum>(), ctx.Source, ctx.Target, checkTargetDefined, fallbackMapping);
var checkDefinedMode = checkTargetDefined switch
{
false => EnumCastMapping.CheckDefinedMode.NoCheck,
_ when ctx.Target.HasAttribute(ctx.Types.Get<FlagsAttribute>()) => EnumCastMapping.CheckDefinedMode.Flags,
_ => EnumCastMapping.CheckDefinedMode.Value,
};

var castFallbackMapping = new EnumCastMapping(ctx.Source, ctx.Target, checkDefinedMode, targetMemberNames, fallbackMapping);
if (explicitMappings.Count == 0)
return castFallbackMapping;

Expand Down
62 changes: 50 additions & 12 deletions src/Riok.Mapperly/Descriptors/Mappings/Enums/EnumCastMapping.cs
Expand Up @@ -5,36 +5,74 @@

namespace Riok.Mapperly.Descriptors.Mappings.Enums;

/// <summary>
/// An enum cast mapping which casts the source to the target type and optionally checks whether the target is defined.
/// If it is not defined an optional fallback value is used.
/// </summary>
public class EnumCastMapping : CastMapping
{
private const string IsDefinedMethodName = nameof(Enum.IsDefined);

private readonly ITypeSymbol _enumType;
private readonly bool _checkDefined;
private readonly CheckDefinedMode _checkDefinedMode;
private readonly IReadOnlyCollection<string> _targetEnumMemberNames;
private readonly EnumFallbackValueMapping _fallback;

public enum CheckDefinedMode
{
/// <summary>
/// No check is performed at all, the value is just casted.
/// </summary>
NoCheck,

/// <summary>
/// It is checked if the casted value is defined in the target enum.
/// </summary>
Value,

/// <summary>
/// It is checked if the casted value is a defined flags combination of the target enum.
/// </summary>
Flags
}

public EnumCastMapping(
ITypeSymbol enumType,
ITypeSymbol sourceType,
ITypeSymbol targetType,
bool checkDefined,
CheckDefinedMode checkDefinedMode,
IReadOnlyCollection<string> targetEnumMemberNames,
EnumFallbackValueMapping fallback
)
: base(sourceType, targetType)
{
_enumType = enumType;
_checkDefined = checkDefined;
_checkDefinedMode = checkDefinedMode;
_targetEnumMemberNames = targetEnumMemberNames;
_fallback = fallback;
}

public override ExpressionSyntax Build(TypeMappingBuildContext ctx)
{
var casted = base.Build(ctx);
if (!_checkDefined)
if (_checkDefinedMode == CheckDefinedMode.NoCheck)
return casted;

var isDefinedMethod = MemberAccess(FullyQualifiedIdentifier(_enumType), IsDefinedMethodName);
var isDefined = Invocation(isDefinedMethod, TypeOfExpression(FullyQualifiedIdentifier(TargetType)), casted);
return ConditionalExpression(isDefined, casted, _fallback.Build(ctx));
var valueDefinedCondition = BuildIsDefinedCondition(casted);
return ConditionalExpression(valueDefinedCondition, casted, _fallback.Build(ctx));
}

private ExpressionSyntax BuildIsDefinedCondition(ExpressionSyntax convertedSourceValue)
{
var allEnumMembers = _targetEnumMemberNames.Select(x => MemberAccess(FullyQualifiedIdentifier(TargetType), x));
return _checkDefinedMode switch
{
// (TargetEnum)v is TargetEnum.A or TargetEnum.B or ...
CheckDefinedMode.Value
=> IsPatternExpression(convertedSourceValue, OrPattern(allEnumMembers)),

// (TargetEnum)v == ((TargetEnum)v & (TargetEnum.A | TargetEnum.B | ...))
CheckDefinedMode.Flags
=> Equal(
convertedSourceValue,
ParenthesizedExpression(BitwiseAnd(convertedSourceValue, ParenthesizedExpression(BitwiseOr(allEnumMembers))))
),
_ => throw new ArgumentOutOfRangeException($"{nameof(_checkDefinedMode)} has an unknown value {_checkDefinedMode}")
};
}
}
37 changes: 27 additions & 10 deletions src/Riok.Mapperly/Emit/SyntaxFactoryHelper.cs
Expand Up @@ -39,18 +39,29 @@ public static SyntaxToken Accessibility(Accessibility accessibility)
};
}

public static BinaryExpressionSyntax Coalesce(ExpressionSyntax expr, ExpressionSyntax coalesceExpr)
{
return BinaryExpression(SyntaxKind.CoalesceExpression, expr, coalesceExpr);
}
public static BinaryExpressionSyntax Coalesce(ExpressionSyntax expr, ExpressionSyntax coalesceExpr) =>
SyntaxFactory.BinaryExpression(SyntaxKind.CoalesceExpression, expr, coalesceExpr);

public static ExpressionSyntax Or(IEnumerable<ExpressionSyntax?> values) =>
values.WhereNotNull().Aggregate((a, b) => BinaryExpression(SyntaxKind.LogicalOrExpression, a, b));
public static ExpressionSyntax Or(IEnumerable<ExpressionSyntax?> values) => BinaryExpression(SyntaxKind.LogicalOrExpression, values);

public static ExpressionSyntax And(params ExpressionSyntax?[] values) => And((IEnumerable<ExpressionSyntax?>)values);

public static ExpressionSyntax And(IEnumerable<ExpressionSyntax?> values) =>
values.WhereNotNull().Aggregate((a, b) => BinaryExpression(SyntaxKind.LogicalAndExpression, a, b));
public static ExpressionSyntax And(IEnumerable<ExpressionSyntax?> values) => BinaryExpression(SyntaxKind.LogicalAndExpression, values);

public static ExpressionSyntax BitwiseAnd(params ExpressionSyntax?[] values) =>
BinaryExpression(SyntaxKind.BitwiseAndExpression, values);

public static ExpressionSyntax BitwiseOr(IEnumerable<ExpressionSyntax?> values) =>
BinaryExpression(SyntaxKind.BitwiseOrExpression, values);

public static PatternSyntax OrPattern(IEnumerable<ExpressionSyntax?> values) =>
values
.WhereNotNull()
.Select<ExpressionSyntax, PatternSyntax>(ConstantPattern)
.Aggregate((left, right) => BinaryPattern(SyntaxKind.OrPattern, left, right));

public static ExpressionSyntax Equal(ExpressionSyntax left, ExpressionSyntax right) =>
BinaryExpression(SyntaxKind.EqualsExpression, left, right);

public static ExpressionSyntax IfNoneNull(params (ITypeSymbol Type, ExpressionSyntax Access)[] values)
{
Expand All @@ -65,10 +76,10 @@ public static ExpressionSyntax IfAnyNull(params (ITypeSymbol Type, ExpressionSyn
}

public static BinaryExpressionSyntax IsNull(ExpressionSyntax expression) =>
BinaryExpression(SyntaxKind.EqualsExpression, expression, NullLiteral());
SyntaxFactory.BinaryExpression(SyntaxKind.EqualsExpression, expression, NullLiteral());

public static BinaryExpressionSyntax IsNotNull(ExpressionSyntax expression) =>
BinaryExpression(SyntaxKind.NotEqualsExpression, expression, NullLiteral());
SyntaxFactory.BinaryExpression(SyntaxKind.NotEqualsExpression, expression, NullLiteral());

public static ExpressionSyntax NullSubstitute(ITypeSymbol t, ExpressionSyntax argument, NullFallbackValue nullFallbackValue)
{
Expand Down Expand Up @@ -389,6 +400,12 @@ public static SeparatedSyntaxList<T> CommaSeparatedList<T>(IEnumerable<T> nodes,
public static IReadOnlyCollection<StatementSyntax> SingleStatement(ExpressionSyntax expression) =>
new[] { ExpressionStatement(expression) };

private static ExpressionSyntax BinaryExpression(SyntaxKind kind, params ExpressionSyntax?[] values) =>
BinaryExpression(kind, (IEnumerable<ExpressionSyntax?>)values);

private static ExpressionSyntax BinaryExpression(SyntaxKind kind, IEnumerable<ExpressionSyntax?> values) =>
values.WhereNotNull().Aggregate((left, right) => SyntaxFactory.BinaryExpression(kind, left, right));

private static InterpolatedStringTextSyntax InterpolatedStringText(string text) =>
SyntaxFactory.InterpolatedStringText(
Token(SyntaxTriviaList.Empty, SyntaxKind.InterpolatedStringTextToken, text, text, SyntaxTriviaList.Empty)
Expand Down
1 change: 1 addition & 0 deletions test/Riok.Mapperly.IntegrationTests/BaseMapperTest.cs
Expand Up @@ -53,6 +53,7 @@ public static TestObject NewTestObj()
IntValue = 10,
EnumName = TestEnum.Value10,
EnumValue = TestEnum.Value10,
FlagsEnumValue = TestFlagsEnum.V1 | TestFlagsEnum.V4,
IntInitOnlyValue = 3,
RequiredValue = 4,
NestedNullable = new TestObjectNested { IntValue = 100, },
Expand Down
14 changes: 14 additions & 0 deletions test/Riok.Mapperly.IntegrationTests/Dto/TestFlagsEnumDto.cs
@@ -0,0 +1,14 @@
using System;

namespace Riok.Mapperly.IntegrationTests.Dto
{
[Flags]
public enum TestFlagsEnumDto
{
V1 = 1 << 0,
V2 = 1 << 1,

// use another name to test mapping by value
V3 = 1 << 2,
}
}
2 changes: 2 additions & 0 deletions test/Riok.Mapperly.IntegrationTests/Dto/TestObjectDto.cs
Expand Up @@ -91,6 +91,8 @@ public TestObjectDto(int ctorValue, int unknownValue = 10, int ctorValue2 = 100)

public TestEnumDtoByValue EnumValue { get; set; }

public TestFlagsEnumDto FlagsEnumValue { get; set; }

public TestEnumDtoByName EnumName { get; set; }

public byte EnumRawValue { get; set; }
Expand Down
26 changes: 23 additions & 3 deletions test/Riok.Mapperly.IntegrationTests/Mapper/StaticTestMapper.cs
Expand Up @@ -65,9 +65,6 @@ public static TestObjectDto MapToDto(TestObject src)
[MapperIgnoreSource(nameof(TestObjectDto.IgnoredIntValue))]
public static partial TestObject MapFromDto(TestObjectDto dto);

[MapEnum(EnumMappingStrategy.ByName)]
public static partial TestEnumDtoByName MapToEnumDtoByName(TestEnum v);

[MapperIgnoreTarget(nameof(TestObjectDto.IgnoredIntValue))]
[MapperIgnoreSource(nameof(TestObject.IgnoredStringValue))]
public static partial void UpdateDto(TestObject source, TestObjectDto target);
Expand All @@ -88,5 +85,28 @@ public static TestObjectDto MapToDto(TestObject src)
public static partial object? MapNullableWithRuntimeTargetType(object? source, Type targetType);

public static partial TTarget MapGeneric<TSource, TTarget>(TSource source);

[MapEnum(EnumMappingStrategy.ByName)]
public static partial TestEnumDtoByName MapToEnumDtoByName(TestEnum v);

[MapEnum(EnumMappingStrategy.ByName)]
[MapEnumValue(TestEnumDtoAdditionalValue.Value40, TestEnum.Value30)]
public static partial TestEnum MapToEnumByNameWithExplicit(TestEnumDtoAdditionalValue v);

[MapEnum(EnumMappingStrategy.ByValue)]
[MapEnumValue(TestEnumDtoAdditionalValue.Value40, TestEnum.Value30)]
public static partial TestEnum MapToEnumByValueWithExplicit(TestEnumDtoAdditionalValue v);

[MapEnum(EnumMappingStrategy.ByValueCheckDefined)]
public static partial TestEnum MapToEnumByValueCheckDefined(TestEnumDtoByValue v);

[MapEnum(EnumMappingStrategy.ByValueCheckDefined, FallbackValue = TestEnum.Value10)]
public static partial TestEnum MapToEnumByValueCheckDefinedWithFallback(TestEnumDtoByValue v);

[MapEnum(EnumMappingStrategy.ByValueCheckDefined)]
public static partial TestFlagsEnum MapToFlagsEnumByValueCheckDefined(TestFlagsEnumDto v);

[MapEnum(EnumMappingStrategy.ByName, FallbackValue = TestEnum.Value10)]
public static partial TestEnum MapToEnumByNameWithFallback(TestEnumDtoByName v);
}
}
23 changes: 3 additions & 20 deletions test/Riok.Mapperly.IntegrationTests/Mapper/TestMapper.cs
Expand Up @@ -59,30 +59,13 @@ public TestObjectDto MapToDto(TestObject src)
[MapperIgnoreSource(nameof(TestObjectDto.IgnoredIntValue))]
public partial TestObject MapFromDto(TestObjectDto dto);

[MapEnum(EnumMappingStrategy.ByName)]
public partial TestEnumDtoByName MapToEnumDtoByName(TestEnum v);

[MapEnum(EnumMappingStrategy.ByName)]
[MapEnumValue(TestEnumDtoAdditionalValue.Value40, TestEnum.Value30)]
public partial TestEnum MapToEnumByNameWithExplicit(TestEnumDtoAdditionalValue v);

[MapEnum(EnumMappingStrategy.ByValue)]
[MapEnumValue(TestEnumDtoAdditionalValue.Value40, TestEnum.Value30)]
public partial TestEnum MapToEnumByValueWithExplicit(TestEnumDtoAdditionalValue v);

[MapEnum(EnumMappingStrategy.ByValueCheckDefined)]
public partial TestEnum MapToEnumByValueCheckDefined(TestEnumDtoByValue v);

[MapEnum(EnumMappingStrategy.ByValueCheckDefined, FallbackValue = TestEnum.Value10)]
public partial TestEnum MapToEnumByValueCheckDefinedWithFallback(TestEnumDtoByValue v);

[MapEnum(EnumMappingStrategy.ByName, FallbackValue = TestEnum.Value10)]
public partial TestEnum MapToEnumByNameWithFallback(TestEnumDtoByName v);

[MapperIgnoreTarget(nameof(TestObjectDto.IgnoredIntValue))]
[MapperIgnoreSource(nameof(TestObject.IgnoredStringValue))]
public partial void UpdateDto(TestObject source, TestObjectDto target);

[MapEnum(EnumMappingStrategy.ByName)]
public partial TestEnumDtoByName MapToEnumDtoByName(TestEnum v);

private partial int PrivateDirectInt(int value);
}
}
12 changes: 12 additions & 0 deletions test/Riok.Mapperly.IntegrationTests/Models/TestFlagsEnum.cs
@@ -0,0 +1,12 @@
using System;

namespace Riok.Mapperly.IntegrationTests.Models
{
[Flags]
public enum TestFlagsEnum
{
V1 = 1 << 0,
V2 = 1 << 1,
V4 = 1 << 2,
}
}
2 changes: 2 additions & 0 deletions test/Riok.Mapperly.IntegrationTests/Models/TestObject.cs
Expand Up @@ -89,6 +89,8 @@ public TestObject(int ctorValue, int unknownValue = 10, int ctorValue2 = 100)

public TestEnum EnumValue { get; set; }

public TestFlagsEnum FlagsEnumValue { get; set; }

public TestEnum EnumName { get; set; }

public TestEnum EnumRawValue { get; set; }
Expand Down
Expand Up @@ -130,6 +130,7 @@
3
],
EnumValue: Value10,
FlagsEnumValue: V1, V4,
EnumName: Value10,
EnumRawValue: Value20,
EnumStringValue: Value30,
Expand Down
Expand Up @@ -88,6 +88,7 @@ public static partial class DeepCloningMapper
target.HashSet = global::System.Linq.Enumerable.ToHashSet(src.HashSet);
target.SortedSet = new global::System.Collections.Generic.SortedSet<string>(src.SortedSet);
target.EnumValue = src.EnumValue;
target.FlagsEnumValue = src.FlagsEnumValue;
target.EnumName = src.EnumName;
target.EnumRawValue = src.EnumRawValue;
target.EnumStringValue = src.EnumStringValue;
Expand Down
Expand Up @@ -139,6 +139,7 @@
3
],
EnumValue: DtoValue1,
FlagsEnumValue: V1, V3,
EnumName: Value10,
EnumRawValue: 20,
EnumStringValue: Value30,
Expand Down

0 comments on commit 602c79d

Please sign in to comment.