Skip to content

Commit

Permalink
feat(enum): add ignore case option for enum mappings
Browse files Browse the repository at this point in the history
  • Loading branch information
latonz committed Feb 16, 2022
1 parent baf98fd commit d59cbee
Show file tree
Hide file tree
Showing 16 changed files with 276 additions and 72 deletions.
5 changes: 5 additions & 0 deletions src/Riok.Mapperly.Abstractions/MapEnumAttribute.cs
Expand Up @@ -20,4 +20,9 @@ public MapEnumAttribute(EnumMappingStrategy strategy)
/// The strategy to be used to map enums.
/// </summary>
public EnumMappingStrategy Strategy { get; }

/// <summary>
/// Whether the case should be ignored during mappings.
/// </summary>
public bool IgnoreCase { get; set; } = true;
}
5 changes: 5 additions & 0 deletions src/Riok.Mapperly.Abstractions/MapperAttribute.cs
Expand Up @@ -23,6 +23,11 @@ public sealed class MapperAttribute : Attribute
/// </summary>
public EnumMappingStrategy EnumMappingStrategy { get; set; } = EnumMappingStrategy.ByValue;

/// <summary>
/// Whether the case should be ignored for enum mappings.
/// </summary>
public bool EnumMappingIgnoreCase { get; set; } = true;

/// <summary>
/// Whether to always deep copy objects.
/// Eg. when the type <c>Person[]</c> should be mapped to the same type <c>Person[]</c>,
Expand Down
10 changes: 8 additions & 2 deletions src/Riok.Mapperly/Descriptors/DescriptorBuilder.cs
Expand Up @@ -21,6 +21,8 @@ public class DescriptorBuilder
ImplicitCastMappingBuilder.TryBuildMapping,
ParseMappingBuilder.TryBuildMapping,
CtorMappingBuilder.TryBuildMapping,
StringToEnumMappingBuilder.TryBuildMapping,
EnumToStringMappingBuilder.TryBuildMapping,
EnumMappingBuilder.TryBuildMapping,
ExplicitCastMappingBuilder.TryBuildMapping,
ToStringMappingBuilder.TryBuildMapping,
Expand All @@ -32,7 +34,9 @@ public class DescriptorBuilder
private readonly SyntaxNode _mapperSyntax;
private readonly MapperDescriptor _mapperDescriptor;

// default configurations
// default configurations, used a configuration is needed but no configuration is provided by the user
// these are the default configurations registered for each configuration attribute.
// Usually these are derived from the mapper attribute or default values.
private readonly Dictionary<Type, Attribute> _defaultConfigurations = new();

// this includes mappings to build and already built mappings
Expand Down Expand Up @@ -78,7 +82,9 @@ private MapperAttribute Configure()
_mapperDescriptor.Name = mapperAttribute.ImplementationName ?? BuildName();
_mapperDescriptor.InstanceName = mapperAttribute.InstanceName;

_defaultConfigurations.Add(typeof(MapEnumAttribute), new MapEnumAttribute(mapperAttribute.EnumMappingStrategy));
_defaultConfigurations.Add(
typeof(MapEnumAttribute),
new MapEnumAttribute(mapperAttribute.EnumMappingStrategy) { IgnoreCase = mapperAttribute.EnumMappingIgnoreCase });
return mapperAttribute;
}

