Skip to content

Commit

Permalink
feat: add mapping to immutable sequences (#305)
Browse files Browse the repository at this point in the history
  • Loading branch information
TimothyMakkison committed Apr 12, 2023
1 parent 393f78d commit 8c8a338
Show file tree
Hide file tree
Showing 26 changed files with 975 additions and 70 deletions.
Expand Up @@ -27,7 +27,7 @@ public static class DictionaryMappingBuilder
.OfType<IPropertySymbol>()
.Any(x => !x.IsStatic && !x.IsIndexer && !x.IsWriteOnly && x.Type.SpecialType == SpecialType.System_Int32);

var targetDictionarySymbol = ctx.Types.Dictionary.Construct(keyMapping.TargetType, valueMapping.TargetType);
var targetDictionarySymbol = ctx.Types.DictionaryT.Construct(keyMapping.TargetType, valueMapping.TargetType);
ctx.ObjectFactories.TryFindObjectFactory(ctx.Source, ctx.Target, out var dictionaryObjectFactory);
return new ForEachSetDictionaryMapping(
ctx.Source,
Expand All @@ -47,7 +47,7 @@ public static class DictionaryMappingBuilder
return null;
}

if (!ctx.Target.ImplementsGeneric(ctx.Types.IDictionary, out _))
if (!ctx.Target.ImplementsGeneric(ctx.Types.IDictionaryT, out _))
return null;

return new ForEachSetDictionaryMapping(
Expand All @@ -61,7 +61,7 @@ public static class DictionaryMappingBuilder

public static IExistingTargetMapping? TryBuildExistingTargetMapping(MappingBuilderContext ctx)
{
if (!ctx.Target.ImplementsGeneric(ctx.Types.IDictionary, out _))
if (!ctx.Target.ImplementsGeneric(ctx.Types.IDictionaryT, out _))
return null;

if (BuildKeyValueMapping(ctx) is not var (keyMapping, valueMapping))
Expand Down Expand Up @@ -98,19 +98,19 @@ private static bool IsDictionaryType(MappingBuilderContext ctx, ITypeSymbol symb
if (symbol is not INamedTypeSymbol namedSymbol)
return false;

return SymbolEqualityComparer.Default.Equals(namedSymbol.ConstructedFrom, ctx.Types.Dictionary)
|| SymbolEqualityComparer.Default.Equals(namedSymbol.ConstructedFrom, ctx.Types.IDictionary)
|| SymbolEqualityComparer.Default.Equals(namedSymbol.ConstructedFrom, ctx.Types.IReadOnlyDictionary);
return SymbolEqualityComparer.Default.Equals(namedSymbol.ConstructedFrom, ctx.Types.DictionaryT)
|| SymbolEqualityComparer.Default.Equals(namedSymbol.ConstructedFrom, ctx.Types.IDictionaryT)
|| SymbolEqualityComparer.Default.Equals(namedSymbol.ConstructedFrom, ctx.Types.IReadOnlyDictionaryT);
}

private static (ITypeSymbol, ITypeSymbol)? GetDictionaryKeyValueTypes(MappingBuilderContext ctx, ITypeSymbol t)
{
if (t.ImplementsGeneric(ctx.Types.IDictionary, out var dictionaryImpl))
if (t.ImplementsGeneric(ctx.Types.IDictionaryT, out var dictionaryImpl))
{
return (dictionaryImpl.TypeArguments[0], dictionaryImpl.TypeArguments[1]);
}

if (t.ImplementsGeneric(ctx.Types.IReadOnlyDictionary, out var readOnlyDictionaryImpl))
if (t.ImplementsGeneric(ctx.Types.IReadOnlyDictionaryT, out var readOnlyDictionaryImpl))
{
return (readOnlyDictionaryImpl.TypeArguments[0], readOnlyDictionaryImpl.TypeArguments[1]);
}
Expand All @@ -120,13 +120,13 @@ private static (ITypeSymbol, ITypeSymbol)? GetDictionaryKeyValueTypes(MappingBui

private static (ITypeSymbol, ITypeSymbol)? GetEnumerableKeyValueTypes(MappingBuilderContext ctx, ITypeSymbol t)
{
if (!t.ImplementsGeneric(ctx.Types.IEnumerable, out var enumerableImpl))
if (!t.ImplementsGeneric(ctx.Types.IEnumerableT, out var enumerableImpl))
return null;

if (enumerableImpl.TypeArguments[0] is not INamedTypeSymbol enumeratedType)
return null;

if (!SymbolEqualityComparer.Default.Equals(enumeratedType.ConstructedFrom, ctx.Types.KeyValuePair))
if (!SymbolEqualityComparer.Default.Equals(enumeratedType.ConstructedFrom, ctx.Types.KeyValuePairT))
return null;

return (enumeratedType.TypeArguments[0], enumeratedType.TypeArguments[1]);
Expand Down
@@ -1,3 +1,4 @@
using System.Collections.Immutable;
using Microsoft.CodeAnalysis;
using Riok.Mapperly.Abstractions;
using Riok.Mapperly.Descriptors.Mappings;
Expand All @@ -12,6 +13,14 @@ public static class EnumerableMappingBuilder
private const string SelectMethodName = nameof(Enumerable.Select);
private const string ToArrayMethodName = nameof(Enumerable.ToArray);
private const string ToListMethodName = nameof(Enumerable.ToList);
private const string AddValueMethodName = nameof(ICollection<object>.Add);

private const string ToImmutableArrayMethodName = nameof(ImmutableArray.ToImmutableArray);
private const string ToImmutableListMethodName = nameof(ImmutableList.ToImmutableList);
private const string ToImmutableHashSetMethodName = nameof(ImmutableHashSet.ToImmutableHashSet);
private const string CreateRangeQueueMethodName = nameof(ImmutableQueue.CreateRange);
private const string CreateRangeStackMethodName = nameof(ImmutableStack.CreateRange);
private const string ToImmutableSortedSetMethodName = nameof(ImmutableSortedSet.ToImmutableSortedSet);

public static TypeMapping? TryBuildMapping(MappingBuilderContext ctx)
{
Expand All @@ -23,13 +32,13 @@ public static class EnumerableMappingBuilder

// if element mapping is synthetic
// and target is an IEnumerable, there is no mapping needed at all.
if (elementMapping.IsSynthetic && SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.IEnumerable))
if (elementMapping.IsSynthetic && SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.IEnumerableT))
return new CastMapping(ctx.Source, ctx.Target);

// if source is an array and target is an array or IReadOnlyCollection faster mappings can be applied
if (!ctx.IsExpression
&& ctx.Source.IsArrayType()
&& (ctx.Target.IsArrayType() || SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.IReadOnlyCollection)))
&& (ctx.Target.IsArrayType() || SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.IReadOnlyCollectionT)))
{
// if element mapping is synthetic
// a single Array.Clone / cast mapping call should be sufficient and fast,
Expand All @@ -48,6 +57,12 @@ public static class EnumerableMappingBuilder
if (canMapWithLinq)
return BuildLinqMapping(ctx, elementMapping, collectMethodName);

// try linq mapping: x.Select(Map).ToImmutableArray/ToImmutableList
// if that doesn't work do a foreach with add calls
var immutableLinqMapping = TryBuildImmutableLinqMapping(ctx, elementMapping);
if (immutableLinqMapping is not null)
return immutableLinqMapping;

return ctx.IsExpression
? null
: BuildCustomTypeMapping(ctx, elementMapping);
Expand All @@ -58,14 +73,17 @@ public static class EnumerableMappingBuilder
if (BuildElementMapping(ctx) is not { } elementMapping)
return null;

if (ctx.Target.ImplementsGeneric(ctx.Types.Stack, out _))
if (ctx.Target.ImplementsGeneric(ctx.Types.StackT, out _))
return new ForEachAddEnumerableExistingTargetMapping(ctx.Source, ctx.Target, elementMapping, nameof(Stack<object>.Push));

if (ctx.Target.ImplementsGeneric(ctx.Types.Queue, out _))
if (ctx.Target.ImplementsGeneric(ctx.Types.QueueT, out _))
return new ForEachAddEnumerableExistingTargetMapping(ctx.Source, ctx.Target, elementMapping, nameof(Queue<object>.Enqueue));

if (ctx.Target.ImplementsGeneric(ctx.Types.ICollection, out _))
return new ForEachAddEnumerableExistingTargetMapping(ctx.Source, ctx.Target, elementMapping, nameof(ICollection<object>.Add));
// create a foreach loop with add calls if source is not an array
// and ICollection.Add(T): void is implemented and not explicit
// ensures add is not called and immutable types
if (!ctx.Target.IsArrayType() && ctx.Target.HasImplicitInterfaceMethod(ctx.Types.ICollectionT, AddValueMethodName))
return new ForEachAddEnumerableExistingTargetMapping(ctx.Source, ctx.Target, elementMapping, AddValueMethodName);

return null;
}
Expand All @@ -90,11 +108,11 @@ public static class EnumerableMappingBuilder
{
var collectMethod = collectMethodName == null
? null
: ResolveLinqMethod(ctx, collectMethodName);
: ResolveStaticMethod(ctx.Types.Enumerable, collectMethodName);

var selectMethod = elementMapping.IsSynthetic
? null
: ResolveLinqMethod(ctx, SelectMethodName);
: ResolveStaticMethod(ctx.Types.Enumerable, SelectMethodName);

return new LinqEnumerableMapping(ctx.Source, ctx.Target, elementMapping, selectMethod, collectMethod);
}
Expand All @@ -109,14 +127,17 @@ public static class EnumerableMappingBuilder
return null;
}

if (ctx.Target.ImplementsGeneric(ctx.Types.Stack, out _))
if (ctx.Target.ImplementsGeneric(ctx.Types.StackT, out _))
return new ForEachAddEnumerableMapping(ctx.Source, ctx.Target, elementMapping, objectFactory, nameof(Stack<object>.Push));

if (ctx.Target.ImplementsGeneric(ctx.Types.Queue, out _))
if (ctx.Target.ImplementsGeneric(ctx.Types.QueueT, out _))
return new ForEachAddEnumerableMapping(ctx.Source, ctx.Target, elementMapping, objectFactory, nameof(Queue<object>.Enqueue));

if (ctx.Target.ImplementsGeneric(ctx.Types.ICollection, out _))
return new ForEachAddEnumerableMapping(ctx.Source, ctx.Target, elementMapping, objectFactory, nameof(ICollection<object>.Add));
// create a foreach loop with add calls if source is not an array
// and ICollection.Add(T): void is implemented and not explicit
// ensures add is not called and immutable types
if (!ctx.Target.IsArrayType() && ctx.Target.HasImplicitInterfaceMethod(ctx.Types.ICollectionT, AddValueMethodName))
return new ForEachAddEnumerableMapping(ctx.Source, ctx.Target, elementMapping, objectFactory, AddValueMethodName);

return null;
}
Expand All @@ -128,45 +149,80 @@ private static (bool CanMapWithLinq, string? CollectMethod) ResolveCollectMethod
return (true, ToArrayMethodName);

