Skip to content

Commit

Permalink
chore: refactor collection support
Browse files Browse the repository at this point in the history
  • Loading branch information
latonz committed Jun 14, 2023
1 parent f1f2369 commit 87834ce
Show file tree
Hide file tree
Showing 9 changed files with 322 additions and 148 deletions.
40 changes: 40 additions & 0 deletions src/Riok.Mapperly/Descriptors/Enumerables/CollectionInfo.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
using Microsoft.CodeAnalysis;
using Riok.Mapperly.Helpers;

namespace Riok.Mapperly.Descriptors.Enumerables;

public record CollectionInfo(
CollectionType Type,
CollectionType ImplementedTypes,
ITypeSymbol EnumeratedType,
bool CountIsKnown,
bool HasValidAddMethod,
bool IsImmutableCollectionType
)
{
public (ITypeSymbol, ITypeSymbol)? GetDictionaryKeyValueTypes(MappingBuilderContext ctx, ITypeSymbol t)
{
if (t.ImplementsGeneric(ctx.Types.Get(typeof(IDictionary<,>)), out var dictionaryImpl))
{
return (dictionaryImpl.TypeArguments[0], dictionaryImpl.TypeArguments[1]);
}

if (t.ImplementsGeneric(ctx.Types.Get(typeof(IReadOnlyDictionary<,>)), out var readOnlyDictionaryImpl))
{
return (readOnlyDictionaryImpl.TypeArguments[0], readOnlyDictionaryImpl.TypeArguments[1]);
}

return null;
}

public (ITypeSymbol, ITypeSymbol)? GetEnumeratedKeyValueTypes(WellKnownTypes types)
{
if (
EnumeratedType is not INamedTypeSymbol namedEnumeratedType
|| !SymbolEqualityComparer.Default.Equals(namedEnumeratedType.ConstructedFrom, types.Get(typeof(KeyValuePair<,>)))
)
return null;

return (namedEnumeratedType.TypeArguments[0], namedEnumeratedType.TypeArguments[1]);
}
}
165 changes: 165 additions & 0 deletions src/Riok.Mapperly/Descriptors/Enumerables/CollectionInfoBuilder.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
using System.Collections.Immutable;
using Microsoft.CodeAnalysis;
using Riok.Mapperly.Helpers;

namespace Riok.Mapperly.Descriptors.Enumerables;