Expand Down
77 changes: 39 additions & 38 deletions src/Riok.Mapperly/Descriptors/MappingBuilder/EnumMappingBuilder.cs
@@ -1,4 +1,3 @@
using System.Diagnostics.CodeAnalysis;
using Microsoft.CodeAnalysis;
using Riok.Mapperly.Abstractions;
using Riok.Mapperly.Descriptors.TypeMappings;
Expand All @@ -11,61 +10,63 @@ public static class EnumMappingBuilder
{
public static TypeMapping? TryBuildMapping(MappingBuilderContext ctx)
{
var sourceIsEnum = TryGetEnumType(ctx.Source, out var sourceEnumType);
var targetIsEnum = TryGetEnumType(ctx.Target, out var targetEnumType);
var sourceIsEnum = ctx.Source.TryGetEnumUnderlyingType(out var sourceEnumType);
var targetIsEnum = ctx.Target.TryGetEnumUnderlyingType(out var targetEnumType);

// none is an enum
if (!sourceIsEnum && !targetIsEnum)
return null;

if (sourceIsEnum && targetIsEnum)
// one is an enum, other may be an underlying type (eg. int)
if (!sourceIsEnum || !targetIsEnum)
{
if (SymbolEqualityComparer.Default.Equals(ctx.Source, ctx.Target))
return new DirectAssignmentMapping(ctx.Source);

var config = ctx.GetConfigurationOrDefault<MapEnumAttribute>();
return config.Strategy switch
{
EnumMappingStrategy.ByName => BuildNameMapping(ctx),
_ => new NullDelegateMapping(ctx.Source, ctx.Target, new CastMapping(ctx.Source.NonNullable(), ctx.Target.NonNullable())),
};
return ctx.FindOrBuildMapping(sourceEnumType ?? ctx.Source, targetEnumType ?? ctx.Target) is { } delegateMapping
? new EnumDelegateMapping(ctx.Source, ctx.Target, delegateMapping)
: null;
}

// to string => use an optimized method of Enum.ToString which would use slow reflection
// use Enum.ToString as fallback (for ex. for flags)
if (sourceIsEnum && ctx.Target.SpecialType == SpecialType.System_String)
return new EnumToStringMapping(ctx.Source, ctx.Target, ctx.Source.GetMembers().OfType<IFieldSymbol>());

// from string => use an optimized method of Enum.Parse which would use slow reflection
// however we currently don't support all features of Enum.Parse yet (ex. flags)
// therefore we use Enum.Parse as fallback.
if (targetIsEnum && ctx.Source.SpecialType == SpecialType.System_String)
return new EnumFromStringMapping(ctx.Source, ctx.Target, ctx.Target.GetMembers().OfType<IFieldSymbol>());
// since enums are immutable they can be directly assigned if they are of the same type
if (SymbolEqualityComparer.Default.Equals(ctx.Source, ctx.Target))
return new NullDelegateMapping(ctx.Source, ctx.Target, new DirectAssignmentMapping(ctx.Source.NonNullable()));

return ctx.FindOrBuildMapping(sourceEnumType ?? ctx.Source, targetEnumType ?? ctx.Target) is { } delegateMapping
? new EnumDelegateMapping(ctx.Source, ctx.Target, delegateMapping)
: null;
// map enums by strategy
var config = ctx.GetConfigurationOrDefault<MapEnumAttribute>();
return config.Strategy switch
{
EnumMappingStrategy.ByName => BuildNameMapping(ctx, config.IgnoreCase),
_ => new NullDelegateMapping(ctx.Source, ctx.Target, new CastMapping(ctx.Source.NonNullable(), ctx.Target.NonNullable())),
};
}

private static TypeMapping BuildNameMapping(MappingBuilderContext ctx)
private static TypeMapping BuildNameMapping(MappingBuilderContext ctx, bool ignoreCase)
{
var memberNames = ctx.Source.GetMembers().OfType<IFieldSymbol>().Select(x => x.Name)
.Intersect(ctx.Target.GetMembers().OfType<IFieldSymbol>().Select(x => x.Name))
.ToList();
var targetFieldsByName = ctx.Target.GetMembers().OfType<IFieldSymbol>().ToDictionary(x => x.Name);
Func<IFieldSymbol, IFieldSymbol?> getTargetField;
if (ignoreCase)
{
var targetFieldsByNameIgnoreCase = targetFieldsByName
.DistinctBy(x => x.Key, StringComparer.OrdinalIgnoreCase)
.ToDictionary(x => x.Key, x => x.Value, StringComparer.OrdinalIgnoreCase);
getTargetField = source => targetFieldsByName.GetValueOrDefault(source.Name) ?? targetFieldsByNameIgnoreCase.GetValueOrDefault(source.Name);
}
else
{
getTargetField = source => targetFieldsByName.GetValueOrDefault(source.Name);
}

var enumMemberMappings = ctx.Source.GetMembers().OfType<IFieldSymbol>()
.Select(x => (Source: x, Target: getTargetField(x)))
.Where(x => x.Target != null)
.ToDictionary(x => x.Source.Name, x => x.Target!.Name);

if (memberNames.Count == 0)
if (enumMemberMappings.Count == 0)
{
ctx.ReportDiagnostic(
DiagnosticDescriptors.EnumNameMappingNoOverlappingValuesFound,
ctx.Source,
ctx.Target);
}

return new EnumNameMapping(ctx.Source, ctx.Target, memberNames);
}

private static bool TryGetEnumType(ITypeSymbol t, [NotNullWhen(true)] out INamedTypeSymbol? enumType)
{
enumType = (t.NonNullable() as INamedTypeSymbol)?.EnumUnderlyingType;
return enumType != null;
return new EnumNameMapping(ctx.Source, ctx.Target, enumMemberMappings);
}
}
@@ -0,0 +1,18 @@
using Microsoft.CodeAnalysis;
using Riok.Mapperly.Descriptors.TypeMappings;
using Riok.Mapperly.Helpers;

