Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: refactor collections support #501

Merged
merged 1 commit into from
Jun 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions src/Riok.Mapperly/Descriptors/Enumerables/CollectionInfo.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
using Microsoft.CodeAnalysis;
using Riok.Mapperly.Helpers;

namespace Riok.Mapperly.Descriptors.Enumerables;

public record CollectionInfo(
CollectionType Type,
CollectionType ImplementedTypes,
ITypeSymbol EnumeratedType,
bool CountIsKnown,
bool HasImplicitCollectionAddMethod,
bool IsImmutableCollectionType
)
{
public (ITypeSymbol, ITypeSymbol)? GetDictionaryKeyValueTypes(MappingBuilderContext ctx, ITypeSymbol t)
{
if (t.ImplementsGeneric(ctx.Types.Get(typeof(IDictionary<,>)), out var dictionaryImpl))
{
return (dictionaryImpl.TypeArguments[0], dictionaryImpl.TypeArguments[1]);
}

if (t.ImplementsGeneric(ctx.Types.Get(typeof(IReadOnlyDictionary<,>)), out var readOnlyDictionaryImpl))
{
return (readOnlyDictionaryImpl.TypeArguments[0], readOnlyDictionaryImpl.TypeArguments[1]);
}

return null;
}

public (ITypeSymbol, ITypeSymbol)? GetEnumeratedKeyValueTypes(WellKnownTypes types)
{
if (
EnumeratedType is not INamedTypeSymbol namedEnumeratedType
|| !SymbolEqualityComparer.Default.Equals(namedEnumeratedType.ConstructedFrom, types.Get(typeof(KeyValuePair<,>)))
)
return null;

Check warning on line 36 in src/Riok.Mapperly/Descriptors/Enumerables/CollectionInfo.cs

View check run for this annotation

Codecov / codecov/patch

src/Riok.Mapperly/Descriptors/Enumerables/CollectionInfo.cs#L36

Added line #L36 was not covered by tests

return (namedEnumeratedType.TypeArguments[0], namedEnumeratedType.TypeArguments[1]);
}
}
146 changes: 146 additions & 0 deletions src/Riok.Mapperly/Descriptors/Enumerables/CollectionInfoBuilder.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
using System.Collections.Immutable;
using Microsoft.CodeAnalysis;
using Riok.Mapperly.Helpers;

namespace Riok.Mapperly.Descriptors.Enumerables;

