Skip to content

Commit

Permalink
feat: ISet and IReadOnlySet support (#502)
Browse files Browse the repository at this point in the history
  • Loading branch information
latonz committed Jun 15, 2023
1 parent b77bd67 commit 9550ae4
Show file tree
Hide file tree
Showing 27 changed files with 1,010 additions and 16 deletions.
Expand Up @@ -15,6 +15,7 @@ public static class EnumerableMappingBuilder
private const string SelectMethodName = nameof(Enumerable.Select);
private const string ToArrayMethodName = nameof(Enumerable.ToArray);
private const string ToListMethodName = nameof(Enumerable.ToList);
private const string ToHashSetMethodName = "ToHashSet";
private const string AddMethodName = nameof(ICollection<object>.Add);

private const string ToImmutableArrayMethodName = nameof(ImmutableArray.ToImmutableArray);
Expand Down Expand Up @@ -67,8 +68,8 @@ public static class EnumerableMappingBuilder
return immutableLinqMapping;

// if target is a type that takes IEnumerable in its constructor
if (HasEnumerableConstructor(ctx, elementMapping.TargetType))
return BuildLinqConstructorMapping(ctx, elementMapping);
if (GetTypeConstructableFromEnumerable(ctx, elementMapping.TargetType) is { } constructableType)
return BuildLinqConstructorMapping(ctx, constructableType, elementMapping);

return ctx.IsExpression ? null : BuildCustomTypeMapping(ctx, elementMapping);
}
Expand Down Expand Up @@ -124,22 +125,32 @@ private static LinqEnumerableMapping BuildLinqMapping(MappingBuilderContext ctx,
return new LinqEnumerableMapping(ctx.Source, ctx.Target, elementMapping, selectMethod, collectMethod);
}

private static bool HasEnumerableConstructor(MappingBuilderContext ctx, ITypeSymbol typeSymbol)
private static INamedTypeSymbol? GetTypeConstructableFromEnumerable(MappingBuilderContext ctx, ITypeSymbol typeSymbol)
{
if (ctx.Target is not INamedTypeSymbol namedType)
return false;
return null;

var typedEnumerable = ctx.Types.Get(typeof(IEnumerable<>)).Construct(typeSymbol);

return namedType.Constructors.Any(
var hasCtor = namedType.Constructors.Any(
m => m.Parameters.Length == 1 && SymbolEqualityComparer.Default.Equals(m.Parameters[0].Type, typedEnumerable)
);
if (hasCtor)
return namedType;

if (ctx.CollectionInfos!.Target.Type is CollectionType.ISet or CollectionType.IReadOnlySet)
return ctx.Types.Get(typeof(HashSet<>)).Construct(typeSymbol);

return null;
}

private static LinqConstructorMapping BuildLinqConstructorMapping(MappingBuilderContext ctx, ITypeMapping elementMapping)
private static LinqConstructorMapping BuildLinqConstructorMapping(
MappingBuilderContext ctx,
INamedTypeSymbol targetTypeToConstruct,
ITypeMapping elementMapping
)
{
var selectMethod = elementMapping.IsSynthetic ? null : ctx.Types.Get(typeof(Enumerable)).GetStaticGenericMethod(SelectMethodName);
return new LinqConstructorMapping(ctx.Source, ctx.Target, elementMapping, selectMethod);
return new LinqConstructorMapping(ctx.Source, ctx.Target, targetTypeToConstruct, elementMapping, selectMethod);
}

private static ExistingTargetMappingMethodWrapper? BuildCustomTypeMapping(MappingBuilderContext ctx, ITypeMapping elementMapping)
Expand Down Expand Up @@ -189,6 +200,17 @@ private static (bool CanMapWithLinq, string? CollectMethod) ResolveCollectMethod
if ((targetIsReadOnlyCollection || targetIsIEnumerable) && ctx.CollectionInfos.Source.CountIsKnown)
return (true, ToArrayMethodName);

// if target is Set
// and ToHashSet is supported (only supported for .NET5+)
// use ToHashSet
if (
ctx.CollectionInfos.Target.Type is CollectionType.ISet or CollectionType.IReadOnlySet or CollectionType.HashSet
&& GetToHashSetLinqCollectMethod(ctx.Types) is { } toHashSetMethod
)
{
return (true, toHashSetMethod.Name);
}

// if target is a IReadOnlyCollection<T>, IEnumerable<T>, IList<T>, List<T> or ICollection<T> with ToList()
return
targetIsReadOnlyCollection
Expand Down Expand Up @@ -234,4 +256,7 @@ or CollectionType.ICollection

return null;
}