namespace Riok.Mapperly.Descriptors.MappingBuilder;

public static class EnumToStringMappingBuilder
{
public static TypeMapping? TryBuildMapping(MappingBuilderContext ctx)
{
if (ctx.Target.SpecialType != SpecialType.System_String || !ctx.Source.IsEnum())
return null;

// to string => use an optimized method of Enum.ToString which would use slow reflection
// use Enum.ToString as fallback (for ex. for flags)
return new EnumToStringMapping(ctx.Source, ctx.Target, ctx.Source.GetMembers().OfType<IFieldSymbol>());
}
}
@@ -0,0 +1,22 @@
using Microsoft.CodeAnalysis;
using Riok.Mapperly.Abstractions;
using Riok.Mapperly.Descriptors.TypeMappings;
using Riok.Mapperly.Helpers;

namespace Riok.Mapperly.Descriptors.MappingBuilder;

public static class StringToEnumMappingBuilder
{
public static TypeMapping? TryBuildMapping(MappingBuilderContext ctx)
{
if (ctx.Source.SpecialType != SpecialType.System_String || !ctx.Target.IsEnum())
return null;

// from string => use an optimized method of Enum.Parse which would use slow reflection
// however we currently don't support all features of Enum.Parse yet (ex. flags)
// therefore we use Enum.Parse as fallback.
var members = ctx.Target.GetMembers().OfType<IFieldSymbol>();
var config = ctx.GetConfigurationOrDefault<MapEnumAttribute>();
return new EnumFromStringMapping(ctx.Source, ctx.Target, members, config.IgnoreCase);
}
}
@@ -1,5 +1,6 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Riok.Mapperly.Helpers;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
using static Riok.Mapperly.Emit.SyntaxFactoryHelper;

Expand All @@ -8,50 +9,81 @@ namespace Riok.Mapperly.Descriptors.TypeMappings;
/// <summary>
/// Represents a mapping from a string to an enum.
/// Uses a switch expression for performance reasons (in comparison to <see cref="Enum.Parse(System.Type,string)"/>).
/// Only supports defined enum names (must match exactly) and no flags.
/// </summary>
public class EnumFromStringMapping : MethodMapping
{
private const string EnumClassName = "Enum";
private const string ParseMethodName = "Parse";
private const string IgnoreCaseSwitchDesignatedVariableName = "s";
private const string StringEqualsMethodName = nameof(string.Equals);
private const string StringComparisonFullName = "System.StringComparison.OrdinalIgnoreCase";

private readonly IEnumerable<IFieldSymbol> _enumMembers;
private readonly bool _ignoreCase;

public EnumFromStringMapping(
ITypeSymbol sourceType,
ITypeSymbol targetType,
IEnumerable<IFieldSymbol> enumMembers)
IEnumerable<IFieldSymbol> enumMembers,
bool ignoreCase)
: base(sourceType, targetType)
{
_enumMembers = enumMembers;
_ignoreCase = ignoreCase;
}

