Skip to content

Commit

Permalink
chore: refactor collection support
Browse files Browse the repository at this point in the history
  • Loading branch information
latonz committed Jun 15, 2023
1 parent b3e1657 commit 791dcc8
Show file tree
Hide file tree
Showing 27 changed files with 439 additions and 287 deletions.
40 changes: 40 additions & 0 deletions src/Riok.Mapperly/Descriptors/Enumerables/CollectionInfo.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
using Microsoft.CodeAnalysis;
using Riok.Mapperly.Helpers;

namespace Riok.Mapperly.Descriptors.Enumerables;

public record CollectionInfo(
CollectionType Type,
CollectionType ImplementedTypes,
ITypeSymbol EnumeratedType,
bool CountIsKnown,
bool HasImplicitCollectionAddMethod,
bool IsImmutableCollectionType
)
{
public (ITypeSymbol, ITypeSymbol)? GetDictionaryKeyValueTypes(MappingBuilderContext ctx, ITypeSymbol t)
{
if (t.ImplementsGeneric(ctx.Types.Get(typeof(IDictionary<,>)), out var dictionaryImpl))
{
return (dictionaryImpl.TypeArguments[0], dictionaryImpl.TypeArguments[1]);
}

if (t.ImplementsGeneric(ctx.Types.Get(typeof(IReadOnlyDictionary<,>)), out var readOnlyDictionaryImpl))
{
return (readOnlyDictionaryImpl.TypeArguments[0], readOnlyDictionaryImpl.TypeArguments[1]);
}

return null;
}

public (ITypeSymbol, ITypeSymbol)? GetEnumeratedKeyValueTypes(WellKnownTypes types)
{
if (
EnumeratedType is not INamedTypeSymbol namedEnumeratedType
|| !SymbolEqualityComparer.Default.Equals(namedEnumeratedType.ConstructedFrom, types.Get(typeof(KeyValuePair<,>)))
)
return null;

return (namedEnumeratedType.TypeArguments[0], namedEnumeratedType.TypeArguments[1]);
}
}
162 changes: 162 additions & 0 deletions src/Riok.Mapperly/Descriptors/Enumerables/CollectionInfoBuilder.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
using System.Collections.Immutable;
using Microsoft.CodeAnalysis;
using Riok.Mapperly.Helpers;

namespace Riok.Mapperly.Descriptors.Enumerables;