// if the target is an IEnumerable<T> don't collect at all.
if (SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.IEnumerable))
if (SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.IEnumerableT))
return (true, null);

// if the target is IReadOnlyCollection<T>
// and the count of the source is known (array, IReadOnlyCollection<T>, ICollection<T>) we collect to array
// for performance/space reasons
var targetIsReadOnlyCollection = SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.IReadOnlyCollection);
var targetIsReadOnlyCollection = SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.IReadOnlyCollectionT);
var sourceCountIsKnown =
ctx.Source.IsArrayType()
|| ctx.Source.ImplementsGeneric(ctx.Types.IReadOnlyCollection, out _)
|| ctx.Source.ImplementsGeneric(ctx.Types.ICollection, out _);
|| ctx.Source.ImplementsGeneric(ctx.Types.IReadOnlyCollectionT, out _)
|| ctx.Source.ImplementsGeneric(ctx.Types.ICollectionT, out _);
if (targetIsReadOnlyCollection && sourceCountIsKnown)
return (true, ToArrayMethodName);

// if target is a IReadOnlyCollection<T>, IList<T>, List<T> or ICollection<T> with ToList()
return targetIsReadOnlyCollection
|| SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.IReadOnlyList)
|| SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.IList)
|| SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.List)
|| SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.ICollection)
|| SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.IReadOnlyListT)
|| SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.IListT)
|| SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.ListT)
|| SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.ICollectionT)
? (true, ToListMethodName)
: (false, null);
}