public override IEnumerable<StatementSyntax> BuildBody(ExpressionSyntax source)
{
// fallback switch arm: _ => (TargetType)Enum.Parse(typeof(TargetType), source)
// fallback switch arm: _ => (TargetType)Enum.Parse(typeof(TargetType), source, ignoreCase)
var enumParseInvocation = Invocation(
MemberAccess(EnumClassName, ParseMethodName),
TypeOfExpression(IdentifierName(TargetType.ToDisplayString())), source);
TypeOfExpression(IdentifierName(TargetType.ToDisplayString())), source, BooleanLiteral(_ignoreCase));
var fallbackArm = SwitchExpressionArm(
DiscardPattern(),
CastExpression(IdentifierName(TargetType.ToDisplayString()), enumParseInvocation));

// switch for each name to the enum value
// eg: nameof(Enum1.Value1) => Enum1.Value1,
var arms = _enumMembers.Select(BuildArm)
.Append(fallbackArm);
var arms = _ignoreCase
? _enumMembers.Select(BuildArmIgnoreCase)
: _enumMembers.Select(BuildArm);
arms = arms.Append(fallbackArm);

var switchExpr = SwitchExpression(source)
.WithArms(CommaSeparatedList(arms, true));

yield return ReturnStatement(switchExpr);
}

private SwitchExpressionArmSyntax BuildArm(IFieldSymbol field)
private SwitchExpressionArmSyntax BuildArmIgnoreCase(IFieldSymbol field)
{
// { } s
var pattern = RecursivePattern()
.WithPropertyPatternClause(PropertyPatternClause())
.WithDesignation(SingleVariableDesignation(Identifier(IgnoreCaseSwitchDesignatedVariableName)));

// source.Value1
var typeMemberAccess = MemberAccess(
IdentifierName(field.ContainingType.WithNullableAnnotation(NullableAnnotation.None).ToDisplayString()),
field.Name);

// when s.Equals(nameof(source.Value1), StringComparison.OrdinalIgnoreCase)
var whenClause = WhenClause(
Invocation(
MemberAccess(IgnoreCaseSwitchDesignatedVariableName, StringEqualsMethodName),
NameOf(typeMemberAccess),
IdentifierName(StringComparisonFullName)));

// { } s when s.Equals(nameof(source.Value1), StringComparison.OrdinalIgnoreCase) => source.Value1;
return SwitchExpressionArm(pattern, typeMemberAccess)
.WithWhenClause(whenClause);
}

private SwitchExpressionArmSyntax BuildArm(IFieldSymbol field)
{
// nameof(source.Value1) => source.Value1;
var typeMemberAccess = MemberAccess(
IdentifierName(field.ContainingType.NonNullable().ToDisplayString()),
field.Name);
var pattern = ConstantPattern(NameOf(typeMemberAccess));
return SwitchExpressionArm(pattern, typeMemberAccess);
}
Expand Down
16 changes: 8 additions & 8 deletions src/Riok.Mapperly/Descriptors/TypeMappings/EnumNameMapping.cs
Expand Up @@ -12,15 +12,15 @@ namespace Riok.Mapperly.Descriptors.TypeMappings;
/// </summary>
public class EnumNameMapping : MethodMapping
{
private readonly IReadOnlyCollection<string> _enumMemberNames;
private readonly IReadOnlyDictionary<string, string> _enumMemberMappings;

public EnumNameMapping(
ITypeSymbol source,
ITypeSymbol target,
IReadOnlyCollection<string> enumMemberNames)
IReadOnlyDictionary<string, string> enumMemberMappings)
: base(source, target)
{
_enumMemberNames = enumMemberNames;
_enumMemberMappings = enumMemberMappings;
}