public static class CollectionInfoBuilder
{
private record CollectionTypeInfo(
CollectionType CollectionType,
Type? ReflectionType = null,
string? TypeFullName = null,
bool Immutable = false
)
{
public INamedTypeSymbol? GetTypeSymbol(WellKnownTypes types)
{
if (ReflectionType != null)
return types.Get(ReflectionType);

if (TypeFullName != null)
return types.TryGet(TypeFullName);

throw new InvalidOperationException("One type needs to be set for each collection type");

Check warning on line 24 in src/Riok.Mapperly/Descriptors/Enumerables/CollectionInfoBuilder.cs

View check run for this annotation

Codecov / codecov/patch

src/Riok.Mapperly/Descriptors/Enumerables/CollectionInfoBuilder.cs#L24

Added line #L24 was not covered by tests
}
}

private static readonly CollectionTypeInfo _collectionTypeInfoArray = new(CollectionType.Array);

private static readonly IReadOnlyCollection<CollectionTypeInfo> _collectionTypeInfos = new[]
{
new CollectionTypeInfo(CollectionType.IEnumerable, typeof(IEnumerable<>)),
new CollectionTypeInfo(CollectionType.List, typeof(List<>)),
new CollectionTypeInfo(CollectionType.Stack, typeof(Stack<>)),
new CollectionTypeInfo(CollectionType.Queue, typeof(Queue<>)),
new CollectionTypeInfo(CollectionType.IReadOnlyCollection, typeof(IReadOnlyCollection<>)),
new CollectionTypeInfo(CollectionType.IList, typeof(IList<>)),
new CollectionTypeInfo(CollectionType.IReadOnlyList, typeof(IReadOnlyList<>)),
new CollectionTypeInfo(CollectionType.ICollection, typeof(ICollection<>)),
new CollectionTypeInfo(CollectionType.HashSet, typeof(HashSet<>)),
new CollectionTypeInfo(CollectionType.SortedSet, typeof(SortedSet<>)),
new CollectionTypeInfo(CollectionType.ISet, typeof(ISet<>)),
new CollectionTypeInfo(CollectionType.IReadOnlySet, TypeFullName: "System.Collections.Generic.IReadOnlySet`1"),
new CollectionTypeInfo(CollectionType.IDictionary, typeof(IDictionary<,>)),
new CollectionTypeInfo(CollectionType.IReadOnlyDictionary, typeof(IReadOnlyDictionary<,>)),
new CollectionTypeInfo(CollectionType.Dictionary, typeof(Dictionary<,>)),
new CollectionTypeInfo(CollectionType.ImmutableArray, typeof(ImmutableArray<>), Immutable: true),
new CollectionTypeInfo(CollectionType.ImmutableList, typeof(ImmutableList<>), Immutable: true),
new CollectionTypeInfo(CollectionType.IImmutableList, typeof(IImmutableList<>), Immutable: true),
new CollectionTypeInfo(CollectionType.ImmutableHashSet, typeof(ImmutableHashSet<>), Immutable: true),
new CollectionTypeInfo(CollectionType.IImmutableSet, typeof(IImmutableSet<>), Immutable: true),
new CollectionTypeInfo(CollectionType.ImmutableSortedSet, typeof(ImmutableSortedSet<>), Immutable: true),
new CollectionTypeInfo(CollectionType.ImmutableQueue, typeof(ImmutableQueue<>), Immutable: true),
new CollectionTypeInfo(CollectionType.IImmutableQueue, typeof(IImmutableQueue<>), Immutable: true),
new CollectionTypeInfo(CollectionType.IImmutableQueue, typeof(IImmutableQueue<>), Immutable: true),
new CollectionTypeInfo(CollectionType.ImmutableStack, typeof(ImmutableStack<>), Immutable: true),
new CollectionTypeInfo(CollectionType.IImmutableStack, typeof(IImmutableStack<>), Immutable: true),
new CollectionTypeInfo(CollectionType.IImmutableDictionary, typeof(IImmutableDictionary<,>), Immutable: true),
new CollectionTypeInfo(CollectionType.ImmutableDictionary, typeof(ImmutableDictionary<,>), Immutable: true),
};

public static CollectionInfos? Build(WellKnownTypes wellKnownTypes, ITypeSymbol source, ITypeSymbol target)
{
if (Build(wellKnownTypes, source) is not { } sourceInfo)
return null;

if (Build(wellKnownTypes, target) is not { } targetInfo)
return null;

return new CollectionInfos(sourceInfo, targetInfo);
}

private static CollectionInfo? Build(WellKnownTypes wellKnownTypes, ITypeSymbol type)
{
var enumeratedType = GetEnumeratedType(wellKnownTypes, type);
if (enumeratedType == null)
return null;

var collectionTypeInfo = GetCollectionTypeInfo(wellKnownTypes, type);
return new CollectionInfo(
collectionTypeInfo?.CollectionType ?? CollectionType.None,
GetImplementedCollectionTypes(wellKnownTypes, type),
enumeratedType,
IsCountKnown(wellKnownTypes, type),
HasValidAddMethod(wellKnownTypes, type),
collectionTypeInfo?.Immutable == true
);
}

private static ITypeSymbol? GetEnumeratedType(WellKnownTypes types, ITypeSymbol type)
{
return type.ImplementsGeneric(types.Get(typeof(IEnumerable<>)), out var enumerableIntf) ? enumerableIntf.TypeArguments[0] : null;
}

private static bool HasValidAddMethod(WellKnownTypes types, ITypeSymbol t)
{
return t.HasImplicitGenericImplementation(types.Get(typeof(ICollection<>)), nameof(ICollection<object>.Add))
|| t.HasImplicitGenericImplementation(types.Get(typeof(ISet<>)), nameof(ISet<object>.Add));
}

private static bool IsCountKnown(WellKnownTypes types, ITypeSymbol t)
{
var intType = types.Get<int>();
return t.GetAccessibleMappableMembers()
.Any(
x =>
x.Name is nameof(ICollection<object>.Count) or nameof(Array.Length)
&& SymbolEqualityComparer.IncludeNullability.Equals(intType, x.Type)
);
}

private static CollectionTypeInfo? GetCollectionTypeInfo(WellKnownTypes types, ITypeSymbol type)
{
if (type.IsArrayType())
return _collectionTypeInfoArray;

foreach (var typeInfo in _collectionTypeInfos)
{
if (typeInfo.GetTypeSymbol(types) is not { } typeSymbol)
continue;

Check warning on line 120 in src/Riok.Mapperly/Descriptors/Enumerables/CollectionInfoBuilder.cs

View check run for this annotation

Codecov / codecov/patch

src/Riok.Mapperly/Descriptors/Enumerables/CollectionInfoBuilder.cs#L120

Added line #L120 was not covered by tests

if (SymbolEqualityComparer.Default.Equals(type.OriginalDefinition, typeSymbol))
return typeInfo;
}

return null;
}

private static CollectionType GetImplementedCollectionTypes(WellKnownTypes types, ITypeSymbol type)
{
var implementedCollectionTypes = type.IsArrayType() ? CollectionType.Array : CollectionType.None;

foreach (var typeInfo in _collectionTypeInfos)
{
if (typeInfo.GetTypeSymbol(types) is not { } typeSymbol)
continue;

Check warning on line 136 in src/Riok.Mapperly/Descriptors/Enumerables/CollectionInfoBuilder.cs

View check run for this annotation

Codecov / codecov/patch

src/Riok.Mapperly/Descriptors/Enumerables/CollectionInfoBuilder.cs#L136

Added line #L136 was not covered by tests

if (type.ImplementsGeneric(typeSymbol, out _))
{
implementedCollectionTypes |= typeInfo.CollectionType;
}
}

return implementedCollectionTypes;
}
}
3 changes: 3 additions & 0 deletions src/Riok.Mapperly/Descriptors/Enumerables/CollectionInfos.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
namespace Riok.Mapperly.Descriptors.Enumerables;