private static IMethodSymbol? ResolveLinqMethod(MappingBuilderContext ctx, string methodName)
private static LinqEnumerableMapping? TryBuildImmutableLinqMapping(
MappingBuilderContext ctx,
ITypeMapping elementMapping)
{
var collectMethod = ResolveImmutableCollectMethod(ctx);
if (collectMethod is null)
return null;

var selectMethod = elementMapping.IsSynthetic
? null
: ResolveStaticMethod(ctx.Types.Enumerable, SelectMethodName);

return new LinqEnumerableMapping(ctx.Source, ctx.Target, elementMapping, selectMethod, collectMethod);
}

private static IMethodSymbol? ResolveImmutableCollectMethod(MappingBuilderContext ctx)
{
if (SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.ImmutableArrayT))
return ResolveStaticMethod(ctx.Types.ImmutableArray, ToImmutableArrayMethodName);

if (SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.ImmutableListT))
return ResolveStaticMethod(ctx.Types.ImmutableList, ToImmutableListMethodName);

if (SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.ImmutableHashSetT))
return ResolveStaticMethod(ctx.Types.ImmutableHashSet, ToImmutableHashSetMethodName);

if (SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.ImmutableQueueT))
return ResolveStaticMethod(ctx.Types.ImmutableQueue, CreateRangeQueueMethodName);