public override IEnumerable<StatementSyntax> BuildBody(ExpressionSyntax source)
Expand All @@ -31,8 +31,8 @@ public override IEnumerable<StatementSyntax> BuildBody(ExpressionSyntax source)
ThrowArgumentOutOfRangeException(source));

// switch for each name to the enum value
// eg: nameof(Enum1.Value1) => Enum2.Value1,
var arms = _enumMemberNames
// eg: Enum1.Value1 => Enum2.Value1,
var arms = _enumMemberMappings
.Select(BuildArm)
.Append(fallbackArm);

Expand All @@ -42,10 +42,10 @@ public override IEnumerable<StatementSyntax> BuildBody(ExpressionSyntax source)
yield return ReturnStatement(switchExpr);
}

private SwitchExpressionArmSyntax BuildArm(string memberName)
private SwitchExpressionArmSyntax BuildArm(KeyValuePair<string, string> sourceTargetField)
{
var sourceMember = MemberAccess(SourceType.ToDisplayString(), memberName);
var targetMember = MemberAccess(TargetType.ToDisplayString(), memberName);
var sourceMember = MemberAccess(SourceType.ToDisplayString(), sourceTargetField.Key);
var targetMember = MemberAccess(TargetType.ToDisplayString(), sourceTargetField.Value);
var pattern = ConstantPattern(sourceMember);
return SwitchExpressionArm(pattern, targetMember);
}
Expand Down
6 changes: 3 additions & 3 deletions src/Riok.Mapperly/Emit/SyntaxFactoryHelper.cs
Expand Up @@ -48,7 +48,7 @@ public static ExpressionSyntax NullSubstitute(ITypeSymbol t, ExpressionSyntax ar
return DefaultLiteral();

if (t.SpecialType == SpecialType.System_String)
return LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(string.Empty));
return StringLiteral(string.Empty);

return t.HasAccessibleParameterlessConstructor()
? CreateInstance(t)
Expand All @@ -71,8 +71,8 @@ public static LiteralExpressionSyntax NullLiteral()
public static LiteralExpressionSyntax StringLiteral(string content) =>
LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(content));

public static LiteralExpressionSyntax NumericLiteral(int v)
=> LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0));
public static LiteralExpressionSyntax BooleanLiteral(bool b)
=> LiteralExpression(b ? SyntaxKind.TrueLiteralExpression : SyntaxKind.FalseLiteralExpression);

public static AttributeListSyntax ReturnNotNullIfNotNullAttribute(string paramName)
{
Expand Down
10 changes: 4 additions & 6 deletions src/Riok.Mapperly/Helpers/DictionaryExtensions.cs
Expand Up @@ -14,13 +14,11 @@ public static class DictionaryExtensions
return true;
}

public static TValue GetOrAdd<TKey, TValue>(
public static TValue? GetValueOrDefault<TKey, TValue>(
this IDictionary<TKey, TValue> dict,
TKey key,
Func<TValue> valueProvider)
TKey key)
{
return dict.TryGetValue(key, out var value)
? value
: dict[key] = valueProvider();
dict.TryGetValue(key, out var value);
return value;
}
}
7 changes: 5 additions & 2 deletions src/Riok.Mapperly/Helpers/EnumerableExtensions.cs
Expand Up @@ -13,9 +13,12 @@ public static HashSet<T> ToHashSet<T>(this IEnumerable<T> enumerable)
return set;
}

public static IEnumerable<T> DistinctBy<T, TProp>(this IEnumerable<T> enumerable, Func<T, TProp> selector)
public static IEnumerable<T> DistinctBy<T, TProp>(
this IEnumerable<T> enumerable,
Func<T, TProp> selector,
IEqualityComparer<TProp>? equalityComparer = null)
{
var set = new HashSet<TProp>();
var set = new HashSet<TProp>(equalityComparer);
foreach (var item in enumerable)
{
if (set.Add(selector(item)))
Expand Down

0 comments on commit d59cbee

Please sign in to comment.