public static class CollectionInfoBuilder
{
private static readonly IReadOnlyDictionary<Type, CollectionType> _collectionTypeMappings = new Dictionary<Type, CollectionType>
{
[typeof(List<>)] = CollectionType.List,
[typeof(Stack<>)] = CollectionType.Stack,
[typeof(Queue<>)] = CollectionType.Queue,
[typeof(IReadOnlyCollection<>)] = CollectionType.IReadOnlyCollection,
[typeof(IList<>)] = CollectionType.IList,
[typeof(IReadOnlyList<>)] = CollectionType.IReadOnlyList,
[typeof(ICollection<>)] = CollectionType.ICollection,
[typeof(HashSet<>)] = CollectionType.HashSet,
[typeof(SortedSet<>)] = CollectionType.SortedSet,
[typeof(ISet<>)] = CollectionType.ISet,
[typeof(IDictionary<,>)] = CollectionType.IDictionary,
[typeof(IReadOnlyDictionary<,>)] = CollectionType.IReadOnlyDictionary,
[typeof(Dictionary<,>)] = CollectionType.Dictionary,
[typeof(ImmutableArray<>)] = CollectionType.ImmutableArray,
[typeof(ImmutableList<>)] = CollectionType.ImmutableList,
[typeof(IImmutableList<>)] = CollectionType.IImmutableList,
[typeof(ImmutableHashSet<>)] = CollectionType.ImmutableHashSet,
[typeof(IImmutableSet<>)] = CollectionType.IImmutableSet,
[typeof(ImmutableSortedSet<>)] = CollectionType.ImmutableSortedSet,
[typeof(ImmutableQueue<>)] = CollectionType.ImmutableQueue,
[typeof(IImmutableQueue<>)] = CollectionType.IImmutableQueue,
[typeof(ImmutableStack<>)] = CollectionType.ImmutableStack,
[typeof(IImmutableStack<>)] = CollectionType.IImmutableStack,
[typeof(IImmutableDictionary<,>)] = CollectionType.IImmutableDictionary,
[typeof(ImmutableDictionary<,>)] = CollectionType.ImmutableDictionary,
[typeof(IEnumerable<>)] = CollectionType.IEnumerable,
};

private static readonly IReadOnlyDictionary<string, CollectionType> _newCollectionTypeMappings = new Dictionary<string, CollectionType>
{
["System.Collections.Generic.IReadOnlySet"] = CollectionType.IReadOnlySet,
};

public static CollectionInfos? Build(WellKnownTypes wellKnownTypes, ITypeSymbol source, ITypeSymbol target)
{
var collectionTypes = GetCollectionTypeMappings(wellKnownTypes);

if (Build(wellKnownTypes, collectionTypes, source) is not { } sourceInfo)
return null;

if (Build(wellKnownTypes, collectionTypes, target) is not { } targetInfo)
return null;

return new CollectionInfos(sourceInfo, targetInfo);
}

private static CollectionInfo? Build(
WellKnownTypes wellKnownTypes,
IReadOnlyCollection<(INamedTypeSymbol Type, CollectionType CollectionType)> collectionTypes,
ITypeSymbol type
)
{
var enumeratedType = GetEnumeratedType(wellKnownTypes, type);
if (enumeratedType == null)
return null;

var collectionType = GetCollectionType(collectionTypes, type);
return new CollectionInfo(
collectionType,
GetImplementedCollectionTypes(collectionTypes, type),
enumeratedType,
IsCountKnown(wellKnownTypes, type),
HasValidAddMethod(wellKnownTypes, type),
IsImmutable(collectionType)
);
}

private static ITypeSymbol? GetEnumeratedType(WellKnownTypes types, ITypeSymbol type)
{
return type.ImplementsGeneric(types.Get(typeof(IEnumerable<>)), out var enumerableIntf) ? enumerableIntf.TypeArguments[0] : null;
}

private static bool HasValidAddMethod(WellKnownTypes types, ITypeSymbol t)
{
return t.HasImplicitGenericImplementation(types.Get(typeof(ICollection<>)), nameof(ICollection<object>.Add))
|| t.HasImplicitGenericImplementation(types.Get(typeof(ISet<>)), nameof(ISet<object>.Add));
}

private static bool IsCountKnown(WellKnownTypes types, ITypeSymbol t)
{
var intType = types.Get<int>();
return t.GetAccessibleMappableMembers()
.Any(
x =>
x.Name is nameof(ICollection<object>.Count) or nameof(Array.Length)
&& SymbolEqualityComparer.IncludeNullability.Equals(intType, x.Type)
);
}

private static bool IsImmutable(CollectionType t)
{
switch (t)
{
case CollectionType.ImmutableArray:
case CollectionType.ImmutableList:
case CollectionType.IImmutableList:
case CollectionType.ImmutableHashSet:
case CollectionType.IImmutableSet:
case CollectionType.ImmutableSortedSet:
case CollectionType.ImmutableQueue:
case CollectionType.IImmutableQueue:
case CollectionType.ImmutableStack:
case CollectionType.IImmutableStack:
return true;
default:
return false;
}
}

private static CollectionType GetCollectionType(
IEnumerable<(INamedTypeSymbol Type, CollectionType CollectionType)> collectionTypes,
ITypeSymbol type
)
{
if (type.IsArrayType())
return CollectionType.Array;

foreach (var (collectionTypeSymbol, collectionType) in collectionTypes)
{
if (SymbolEqualityComparer.Default.Equals(type.OriginalDefinition, collectionTypeSymbol))
return collectionType;
}

return CollectionType.None;
}

private static CollectionType GetImplementedCollectionTypes(
IEnumerable<(INamedTypeSymbol Type, CollectionType CollectionType)> collectionTypes,
ITypeSymbol type
)
{
var implementedCollectionTypes = type.IsArrayType() ? CollectionType.Array : CollectionType.None;

foreach (var (collectionTypeSymbol, collectionType) in collectionTypes)
{
if (type.ImplementsGeneric(collectionTypeSymbol, out _))
{
implementedCollectionTypes |= collectionType;
}
}

return implementedCollectionTypes;
}

private static IReadOnlyCollection<(INamedTypeSymbol Type, CollectionType CollectionType)> GetCollectionTypeMappings(
WellKnownTypes types
)
{
return _newCollectionTypeMappings
.Select(x => (Type: types.TryGet(x.Key), CollectionType: x.Value))
.Where(x => x.Type != null)!
.Concat(_collectionTypeMappings.Select(x => (types.Get(x.Key), x.Value)))
.ToList();
}
}
3 changes: 3 additions & 0 deletions src/Riok.Mapperly/Descriptors/Enumerables/CollectionInfos.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
namespace Riok.Mapperly.Descriptors.Enumerables;

public record CollectionInfos(CollectionInfo Source, CollectionInfo Target);
43 changes: 43 additions & 0 deletions src/Riok.Mapperly/Descriptors/Enumerables/CollectionType.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
namespace Riok.Mapperly.Descriptors.Enumerables;