private static IMethodSymbol? GetToHashSetLinqCollectMethod(WellKnownTypes wellKnownTypes) =>
wellKnownTypes.Get(typeof(Enumerable)).GetStaticGenericMethod(ToHashSetMethodName);
}
12 changes: 10 additions & 2 deletions src/Riok.Mapperly/Descriptors/Mappings/LinqConstructorMapping.cs
Expand Up @@ -10,12 +10,20 @@ namespace Riok.Mapperly.Descriptors.Mappings;
/// </summary>
public class LinqConstructorMapping : TypeMapping
{
private readonly INamedTypeSymbol _targetTypeToConstruct;
private readonly ITypeMapping _elementMapping;
private readonly IMethodSymbol? _selectMethod;

public LinqConstructorMapping(ITypeSymbol sourceType, ITypeSymbol targetType, ITypeMapping elementMapping, IMethodSymbol? selectMethod)
public LinqConstructorMapping(
ITypeSymbol sourceType,
ITypeSymbol targetType,
INamedTypeSymbol targetTypeToConstruct,
ITypeMapping elementMapping,
IMethodSymbol? selectMethod
)
: base(sourceType, targetType)
{
_targetTypeToConstruct = targetTypeToConstruct;
_elementMapping = elementMapping;
_selectMethod = selectMethod;
}
Expand All @@ -37,6 +45,6 @@ public override ExpressionSyntax Build(TypeMappingBuildContext ctx)
mappedSource = _elementMapping.Build(ctx);
}

return CreateInstance(TargetType, mappedSource);
return CreateInstance(_targetTypeToConstruct, mappedSource);
}
}
9 changes: 9 additions & 0 deletions test/Riok.Mapperly.IntegrationTests/BaseMapperTest.cs
Expand Up @@ -104,6 +104,15 @@ public static TestObject NewTestObj()
{ "2", "2" },
{ "3", "3" }
}.ToImmutableSortedDictionary(),
ExistingISet = { "1", "2", "3", },
ExistingHashSet = { "1", "2", "3", },
ExistingSortedSet = { "1", "2", "3", },
ISet = new HashSet<string> { "1", "2", "3", },
#if NET5_0_OR_GREATER
IReadOnlySet = new HashSet<string> { "1", "2", "3", },
#endif
HashSet = new HashSet<string> { "1", "2", "3", },
SortedSet = new SortedSet<string> { "1", "2", "3", }
};
}

Expand Down
16 changes: 16 additions & 0 deletions test/Riok.Mapperly.IntegrationTests/Dto/TestObjectDto.cs
Expand Up @@ -73,6 +73,22 @@ public TestObjectDto(int ctorValue, int unknownValue = 10, int ctorValue2 = 100)

public ImmutableSortedDictionary<int, int> ImmutableSortedDictionaryValue { get; set; } = ImmutableSortedDictionary<int, int>.Empty;

public ISet<int> ExistingISet { get; } = new HashSet<int>();

public HashSet<int> ExistingHashSet { get; } = new HashSet<int>();

public SortedSet<int> ExistingSortedSet { get; } = new SortedSet<int>();

public ISet<int> ISet { get; set; } = new HashSet<int>();

#if NET5_0_OR_GREATER
public IReadOnlySet<int> IReadOnlySet { get; set; } = new HashSet<int>();
#endif

public HashSet<int> HashSet { get; set; } = new HashSet<int>();

public SortedSet<int> SortedSet { get; set; } = new SortedSet<int>();

public TestEnumDtoByValue EnumValue { get; set; }

public TestEnumDtoByName EnumName { get; set; }
Expand Down
16 changes: 16 additions & 0 deletions test/Riok.Mapperly.IntegrationTests/Models/TestObject.cs
Expand Up @@ -71,6 +71,22 @@ public TestObject(int ctorValue, int unknownValue = 10, int ctorValue2 = 100)
public ImmutableSortedDictionary<string, string> ImmutableSortedDictionaryValue { get; set; } =
ImmutableSortedDictionary<string, string>.Empty;

public ISet<string> ExistingISet { get; } = new HashSet<string>();

public HashSet<string> ExistingHashSet { get; } = new HashSet<string>();

public SortedSet<string> ExistingSortedSet { get; } = new SortedSet<string>();

public ISet<string> ISet { get; set; } = new HashSet<string>();