public static class CollectionInfoBuilder
{
private record CollectionTypeInfo(CollectionType CollectionType, Type? ReflectionType = null, string? TypeFullName = null)
{
public INamedTypeSymbol? GetTypeSymbol(WellKnownTypes types)
{
if (ReflectionType != null)
return types.Get(ReflectionType);

if (TypeFullName != null)
return types.TryGet(TypeFullName);

throw new InvalidOperationException("One type needs to be set for each collection type");
}
}

private static readonly IReadOnlyCollection<CollectionTypeInfo> _collectionTypeInfos = new[]
{
new CollectionTypeInfo(CollectionType.List, typeof(List<>)),
new CollectionTypeInfo(CollectionType.Stack, typeof(Stack<>)),
new CollectionTypeInfo(CollectionType.Queue, typeof(Queue<>)),
new CollectionTypeInfo(CollectionType.IReadOnlyCollection, typeof(IReadOnlyCollection<>)),
new CollectionTypeInfo(CollectionType.IList, typeof(IList<>)),
new CollectionTypeInfo(CollectionType.IReadOnlyList, typeof(IReadOnlyList<>)),
new CollectionTypeInfo(CollectionType.ICollection, typeof(ICollection<>)),
new CollectionTypeInfo(CollectionType.HashSet, typeof(HashSet<>)),
new CollectionTypeInfo(CollectionType.SortedSet, typeof(SortedSet<>)),
new CollectionTypeInfo(CollectionType.ISet, typeof(ISet<>)),
new CollectionTypeInfo(CollectionType.IDictionary, typeof(IDictionary<,>)),
new CollectionTypeInfo(CollectionType.IReadOnlyDictionary, typeof(IReadOnlyDictionary<,>)),
new CollectionTypeInfo(CollectionType.Dictionary, typeof(Dictionary<,>)),
new CollectionTypeInfo(CollectionType.ImmutableArray, typeof(ImmutableArray<>)),
new CollectionTypeInfo(CollectionType.ImmutableList, typeof(ImmutableList<>)),
new CollectionTypeInfo(CollectionType.IImmutableList, typeof(IImmutableList<>)),
new CollectionTypeInfo(CollectionType.ImmutableHashSet, typeof(ImmutableHashSet<>)),
new CollectionTypeInfo(CollectionType.IImmutableSet, typeof(IImmutableSet<>)),
new CollectionTypeInfo(CollectionType.ImmutableSortedSet, typeof(ImmutableSortedSet<>)),
new CollectionTypeInfo(CollectionType.ImmutableQueue, typeof(ImmutableQueue<>)),
new CollectionTypeInfo(CollectionType.IImmutableQueue, typeof(IImmutableQueue<>)),
new CollectionTypeInfo(CollectionType.IImmutableQueue, typeof(IImmutableQueue<>)),
new CollectionTypeInfo(CollectionType.ImmutableStack, typeof(ImmutableStack<>)),
new CollectionTypeInfo(CollectionType.IImmutableStack, typeof(IImmutableStack<>)),
new CollectionTypeInfo(CollectionType.IImmutableDictionary, typeof(IImmutableDictionary<,>)),
new CollectionTypeInfo(CollectionType.ImmutableDictionary, typeof(ImmutableDictionary<,>)),
new CollectionTypeInfo(CollectionType.IEnumerable, typeof(IEnumerable<>)),
new CollectionTypeInfo(CollectionType.IReadOnlySet, TypeFullName: "System.Collections.Generic.IReadOnlySet`1"),
};

public static CollectionInfos? Build(WellKnownTypes wellKnownTypes, ITypeSymbol source, ITypeSymbol target)
{
if (Build(wellKnownTypes, source) is not { } sourceInfo)
return null;

if (Build(wellKnownTypes, target) is not { } targetInfo)
return null;

return new CollectionInfos(sourceInfo, targetInfo);
}

private static CollectionInfo? Build(WellKnownTypes wellKnownTypes, ITypeSymbol type)
{
var enumeratedType = GetEnumeratedType(wellKnownTypes, type);
if (enumeratedType == null)
return null;

var collectionType = GetCollectionType(wellKnownTypes, type);
return new CollectionInfo(
collectionType,
GetImplementedCollectionTypes(wellKnownTypes, type),
enumeratedType,
IsCountKnown(wellKnownTypes, type),
HasValidAddMethod(wellKnownTypes, type),
IsImmutable(collectionType)
);
}

private static ITypeSymbol? GetEnumeratedType(WellKnownTypes types, ITypeSymbol type)
{
return type.ImplementsGeneric(types.Get(typeof(IEnumerable<>)), out var enumerableIntf) ? enumerableIntf.TypeArguments[0] : null;
}

private static bool HasValidAddMethod(WellKnownTypes types, ITypeSymbol t)
{
return t.HasImplicitGenericImplementation(types.Get(typeof(ICollection<>)), nameof(ICollection<object>.Add))
|| t.HasImplicitGenericImplementation(types.Get(typeof(ISet<>)), nameof(ISet<object>.Add));
}

private static bool IsCountKnown(WellKnownTypes types, ITypeSymbol t)
{
var intType = types.Get<int>();
return t.GetAccessibleMappableMembers()
.Any(
x =>
x.Name is nameof(ICollection<object>.Count) or nameof(Array.Length)
&& SymbolEqualityComparer.IncludeNullability.Equals(intType, x.Type)
);
}

private static bool IsImmutable(CollectionType t)
{
switch (t)
{
case CollectionType.ImmutableArray:
case CollectionType.ImmutableList:
case CollectionType.IImmutableList:
case CollectionType.ImmutableHashSet:
case CollectionType.IImmutableSet:
case CollectionType.ImmutableSortedSet:
case CollectionType.ImmutableQueue:
case CollectionType.IImmutableQueue:
case CollectionType.ImmutableStack:
case CollectionType.IImmutableStack:
case CollectionType.ImmutableDictionary:
case CollectionType.IImmutableDictionary:
return true;

default:
return false;
}
}

private static CollectionType GetCollectionType(WellKnownTypes types, ITypeSymbol type)
{
if (type.IsArrayType())
return CollectionType.Array;

foreach (var typeInfo in _collectionTypeInfos)
{
if (typeInfo.GetTypeSymbol(types) is not { } typeSymbol)
continue;

if (SymbolEqualityComparer.Default.Equals(type.OriginalDefinition, typeSymbol))
return typeInfo.CollectionType;
}

return CollectionType.None;
}

private static CollectionType GetImplementedCollectionTypes(WellKnownTypes types, ITypeSymbol type)
{
var implementedCollectionTypes = type.IsArrayType() ? CollectionType.Array : CollectionType.None;

foreach (var typeInfo in _collectionTypeInfos)
{
if (typeInfo.GetTypeSymbol(types) is not { } typeSymbol)
continue;

if (type.ImplementsGeneric(typeSymbol, out _))
{
implementedCollectionTypes |= typeInfo.CollectionType;
}
}

return implementedCollectionTypes;
}
}
3 changes: 3 additions & 0 deletions src/Riok.Mapperly/Descriptors/Enumerables/CollectionInfos.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
namespace Riok.Mapperly.Descriptors.Enumerables;