[Flags]
public enum CollectionType
{
None = 0,
Array = 1 << 0,
IEnumerable = 1 << 1,

// collections
List = 1 << 2,
Stack = 1 << 3,
Queue = 1 << 4,
IReadOnlyCollection = 1 << 5,
IList = 1 << 6,
IReadOnlyList = 1 << 7,
ICollection = 1 << 8,

// sets
HashSet = 1 << 9,
SortedSet = 1 << 10,
IReadOnlySet = 1 << 11,
ISet = 1 << 12,

// dictionaries
IDictionary = 1 << 13,
IReadOnlyDictionary = 1 << 14,
Dictionary = 1 << 15,

// immutable
ImmutableArray = 1 << 16,
ImmutableList = 1 << 17,
IImmutableList = 1 << 18,
ImmutableHashSet = 1 << 19,
IImmutableSet = 1 << 20,
ImmutableSortedSet = 1 << 21,
ImmutableQueue = 1 << 22,
IImmutableQueue = 1 << 23,
ImmutableStack = 1 << 24,
IImmutableStack = 1 << 25,
ImmutableDictionary = 1 << 26,
IImmutableDictionary = 1 << 27,
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
namespace Riok.Mapperly.Descriptors.Enumerables.EnsureCapacity;

/// <summary>
/// Generates an <see cref="EnsureCapacity?"/> of types <see cref="EnsureCapacityNonEnumerated"/> or <see cref="EnsureCapacityMember"/> depending on type information.
/// Generates an <see cref="EnsureCapacity"/> of types <see cref="EnsureCapacityNonEnumerated"/> or <see cref="EnsureCapacityMember"/> depending on type information.
/// </summary>
public static class EnsureCapacityBuilder
{
Expand All @@ -14,28 +14,27 @@ public static class EnsureCapacityBuilder
private const string LengthPropertyName = nameof(Array.Length);
private const string TryGetNonEnumeratedCountMethodName = "TryGetNonEnumeratedCount";

public static EnsureCapacity? TryBuildEnsureCapacity(ITypeSymbol sourceType, ITypeSymbol targetType, WellKnownTypes types)
public static EnsureCapacity? TryBuildEnsureCapacity(MappingBuilderContext ctx)
{
var capacityMethod = targetType
var capacityMethod = ctx.Target
.GetAllMethods(EnsureCapacityName)
.OfType<IMethodSymbol>()
.FirstOrDefault(x => x.Parameters.Length == 1 && x.Parameters[0].Type.SpecialType == SpecialType.System_Int32 && !x.IsStatic);

// if EnsureCapacity is not available then return null
if (capacityMethod == null)
return null;

// if target does not have a count then return null
if (!TryGetNonEnumeratedCount(targetType, types, out var targetSizeProperty))
if (!TryGetNonEnumeratedCount(ctx.Target, ctx.Types, out var targetSizeProperty))
return null;

// if target and source count are known then create a simple EnsureCapacity statement
if (TryGetNonEnumeratedCount(sourceType, types, out var sourceSizeProperty))
if (TryGetNonEnumeratedCount(ctx.Source, ctx.Types, out var sourceSizeProperty))
return new EnsureCapacityMember(targetSizeProperty, sourceSizeProperty);

sourceType.ImplementsGeneric(types.Get(typeof(IEnumerable<>)), out var iEnumerable);
ctx.Source.ImplementsGeneric(ctx.Types.Get(typeof(IEnumerable<>)), out var iEnumerable);

var nonEnumeratedCountMethod = types
var nonEnumeratedCountMethod = ctx.Types
.Get(typeof(Enumerable))
.GetMembers(TryGetNonEnumeratedCountMethodName)
.OfType<IMethodSymbol>()
Expand All @@ -47,7 +46,7 @@ public static class EnsureCapacityBuilder
if (nonEnumeratedCountMethod == null)
return null;

// if source does not have a count use GetNonEnumeratedCount, calling EnusureCapacity if count is available
// if source does not have a count use GetNonEnumeratedCount, calling EnsureCapacity if count is available
var typedNonEnumeratedCount = nonEnumeratedCountMethod.Construct(iEnumerable!.TypeArguments.ToArray());
return new EnsureCapacityNonEnumerated(targetSizeProperty, typedNonEnumeratedCount);
}
Expand Down
4 changes: 4 additions & 0 deletions src/Riok.Mapperly/Descriptors/MappingBuilderContext.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System.Diagnostics;
using Microsoft.CodeAnalysis;
using Riok.Mapperly.Configuration;
using Riok.Mapperly.Descriptors.Enumerables;
using Riok.Mapperly.Descriptors.Mappings;
using Riok.Mapperly.Descriptors.Mappings.ExistingTarget;
using Riok.Mapperly.Descriptors.Mappings.UserMappings;
Expand All @@ -14,6 +15,7 @@ namespace Riok.Mapperly.Descriptors;
public class MappingBuilderContext : SimpleMappingBuilderContext
{
private readonly IMethodSymbol? _userSymbol;
private CollectionInfos? _collectionInfos;

public MappingBuilderContext(
SimpleMappingBuilderContext parentCtx,
Expand All @@ -40,6 +42,8 @@ protected MappingBuilderContext(MappingBuilderContext ctx, IMethodSymbol? userSy

public ITypeSymbol Target { get; }

public CollectionInfos? CollectionInfos => _collectionInfos ??= CollectionInfoBuilder.Build(Types, Source, Target);

/// <summary>
/// Whether the current mapping code is generated for a <see cref="System.Linq.Expressions.Expression"/>.
/// </summary>
Expand Down
Loading

0 comments on commit 87834ce

Please sign in to comment.