Skip to content

Commit

Permalink
fix: handle nullable correctly in nullable disabled contexts for gene…
Browse files Browse the repository at this point in the history
…rics and arrays (#1029)
  • Loading branch information
latonz committed Jan 2, 2024
1 parent 89403e0 commit 132b8cc
Show file tree
Hide file tree
Showing 41 changed files with 445 additions and 227 deletions.
Expand Up @@ -101,7 +101,7 @@ ITypeSymbol enumeratedType
type,
typeInfo,
implementedTypes,
enumeratedType,
symbolAccessor.UpgradeNullable(enumeratedType),
FindCountProperty(symbolAccessor, type, typeInfo),
HasValidAddMethod(wellKnownTypes, type, typeInfo, implementedTypes),
collectionTypeInfo?.Immutable == true
Expand All @@ -112,8 +112,8 @@ ITypeSymbol enumeratedType
{
// if type is array return element type
// otherwise using the IEnumerable element type can erase the null annotation for external types
if (type.IsArrayType())
return ((IArrayTypeSymbol)type).ElementType;
if (type.IsArrayType(out var arraySymbol))
return arraySymbol.ElementType;

if (type.ImplementsGeneric(types.Get(typeof(IEnumerable<>)), out var enumerableIntf))
return enumerableIntf.TypeArguments[0];
Expand All @@ -133,6 +133,7 @@ ITypeSymbol enumeratedType
{
return namedType.TypeArguments[0];
}

// Memory<> or ReadOnlyMemory<> etc, get the type symbol
if (
SymbolEqualityComparer.Default.Equals(type.OriginalDefinition, types.Get(typeof(Memory<>)))
Expand Down
Expand Up @@ -209,7 +209,7 @@ out sourcePath
if (mappableField == default)
return false;

sourcePath = new MemberPath(new[] { new FieldMember(mappableField) });
sourcePath = new MemberPath(new[] { new FieldMember(mappableField, ctx.BuilderContext.SymbolAccessor) });
return true;
}
}
Expand Up @@ -62,10 +62,10 @@ bool duplicatedSourceTypesAllowed
var derivedTypeMappingSourceTypes = new HashSet<ITypeSymbol>(SymbolEqualityComparer.Default);
var derivedTypeMappings = new List<TMapping>(configs.Count);
Func<ITypeSymbol, bool> isAssignableToSource = ctx.Source is ITypeParameterSymbol sourceTypeParameter
? t => ctx.SymbolAccessor.DoesTypeSatisfyTypeParameterConstraints(sourceTypeParameter, t, ctx.Source.NullableAnnotation)
? t => ctx.SymbolAccessor.DoesTypeSatisfyTypeParameterConstraints(sourceTypeParameter, t)
: t => ctx.SymbolAccessor.HasImplicitConversion(t, ctx.Source);
Func<ITypeSymbol, bool> isAssignableToTarget = ctx.Target is ITypeParameterSymbol targetTypeParameter
? t => ctx.SymbolAccessor.DoesTypeSatisfyTypeParameterConstraints(targetTypeParameter, t, ctx.Target.NullableAnnotation)
? t => ctx.SymbolAccessor.DoesTypeSatisfyTypeParameterConstraints(targetTypeParameter, t)
: t => ctx.SymbolAccessor.HasImplicitConversion(t, ctx.Target);

foreach (var config in configs)
Expand Down
Expand Up @@ -159,7 +159,9 @@ INewInstanceMapping valueMapping

if (fromEnumerableCtor != null)
{
var constructedDictionary = dictionaryType.Construct(keyMapping.TargetType, valueMapping.TargetType);
var constructedDictionary = dictionaryType
.Construct(keyMapping.TargetType, valueMapping.TargetType)
.WithNullableAnnotation(NullableAnnotation.NotAnnotated);
return new CtorMapping(ctx.Source, constructedDictionary);
}

Expand Down
Expand Up @@ -186,13 +186,13 @@ or CollectionType.IEnumerable
if (elementMapping.IsSynthetic)
return null;

// upgrade nullability of element type
var targetType = ((INamedTypeSymbol)ctx.CollectionInfos!.Target.Type.OriginalDefinition).Construct(elementMapping.TargetType);
var targetTypeToInstantiate = ctx.Types.Get(typeof(List<>)).Construct(elementMapping.TargetType);
var targetTypeToInstantiate = ctx.Types.Get(typeof(List<>))
.Construct(elementMapping.TargetType)
.WithNullableAnnotation(NullableAnnotation.NotAnnotated);

return new ForEachAddEnumerableMapping(
ctx.Source,
targetType,
ctx.CollectionInfos!.Target.Type,
elementMapping,
AddMethodName,
targetTypeToInstantiate,
Expand All @@ -207,12 +207,7 @@ private static NewInstanceMapping BuildArrayToArrayMapping(MappingBuilderContext
// use a for loop mapping otherwise.
if (!elementMapping.IsSynthetic)
{
// upgrade nullability of element type
var targetType =
ctx.CollectionInfos!.Target.CollectionType == CollectionType.Array
? ctx.Types.GetArrayType(elementMapping.TargetType)
: ((INamedTypeSymbol)ctx.Target).ConstructedFrom.Construct(elementMapping.TargetType);
return new ArrayForMapping(ctx.Source, targetType, elementMapping, elementMapping.TargetType);
return new ArrayForMapping(ctx.Source, ctx.Target, elementMapping, elementMapping.TargetType);
}

return ctx.MapperConfiguration.UseDeepCloning
Expand All @@ -226,18 +221,12 @@ private static NewInstanceMapping BuildArrayToArrayMapping(MappingBuilderContext
if (elementMapping.IsSynthetic)
return null;

// upgrade nullability of element type
var targetType =
ctx.CollectionInfos!.Target.CollectionType == CollectionType.Array
? ctx.Types.GetArrayType(elementMapping.TargetType)
: ((INamedTypeSymbol)ctx.Target).ConstructedFrom.Construct(elementMapping.TargetType);

return new ArrayForEachMapping(
ctx.Source,
targetType,
ctx.Target,
elementMapping,
elementMapping.TargetType,
ctx.CollectionInfos.Source.CountPropertyName!
ctx.CollectionInfos!.Source.CountPropertyName!
);
}

Expand Down
Expand Up @@ -41,7 +41,7 @@ public static class ParseMappingBuilder
return null;

// otherwise try to find parse method ignoring the nullability
parseMethod = parseMethodCandidates.FirstOrDefault(x => SymbolEqualityComparer.Default.Equals(x.ReturnType, nonNullableTarget));
parseMethod = parseMethodCandidates.Find(x => SymbolEqualityComparer.Default.Equals(x.ReturnType, nonNullableTarget));
return parseMethod == null ? null : new StaticMethodMapping(parseMethod);
}
}
@@ -1,3 +1,4 @@
using Microsoft.CodeAnalysis;
using Riok.Mapperly.Abstractions;
using Riok.Mapperly.Descriptors.Enumerables;
using Riok.Mapperly.Descriptors.Enumerables.EnsureCapacity;
Expand Down Expand Up @@ -175,7 +176,9 @@ private static NewInstanceMapping BuildToArrayOrMap(MappingBuilderContext ctx, I

private static NewInstanceMapping? BuildSpanToList(MappingBuilderContext ctx, INewInstanceMapping elementMapping)
{
var typedList = ctx.Types.Get(typeof(List<>)).Construct(elementMapping.TargetType);
var typedList = ctx.Types.Get(typeof(List<>))
.Construct(elementMapping.TargetType)
.WithNullableAnnotation(NullableAnnotation.NotAnnotated);
if (ctx.FindOrBuildMapping(ctx.Source, typedList) is not { } listMapping)
return null;

Expand Down
@@ -1,16 +1,27 @@
using System.Diagnostics;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Riok.Mapperly.Helpers;

namespace Riok.Mapperly.Descriptors.Mappings.ExistingTarget;

/// <summary>
/// A default implementation of <see cref="IExistingTargetMapping"/>.
/// </summary>
public abstract class ExistingTargetMapping(ITypeSymbol sourceType, ITypeSymbol targetType) : IExistingTargetMapping
public abstract class ExistingTargetMapping : IExistingTargetMapping
{
public ITypeSymbol SourceType { get; } = sourceType;
protected ExistingTargetMapping(ITypeSymbol sourceType, ITypeSymbol targetType)
{
Debug.Assert(sourceType.IsNullableUpgraded());
Debug.Assert(targetType.IsNullableUpgraded());

public ITypeSymbol TargetType { get; } = targetType;
SourceType = sourceType;
TargetType = targetType;
}

public ITypeSymbol SourceType { get; }

public ITypeSymbol TargetType { get; }

public virtual bool CallableByOtherMappings => true;

Expand Down
2 changes: 1 addition & 1 deletion src/Riok.Mapperly/Descriptors/Mappings/MethodMapping.cs
Expand Up @@ -54,7 +54,7 @@ ITypeSymbol targetType
ReferenceHandlerParameter = referenceHandlerParameter;
_partialMethodDefinition = method;
_methodName = method.Name;
_returnType = method.ReturnType.UpgradeNullable();
_returnType = method.ReturnsVoid ? method.ReturnType : targetType;
}

protected bool IsExtensionMethod { get; }
Expand Down
16 changes: 13 additions & 3 deletions src/Riok.Mapperly/Descriptors/Mappings/NewInstanceMapping.cs
@@ -1,16 +1,26 @@
using System.Diagnostics;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Riok.Mapperly.Helpers;

namespace Riok.Mapperly.Descriptors.Mappings;

/// <inheritdoc cref="INewInstanceMapping"/>
[DebuggerDisplay("{GetType().Name}({SourceType} => {TargetType})")]
public abstract class NewInstanceMapping(ITypeSymbol sourceType, ITypeSymbol targetType) : INewInstanceMapping
public abstract class NewInstanceMapping : INewInstanceMapping
{
public ITypeSymbol SourceType { get; } = sourceType;
protected NewInstanceMapping(ITypeSymbol sourceType, ITypeSymbol targetType)
{
Debug.Assert(sourceType.IsNullableUpgraded());
Debug.Assert(targetType.IsNullableUpgraded());

public ITypeSymbol TargetType { get; } = targetType;
SourceType = sourceType;
TargetType = targetType;
}

public ITypeSymbol SourceType { get; }

public ITypeSymbol TargetType { get; }

public virtual MappingBodyBuildingPriority BodyBuildingPriority => MappingBodyBuildingPriority.Default;

Expand Down
@@ -1,6 +1,7 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Riok.Mapperly.Emit;
using Riok.Mapperly.Helpers;
using Riok.Mapperly.Symbols;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
using static Riok.Mapperly.Emit.Syntax.SyntaxFactoryHelper;
Expand All @@ -17,6 +18,7 @@ public class UserDefinedNewInstanceGenericTypeMapping(
IMethodSymbol method,
GenericMappingTypeParameters typeParameters,
MappingMethodParameters parameters,
ITypeSymbol targetType,
bool enableReferenceHandling,
NullFallbackValue nullArm,
ITypeSymbol objectType
Expand All @@ -25,6 +27,7 @@ ITypeSymbol objectType
method,
parameters.Source,
parameters.ReferenceHandler,
targetType,
enableReferenceHandling,
nullArm,
objectType
Expand All @@ -39,7 +42,7 @@ protected override ExpressionSyntax BuildTargetType()
{
// typeof(TTarget) or typeof(<ReturnType>)
var targetTypeName = TypeParameters.TargetType ?? TargetType;
return TypeOfExpression(FullyQualifiedIdentifier(targetTypeName));
return TypeOfExpression(FullyQualifiedIdentifier(targetTypeName.NonNullable()));
}

protected override ExpressionSyntax? BuildSwitchArmWhenClause(ExpressionSyntax targetType, RuntimeTargetTypeMapping mapping)
Expand Down
@@ -1,7 +1,6 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Riok.Mapperly.Abstractions.ReferenceHandling;
using Riok.Mapperly.Helpers;
using Riok.Mapperly.Symbols;

namespace Riok.Mapperly.Descriptors.Mappings.UserMappings;
Expand All @@ -13,8 +12,9 @@ public class UserDefinedNewInstanceMethodMapping(
IMethodSymbol method,
MethodParameter sourceParameter,
MethodParameter? referenceHandlerParameter,
ITypeSymbol targetType,
bool enableReferenceHandling
) : MethodMapping(method, sourceParameter, referenceHandlerParameter, method.ReturnType.UpgradeNullable()), IDelegateUserMapping
) : MethodMapping(method, sourceParameter, referenceHandlerParameter, targetType), IDelegateUserMapping
{
public IMethodSymbol Method { get; } = method;

Expand Down
Expand Up @@ -17,10 +17,11 @@ public abstract class UserDefinedNewInstanceRuntimeTargetTypeMapping(
IMethodSymbol method,
MethodParameter sourceParameter,
MethodParameter? referenceHandlerParameter,
ITypeSymbol targetType,
bool enableReferenceHandling,
NullFallbackValue nullArm,
ITypeSymbol objectType
) : MethodMapping(method, sourceParameter, referenceHandlerParameter, method.ReturnType), IUserMapping
) : MethodMapping(method, sourceParameter, referenceHandlerParameter, targetType), IUserMapping
{
private const string IsAssignableFromMethodName = nameof(Type.IsAssignableFrom);
private const string GetTypeMethodName = nameof(GetType);
Expand Down Expand Up @@ -54,21 +55,21 @@ public override IEnumerable<StatementSyntax> BuildBody(TypeMappingBuildContext c
ctx = ctx.WithRefHandler(referenceHandlerName);
}

var targetType = BuildTargetType();
var targetTypeExpr = BuildTargetType();

// _ => throw new ArgumentException(msg, nameof(ctx.Source)),
var sourceType = Invocation(MemberAccess(ctx.Source, GetTypeMethodName));
var fallbackArm = SwitchArm(
DiscardPattern(),
ThrowArgumentExpression(
InterpolatedString($"Cannot map {sourceType} to {targetType} as there is no known type mapping"),
InterpolatedString($"Cannot map {sourceType} to {targetTypeExpr} as there is no known type mapping"),
ctx.Source
)
);

// source switch { A x when targetType.IsAssignableFrom(typeof(ADto)) => MapToADto(x), B x when targetType.IsAssignableFrom(typeof(BDto)) => MapToBDto(x) }
var (typeArmContext, typeArmVariableName) = ctx.WithNewScopedSource();
var arms = _mappings.Select(x => BuildSwitchArm(typeArmContext, typeArmVariableName, x, targetType));
var arms = _mappings.Select(x => BuildSwitchArm(typeArmContext, typeArmVariableName, x, targetTypeExpr));

// null => default / throw
arms = arms.Append(SwitchArm(ConstantPattern(NullLiteral()), NullSubstitute(TargetType, ctx.Source, nullArm)));
Expand Down
Expand Up @@ -14,13 +14,15 @@ public class UserDefinedNewInstanceRuntimeTargetTypeParameterMapping(
IMethodSymbol method,
RuntimeTargetTypeMappingMethodParameters parameters,
bool enableReferenceHandling,
ITypeSymbol targetType,
NullFallbackValue nullArm,
ITypeSymbol objectType
)
: UserDefinedNewInstanceRuntimeTargetTypeMapping(
method,
parameters.Source,
parameters.ReferenceHandler,
targetType,
enableReferenceHandling,
nullArm,
objectType
Expand Down
@@ -1,7 +1,6 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Riok.Mapperly.Descriptors.Mappings.ExistingTarget;
using Riok.Mapperly.Helpers;
using Riok.Mapperly.Symbols;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
using static Riok.Mapperly.Emit.Syntax.SyntaxFactoryHelper;
Expand All @@ -17,7 +16,7 @@ public class UserImplementedExistingTargetMethodMapping(
MethodParameter sourceParameter,
MethodParameter targetParameter,
MethodParameter? referenceHandlerParameter
) : ExistingTargetMapping(method.Parameters[0].Type.UpgradeNullable(), targetParameter.Type.UpgradeNullable()), IUserMapping
) : ExistingTargetMapping(method.Parameters[0].Type, targetParameter.Type), IUserMapping
{
public IMethodSymbol Method { get; } = method;

Expand Down
@@ -1,6 +1,5 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Riok.Mapperly.Helpers;
using Riok.Mapperly.Symbols;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
using static Riok.Mapperly.Emit.Syntax.SyntaxFactoryHelper;
Expand All @@ -14,8 +13,9 @@ public class UserImplementedMethodMapping(
string? receiver,
IMethodSymbol method,
MethodParameter sourceParameter,
ITypeSymbol targetType,
MethodParameter? referenceHandlerParameter
) : NewInstanceMapping(method.Parameters[0].Type.UpgradeNullable(), method.ReturnType.UpgradeNullable()), IUserMapping
) : NewInstanceMapping(sourceParameter.Type, targetType), IUserMapping
{
public IMethodSymbol Method { get; } = method;

Expand Down
Expand Up @@ -13,11 +13,7 @@ public class GenericSourceObjectFactory(SymbolAccessor symbolAccessor, IMethodSy
{
public override bool CanCreateType(ITypeSymbol sourceType, ITypeSymbol targetTypeToCreate) =>
SymbolEqualityComparer.Default.Equals(Method.ReturnType, targetTypeToCreate)
&& SymbolAccessor.DoesTypeSatisfyTypeParameterConstraints(
Method.TypeParameters[0],
sourceType,
Method.Parameters[0].Type.NullableAnnotation
);
&& SymbolAccessor.DoesTypeSatisfyTypeParameterConstraints(Method.TypeParameters[0], sourceType);

protected override ExpressionSyntax BuildCreateType(ITypeSymbol sourceType, ITypeSymbol targetTypeToCreate, ExpressionSyntax source) =>
GenericInvocation(Method.Name, new[] { NonNullableIdentifier(sourceType) }, source);
Expand Down
Expand Up @@ -10,16 +10,8 @@ public class GenericSourceTargetObjectFactory(SymbolAccessor symbolAccessor, IMe
private readonly int _targetTypeParameterIndex = (sourceTypeParameterIndex + 1) % 2;

public override bool CanCreateType(ITypeSymbol sourceType, ITypeSymbol targetTypeToCreate) =>
SymbolAccessor.DoesTypeSatisfyTypeParameterConstraints(
Method.TypeParameters[sourceTypeParameterIndex],
sourceType,
Method.Parameters[0].Type.NullableAnnotation
)
&& SymbolAccessor.DoesTypeSatisfyTypeParameterConstraints(
Method.TypeParameters[_targetTypeParameterIndex],
targetTypeToCreate,
Method.ReturnType.NullableAnnotation
);
SymbolAccessor.DoesTypeSatisfyTypeParameterConstraints(Method.TypeParameters[sourceTypeParameterIndex], sourceType)
&& SymbolAccessor.DoesTypeSatisfyTypeParameterConstraints(Method.TypeParameters[_targetTypeParameterIndex], targetTypeToCreate);

protected override ExpressionSyntax BuildCreateType(ITypeSymbol sourceType, ITypeSymbol targetTypeToCreate, ExpressionSyntax source)
{
Expand Down
Expand Up @@ -12,11 +12,7 @@ namespace Riok.Mapperly.Descriptors.ObjectFactories;
public class GenericTargetObjectFactory(SymbolAccessor symbolAccessor, IMethodSymbol method) : ObjectFactory(symbolAccessor, method)
{
public override bool CanCreateType(ITypeSymbol sourceType, ITypeSymbol targetTypeToCreate) =>
SymbolAccessor.DoesTypeSatisfyTypeParameterConstraints(
Method.TypeParameters[0],
targetTypeToCreate,
Method.ReturnType.NullableAnnotation
);
SymbolAccessor.DoesTypeSatisfyTypeParameterConstraints(Method.TypeParameters[0], targetTypeToCreate);

protected override ExpressionSyntax BuildCreateType(ITypeSymbol sourceType, ITypeSymbol targetTypeToCreate, ExpressionSyntax source) =>
GenericInvocation(Method.Name, new[] { NonNullableIdentifier(targetTypeToCreate) });
Expand Down

0 comments on commit 132b8cc

Please sign in to comment.