public record CollectionInfos(CollectionInfo Source, CollectionInfo Target);
43 changes: 43 additions & 0 deletions src/Riok.Mapperly/Descriptors/Enumerables/CollectionType.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
namespace Riok.Mapperly.Descriptors.Enumerables;

[Flags]
public enum CollectionType
{
None = 0,
Array = 1 << 0,
IEnumerable = 1 << 1,

// collections
List = 1 << 2,
Stack = 1 << 3,
Queue = 1 << 4,
IReadOnlyCollection = 1 << 5,
IList = 1 << 6,
IReadOnlyList = 1 << 7,
ICollection = 1 << 8,

// sets
HashSet = 1 << 9,
SortedSet = 1 << 10,
IReadOnlySet = 1 << 11,
ISet = 1 << 12,

// dictionaries
IDictionary = 1 << 13,
IReadOnlyDictionary = 1 << 14,
Dictionary = 1 << 15,

// immutable
ImmutableArray = 1 << 16,
ImmutableList = 1 << 17,
IImmutableList = 1 << 18,
ImmutableHashSet = 1 << 19,
IImmutableSet = 1 << 20,
ImmutableSortedSet = 1 << 21,
ImmutableQueue = 1 << 22,
IImmutableQueue = 1 << 23,
ImmutableStack = 1 << 24,
IImmutableStack = 1 << 25,
ImmutableDictionary = 1 << 26,
IImmutableDictionary = 1 << 27,
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
namespace Riok.Mapperly.Descriptors.Enumerables.EnsureCapacity;

/// <summary>
/// Generates an <see cref="EnsureCapacity?"/> of types <see cref="EnsureCapacityNonEnumerated"/> or <see cref="EnsureCapacityMember"/> depending on type information.
/// Generates an <see cref="EnsureCapacity"/> of types <see cref="EnsureCapacityNonEnumerated"/> or <see cref="EnsureCapacityMember"/> depending on type information.
/// </summary>
public static class EnsureCapacityBuilder
{
Expand All @@ -14,28 +14,27 @@ public static class EnsureCapacityBuilder
private const string LengthPropertyName = nameof(Array.Length);
private const string TryGetNonEnumeratedCountMethodName = "TryGetNonEnumeratedCount";

public static EnsureCapacity? TryBuildEnsureCapacity(ITypeSymbol sourceType, ITypeSymbol targetType, WellKnownTypes types)
public static EnsureCapacity? TryBuildEnsureCapacity(MappingBuilderContext ctx)
{
var capacityMethod = targetType
var capacityMethod = ctx.Target
.GetAllMethods(EnsureCapacityName)
.OfType<IMethodSymbol>()
.FirstOrDefault(x => x.Parameters.Length == 1 && x.Parameters[0].Type.SpecialType == SpecialType.System_Int32 && !x.IsStatic);

// if EnsureCapacity is not available then return null
if (capacityMethod == null)
return null;

// if target does not have a count then return null
if (!TryGetNonEnumeratedCount(targetType, types, out var targetSizeProperty))
if (!TryGetNonEnumeratedCount(ctx.Target, ctx.Types, out var targetSizeProperty))
return null;

// if target and source count are known then create a simple EnsureCapacity statement
if (TryGetNonEnumeratedCount(sourceType, types, out var sourceSizeProperty))
if (TryGetNonEnumeratedCount(ctx.Source, ctx.Types, out var sourceSizeProperty))
return new EnsureCapacityMember(targetSizeProperty, sourceSizeProperty);

sourceType.ImplementsGeneric(types.Get(typeof(IEnumerable<>)), out var iEnumerable);
ctx.Source.ImplementsGeneric(ctx.Types.Get(typeof(IEnumerable<>)), out var iEnumerable);

var nonEnumeratedCountMethod = types
var nonEnumeratedCountMethod = ctx.Types
.Get(typeof(Enumerable))
.GetMembers(TryGetNonEnumeratedCountMethodName)
.OfType<IMethodSymbol>()
Expand All @@ -47,7 +46,7 @@ public static class EnsureCapacityBuilder
if (nonEnumeratedCountMethod == null)
return null;

// if source does not have a count use GetNonEnumeratedCount, calling EnusureCapacity if count is available
// if source does not have a count use GetNonEnumeratedCount, calling EnsureCapacity if count is available
var typedNonEnumeratedCount = nonEnumeratedCountMethod.Construct(iEnumerable!.TypeArguments.ToArray());
return new EnsureCapacityNonEnumerated(targetSizeProperty, typedNonEnumeratedCount);
}
Expand Down
4 changes: 4 additions & 0 deletions src/Riok.Mapperly/Descriptors/MappingBuilderContext.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System.Diagnostics;
using Microsoft.CodeAnalysis;
using Riok.Mapperly.Configuration;
using Riok.Mapperly.Descriptors.Enumerables;
using Riok.Mapperly.Descriptors.Mappings;
using Riok.Mapperly.Descriptors.Mappings.ExistingTarget;
using Riok.Mapperly.Descriptors.Mappings.UserMappings;
Expand All @@ -14,6 +15,7 @@ namespace Riok.Mapperly.Descriptors;
public class MappingBuilderContext : SimpleMappingBuilderContext
{
private readonly IMethodSymbol? _userSymbol;
private CollectionInfos? _collectionInfos;

public MappingBuilderContext(
SimpleMappingBuilderContext parentCtx,
Expand All @@ -40,6 +42,8 @@ protected MappingBuilderContext(MappingBuilderContext ctx, IMethodSymbol? userSy

public ITypeSymbol Target { get; }

public CollectionInfos? CollectionInfos => _collectionInfos ??= CollectionInfoBuilder.Build(Types, Source, Target);

/// <summary>
/// Whether the current mapping code is generated for a <see cref="System.Linq.Expressions.Expression"/>.
/// </summary>
Expand Down
Loading

0 comments on commit 791dcc8

Please sign in to comment.