#if NET5_0_OR_GREATER
public IReadOnlySet<string> IReadOnlySet { get; set; } = new HashSet<string>();
#endif

public HashSet<string> HashSet { get; set; } = new HashSet<string>();

public SortedSet<string> SortedSet { get; set; } = new SortedSet<string>();

public TestEnum EnumValue { get; set; }

public TestEnum EnumName { get; set; }
Expand Down
Expand Up @@ -99,6 +99,36 @@
2: 2,
3: 3
},
ExistingISet: [
1,
2,
3
],
ExistingHashSet: [
1,
2,
3
],
ExistingSortedSet: [
1,
2,
3
],
ISet: [
1,
2,
3
],
HashSet: [
1,
2,
3
],
SortedSet: [
1,
2,
3
],
EnumValue: Value10,
EnumName: Value10,
EnumRawValue: Value20,
Expand Down
Expand Up @@ -69,6 +69,24 @@ public static partial class DeepCloningMapper
target.ImmutableSortedSetValue = global::System.Collections.Immutable.ImmutableSortedSet.ToImmutableSortedSet(src.ImmutableSortedSetValue);
target.ImmutableDictionaryValue = global::System.Collections.Immutable.ImmutableDictionary.ToImmutableDictionary(src.ImmutableDictionaryValue);
target.ImmutableSortedDictionaryValue = global::System.Collections.Immutable.ImmutableSortedDictionary.ToImmutableSortedDictionary(src.ImmutableSortedDictionaryValue);
foreach (var item in src.ExistingISet)
{
target.ExistingISet.Add(item);
}

foreach (var item1 in src.ExistingHashSet)
{
target.ExistingHashSet.Add(item1);
}

foreach (var item2 in src.ExistingSortedSet)
{
target.ExistingSortedSet.Add(item2);
}

target.ISet = global::System.Linq.Enumerable.ToHashSet(src.ISet);
target.HashSet = global::System.Linq.Enumerable.ToHashSet(src.HashSet);
target.SortedSet = new global::System.Collections.Generic.SortedSet<string>(src.SortedSet);
target.EnumValue = src.EnumValue;
target.EnumName = src.EnumName;
target.EnumRawValue = src.EnumRawValue;
Expand All @@ -94,4 +112,4 @@ public static partial class DeepCloningMapper
return target;
}
}
}
}
Expand Up @@ -108,6 +108,36 @@
2: 2,
3: 3
},
ExistingISet: [
1,
2,
3
],
ExistingHashSet: [
1,
2,
3
],
ExistingSortedSet: [
1,
2,
3
],
ISet: [
1,
2,
3
],
HashSet: [
1,
2,
3
],
SortedSet: [
1,
2,
3
],
EnumValue: DtoValue1,
EnumName: Value10,
EnumRawValue: 20,
Expand Down
Expand Up @@ -109,6 +109,24 @@ public partial class TestMapper
target.ImmutableSortedSetValue = global::System.Collections.Immutable.ImmutableSortedSet.ToImmutableSortedSet(global::System.Linq.Enumerable.Select(testObject.ImmutableSortedSetValue, x => ParseableInt(x)));
target.ImmutableDictionaryValue = global::System.Collections.Immutable.ImmutableDictionary.ToImmutableDictionary(testObject.ImmutableDictionaryValue, x => ParseableInt(x.Key), x => ParseableInt(x.Value));
target.ImmutableSortedDictionaryValue = global::System.Collections.Immutable.ImmutableSortedDictionary.ToImmutableSortedDictionary(testObject.ImmutableSortedDictionaryValue, x => ParseableInt(x.Key), x => ParseableInt(x.Value));
foreach (var item in testObject.ExistingISet)
{
target.ExistingISet.Add(ParseableInt(item));
}

foreach (var item1 in testObject.ExistingHashSet)
{
target.ExistingHashSet.Add(ParseableInt(item1));
}

foreach (var item2 in testObject.ExistingSortedSet)
{
target.ExistingSortedSet.Add(ParseableInt(item2));
}