if (SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.ImmutableStackT))
return ResolveStaticMethod(ctx.Types.ImmutableStack, CreateRangeStackMethodName);

if (SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.ImmutableSortedSetT))
return ResolveStaticMethod(ctx.Types.ImmutableSortedSet, ToImmutableSortedSetMethodName);

return null;
}

private static IMethodSymbol? ResolveStaticMethod(INamedTypeSymbol namedType, string methodName)
{
var method = ctx.Types.Enumerable
.GetMembers(methodName)
.OfType<IMethodSymbol>()
.FirstOrDefault(m =>
m.IsStatic
&& m.IsGenericMethod);
var method = namedType.GetMembers(methodName)
.OfType<IMethodSymbol>()
.FirstOrDefault(m => m.IsStatic && m.IsGenericMethod);

return method;
}

private static ITypeSymbol? GetEnumeratedType(MappingBuilderContext ctx, ITypeSymbol type)
{
return type.ImplementsGeneric(ctx.Types.IEnumerable, out var enumerableIntf)
return type.ImplementsGeneric(ctx.Types.IEnumerableT, out var enumerableIntf)
? enumerableIntf.TypeArguments[0]
: null;
}
Expand Down
Expand Up @@ -12,10 +12,10 @@ public static class QueryableMappingBuilder
if (!ctx.IsConversionEnabled(MappingConversionType.Queryable))
return null;

if (!ctx.Source.ImplementsGeneric(ctx.Types.IQueryable, out var sourceQueryable))
if (!ctx.Source.ImplementsGeneric(ctx.Types.IQueryableT, out var sourceQueryable))
return null;

if (!ctx.Target.ImplementsGeneric(ctx.Types.IQueryable, out var targetQueryable))
if (!ctx.Target.ImplementsGeneric(ctx.Types.IQueryableT, out var targetQueryable))
return null;

var sourceType = sourceQueryable.TypeArguments[0];
Expand Down

0 comments on commit 8c8a338

Please sign in to comment.