public record CollectionInfos(CollectionInfo Source, CollectionInfo Target);
43 changes: 43 additions & 0 deletions src/Riok.Mapperly/Descriptors/Enumerables/CollectionType.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
namespace Riok.Mapperly.Descriptors.Enumerables;

[Flags]
public enum CollectionType
{
None = 0,
Array = 1 << 0,
IEnumerable = 1 << 1,

// collections
List = 1 << 2,
Stack = 1 << 3,
Queue = 1 << 4,
IReadOnlyCollection = 1 << 5,
IList = 1 << 6,
IReadOnlyList = 1 << 7,
ICollection = 1 << 8,

// sets
HashSet = 1 << 9,
SortedSet = 1 << 10,
IReadOnlySet = 1 << 11,
ISet = 1 << 12,

// dictionaries
IDictionary = 1 << 13,
IReadOnlyDictionary = 1 << 14,
Dictionary = 1 << 15,

// immutable
ImmutableArray = 1 << 16,
ImmutableList = 1 << 17,
IImmutableList = 1 << 18,
ImmutableHashSet = 1 << 19,
IImmutableSet = 1 << 20,
ImmutableSortedSet = 1 << 21,
ImmutableQueue = 1 << 22,
IImmutableQueue = 1 << 23,
ImmutableStack = 1 << 24,
IImmutableStack = 1 << 25,
ImmutableDictionary = 1 << 26,
IImmutableDictionary = 1 << 27,
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
namespace Riok.Mapperly.Descriptors.Enumerables.EnsureCapacity;

/// <summary>
/// Generates an <see cref="EnsureCapacity?"/> of types <see cref="EnsureCapacityNonEnumerated"/> or <see cref="EnsureCapacityMember"/> depending on type information.
/// Generates an <see cref="EnsureCapacity"/> of types <see cref="EnsureCapacityNonEnumerated"/> or <see cref="EnsureCapacityMember"/> depending on type information.
/// </summary>
public static class EnsureCapacityBuilder
{
Expand All @@ -14,28 +14,27 @@ public static class EnsureCapacityBuilder
private const string LengthPropertyName = nameof(Array.Length);
private const string TryGetNonEnumeratedCountMethodName = "TryGetNonEnumeratedCount";

public static EnsureCapacity? TryBuildEnsureCapacity(ITypeSymbol sourceType, ITypeSymbol targetType, WellKnownTypes types)
public static EnsureCapacity? TryBuildEnsureCapacity(MappingBuilderContext ctx)
{
var capacityMethod = targetType
var capacityMethod = ctx.Target
.GetAllMethods(EnsureCapacityName)
.OfType<IMethodSymbol>()
.FirstOrDefault(x => x.Parameters.Length == 1 && x.Parameters[0].Type.SpecialType == SpecialType.System_Int32 && !x.IsStatic);

// if EnsureCapacity is not available then return null
if (capacityMethod == null)
return null;

// if target does not have a count then return null
if (!TryGetNonEnumeratedCount(targetType, types, out var targetSizeProperty))
if (!TryGetNonEnumeratedCount(ctx.Target, ctx.Types, out var targetSizeProperty))
return null;

// if target and source count are known then create a simple EnsureCapacity statement
if (TryGetNonEnumeratedCount(sourceType, types, out var sourceSizeProperty))
if (TryGetNonEnumeratedCount(ctx.Source, ctx.Types, out var sourceSizeProperty))
return new EnsureCapacityMember(targetSizeProperty, sourceSizeProperty);

sourceType.ImplementsGeneric(types.Get(typeof(IEnumerable<>)), out var iEnumerable);
ctx.Source.ImplementsGeneric(ctx.Types.Get(typeof(IEnumerable<>)), out var iEnumerable);

var nonEnumeratedCountMethod = types
var nonEnumeratedCountMethod = ctx.Types
.Get(typeof(Enumerable))
.GetMembers(TryGetNonEnumeratedCountMethodName)
.OfType<IMethodSymbol>()
Expand All @@ -47,7 +46,7 @@ public static class EnsureCapacityBuilder
if (nonEnumeratedCountMethod == null)
return null;

// if source does not have a count use GetNonEnumeratedCount, calling EnusureCapacity if count is available
// if source does not have a count use GetNonEnumeratedCount, calling EnsureCapacity if count is available
var typedNonEnumeratedCount = nonEnumeratedCountMethod.Construct(iEnumerable!.TypeArguments.ToArray());
return new EnsureCapacityNonEnumerated(targetSizeProperty, typedNonEnumeratedCount);
}
Expand Down
4 changes: 4 additions & 0 deletions src/Riok.Mapperly/Descriptors/MappingBuilderContext.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System.Diagnostics;
using Microsoft.CodeAnalysis;
using Riok.Mapperly.Configuration;
using Riok.Mapperly.Descriptors.Enumerables;
using Riok.Mapperly.Descriptors.Mappings;
using Riok.Mapperly.Descriptors.Mappings.ExistingTarget;
using Riok.Mapperly.Descriptors.Mappings.UserMappings;
Expand All @@ -14,6 +15,7 @@ namespace Riok.Mapperly.Descriptors;
public class MappingBuilderContext : SimpleMappingBuilderContext
{
private readonly IMethodSymbol? _userSymbol;
private CollectionInfos? _collectionInfos;

public MappingBuilderContext(
SimpleMappingBuilderContext parentCtx,
Expand All @@ -40,6 +42,8 @@ protected MappingBuilderContext(MappingBuilderContext ctx, IMethodSymbol? userSy

public ITypeSymbol Target { get; }

public CollectionInfos? CollectionInfos => _collectionInfos ??= CollectionInfoBuilder.Build(Types, Source, Target);

/// <summary>
/// Whether the current mapping code is generated for a <see cref="System.Linq.Expressions.Expression"/>.
/// </summary>
Expand Down
Loading