target.ISet = global::System.Linq.Enumerable.ToHashSet(global::System.Linq.Enumerable.Select(testObject.ISet, x => ParseableInt(x)));
target.HashSet = global::System.Linq.Enumerable.ToHashSet(global::System.Linq.Enumerable.Select(testObject.HashSet, x => ParseableInt(x)));
target.SortedSet = new global::System.Collections.Generic.SortedSet<int>(global::System.Linq.Enumerable.Select(testObject.SortedSet, x => ParseableInt(x)));
target.EnumValue = (global::Riok.Mapperly.IntegrationTests.Dto.TestEnumDtoByValue)testObject.EnumValue;
target.EnumName = MapToEnumDtoByName(testObject.EnumName);
target.EnumRawValue = (byte)testObject.EnumRawValue;
Expand Down Expand Up @@ -167,6 +185,24 @@ public partial class TestMapper
target.ImmutableSortedSetValue = global::System.Collections.Immutable.ImmutableSortedSet.ToImmutableSortedSet(global::System.Linq.Enumerable.Select(dto.ImmutableSortedSetValue, x => x.ToString()));
target.ImmutableDictionaryValue = global::System.Collections.Immutable.ImmutableDictionary.ToImmutableDictionary(dto.ImmutableDictionaryValue, x => x.Key.ToString(), x => x.Value.ToString());
target.ImmutableSortedDictionaryValue = global::System.Collections.Immutable.ImmutableSortedDictionary.ToImmutableSortedDictionary(dto.ImmutableSortedDictionaryValue, x => x.Key.ToString(), x => x.Value.ToString());
foreach (var item in dto.ExistingISet)
{
target.ExistingISet.Add(item.ToString());
}

foreach (var item1 in dto.ExistingHashSet)
{
target.ExistingHashSet.Add(item1.ToString());
}

foreach (var item2 in dto.ExistingSortedSet)
{
target.ExistingSortedSet.Add(item2.ToString());
}

target.ISet = global::System.Linq.Enumerable.ToHashSet(global::System.Linq.Enumerable.Select(dto.ISet, x => x.ToString()));
target.HashSet = global::System.Linq.Enumerable.ToHashSet(global::System.Linq.Enumerable.Select(dto.HashSet, x => x.ToString()));
target.SortedSet = new global::System.Collections.Generic.SortedSet<string>(global::System.Linq.Enumerable.Select(dto.SortedSet, x => x.ToString()));
target.EnumValue = MapToEnumByValueCheckDefined(dto.EnumValue);
target.EnumName = MapToEnumByNameWithFallback(dto.EnumName);
target.EnumRawValue = (global::Riok.Mapperly.IntegrationTests.Models.TestEnum)dto.EnumRawValue;
Expand Down Expand Up @@ -282,6 +318,24 @@ public partial class TestMapper
target.ImmutableSortedSetValue = global::System.Collections.Immutable.ImmutableSortedSet.ToImmutableSortedSet(global::System.Linq.Enumerable.Select(source.ImmutableSortedSetValue, x => ParseableInt(x)));
target.ImmutableDictionaryValue = global::System.Collections.Immutable.ImmutableDictionary.ToImmutableDictionary(source.ImmutableDictionaryValue, x => ParseableInt(x.Key), x => ParseableInt(x.Value));
target.ImmutableSortedDictionaryValue = global::System.Collections.Immutable.ImmutableSortedDictionary.ToImmutableSortedDictionary(source.ImmutableSortedDictionaryValue, x => ParseableInt(x.Key), x => ParseableInt(x.Value));
foreach (var item in source.ExistingISet)
{
target.ExistingISet.Add(ParseableInt(item));
}

foreach (var item1 in source.ExistingHashSet)
{
target.ExistingHashSet.Add(ParseableInt(item1));
}

foreach (var item2 in source.ExistingSortedSet)
{
target.ExistingSortedSet.Add(ParseableInt(item2));
}

target.ISet = global::System.Linq.Enumerable.ToHashSet(global::System.Linq.Enumerable.Select(source.ISet, x => ParseableInt(x)));
target.HashSet = global::System.Linq.Enumerable.ToHashSet(global::System.Linq.Enumerable.Select(source.HashSet, x => ParseableInt(x)));
target.SortedSet = new global::System.Collections.Generic.SortedSet<int>(global::System.Linq.Enumerable.Select(source.SortedSet, x => ParseableInt(x)));
target.EnumValue = (global::Riok.Mapperly.IntegrationTests.Dto.TestEnumDtoByValue)source.EnumValue;
target.EnumName = MapToEnumDtoByName(source.EnumName);
target.EnumRawValue = (byte)source.EnumRawValue;
Expand Down Expand Up @@ -381,4 +435,4 @@ private string MapToString1(global::Riok.Mapperly.IntegrationTests.Dto.TestEnumD
return target;
}
}
}
}

0 comments on commit 9550ae4

Please sign in to comment.