Skip to content

Commit

Permalink
fix: support queryable projection mappings for derived type mappings (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
latonz committed May 8, 2023
1 parent e5d9bb2 commit 98a44a0
Show file tree
Hide file tree
Showing 19 changed files with 344 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .github/pull_request_template.md
@@ -1,4 +1,4 @@
# Mapperly PR
# <Title>

## Description

Expand Down
Expand Up @@ -18,21 +18,24 @@ public static class DerivedTypeMappingBuilder
return null;

var derivedTypeMappings = BuildDerivedTypeMappings(ctx, configs);
return new DerivedTypeMapping(ctx.Source, ctx.Target, derivedTypeMappings);
return ctx.IsExpression
? new DerivedTypeIfExpressionMapping(ctx.Source, ctx.Target, derivedTypeMappings)
: new DerivedTypeSwitchMapping(ctx.Source, ctx.Target, derivedTypeMappings);
}

private static IReadOnlyDictionary<ITypeSymbol, ITypeMapping> BuildDerivedTypeMappings(
private static IReadOnlyCollection<ITypeMapping> BuildDerivedTypeMappings(
MappingBuilderContext ctx,
IEnumerable<MapDerivedType> configs
IReadOnlyCollection<MapDerivedType> configs
)
{
var derivedTypeMappings = new Dictionary<ITypeSymbol, ITypeMapping>(SymbolEqualityComparer.Default);
var derivedTypeMappingSourceTypes = new HashSet<ITypeSymbol>(SymbolEqualityComparer.Default);
var derivedTypeMappings = new List<ITypeMapping>(configs.Count);

foreach (var config in configs)
{
// set reference types non-nullable as they can never be null when type-switching.
var sourceType = config.SourceType.WithNullableAnnotation(NullableAnnotation.NotAnnotated);
if (derivedTypeMappings.ContainsKey(sourceType))
if (!derivedTypeMappingSourceTypes.Add(sourceType))
{
ctx.ReportDiagnostic(DiagnosticDescriptors.DerivedSourceTypeDuplicated, sourceType);
continue;
Expand All @@ -58,7 +61,7 @@ IEnumerable<MapDerivedType> configs
continue;
}

derivedTypeMappings.Add(sourceType, mapping);
derivedTypeMappings.Add(mapping);
}

return derivedTypeMappings;
Expand Down
@@ -0,0 +1,49 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
using static Riok.Mapperly.Emit.SyntaxFactoryHelper;

namespace Riok.Mapperly.Descriptors.Mappings;

/// <summary>
/// A derived type mapping maps one base type or interface to another
/// by implementing a if with instance checks over known types and performs the provided mapping for each type.
/// </summary>
public class DerivedTypeIfExpressionMapping : TypeMapping
{
private readonly IReadOnlyCollection<ITypeMapping> _typeMappings;

public DerivedTypeIfExpressionMapping(ITypeSymbol sourceType, ITypeSymbol targetType, IReadOnlyCollection<ITypeMapping> typeMappings)
: base(sourceType, targetType)
{
_typeMappings = typeMappings;
}

public override ExpressionSyntax Build(TypeMappingBuildContext ctx)
{
// source is A x ? MapToA(x) : <other cases>
var typeExpressions = _typeMappings
.Reverse()
.Aggregate<ITypeMapping, ExpressionSyntax>(DefaultLiteral(), (aggregate, current) => BuildConditional(ctx, aggregate, current));

// cast to target type, to ensure the compiler picks the correct type
// (B)(<ifs...>
return CastExpression(FullyQualifiedIdentifier(TargetType), ParenthesizedExpression(typeExpressions));
}

private ConditionalExpressionSyntax BuildConditional(TypeMappingBuildContext ctx, ExpressionSyntax notMatched, ITypeMapping mapping)
{
// cannot use is pattern matching is operator due to expression limitations
// use is with a cast instead
// source is A ? MapToB((A)x) : <other cases>
var castedSourceContext = ctx.WithSource(
ParenthesizedExpression(CastExpression(FullyQualifiedIdentifier(mapping.SourceType), ctx.Source))
);
return ConditionalExpression(
BinaryExpression(SyntaxKind.IsExpression, ctx.Source, FullyQualifiedIdentifier(mapping.SourceType)),
mapping.Build(castedSourceContext),
notMatched
);
}
}
Expand Up @@ -9,13 +9,13 @@ namespace Riok.Mapperly.Descriptors.Mappings;
/// A derived type mapping maps one base type or interface to another
/// by implementing a type switch over known types and performs the provided mapping for each type.
/// </summary>
public class DerivedTypeMapping : TypeMapping
public class DerivedTypeSwitchMapping : TypeMapping
{
private const string GetTypeMethodName = "GetType";

private readonly IReadOnlyDictionary<ITypeSymbol, ITypeMapping> _typeMappings;
private readonly IReadOnlyCollection<ITypeMapping> _typeMappings;

public DerivedTypeMapping(ITypeSymbol sourceType, ITypeSymbol targetType, IReadOnlyDictionary<ITypeSymbol, ITypeMapping> typeMappings)
public DerivedTypeSwitchMapping(ITypeSymbol sourceType, ITypeSymbol targetType, IReadOnlyCollection<ITypeMapping> typeMappings)
: base(sourceType, targetType)
{
_typeMappings = typeMappings;
Expand All @@ -35,7 +35,9 @@ public override ExpressionSyntax Build(TypeMappingBuildContext ctx)

// source switch { A x => MapToA(x), B x => MapToB(x) }
var (typeArmContext, typeArmVariableName) = ctx.WithNewSource();
var arms = _typeMappings.Select(x => BuildSwitchArm(typeArmVariableName, x.Key, x.Value.Build(typeArmContext))).Append(fallbackArm);
var arms = _typeMappings
.Select(x => BuildSwitchArm(typeArmVariableName, x.SourceType, x.Build(typeArmContext)))
.Append(fallbackArm);
return SwitchExpression(ctx.Source).WithArms(CommaSeparatedList(arms, true));
}

Expand Down
@@ -0,0 +1,9 @@
namespace Riok.Mapperly.IntegrationTests.Dto
{
public abstract class TestObjectDtoProjectionBaseType
{
public int Id { get; set; }

public int BaseValue { get; set; }
}
}
@@ -0,0 +1,7 @@
namespace Riok.Mapperly.IntegrationTests.Dto
{
public class TestObjectDtoProjectionTypeA : TestObjectDtoProjectionBaseType
{
public int ValueA { get; set; }
}
}
@@ -0,0 +1,7 @@
namespace Riok.Mapperly.IntegrationTests.Dto
{
public class TestObjectDtoProjectionTypeB : TestObjectDtoProjectionBaseType
{
public int ValueB { get; set; }
}
}
Expand Up @@ -10,6 +10,8 @@ public static partial class ProjectionMapper
{
public static partial IQueryable<TestObjectDtoProjection> ProjectToDto(this IQueryable<TestObjectProjection> q);

public static partial IQueryable<TestObjectDtoProjectionBaseType> ProjectToDto(this IQueryable<TestObjectProjectionBaseType> q);

// disable obsolete warning, as the obsolete attribute should still be tested.
#pragma warning disable CS0618
[MapperIgnore(nameof(TestObjectDtoProjection.IgnoredStringValue))]
Expand All @@ -23,5 +25,9 @@ public static partial class ProjectionMapper
{
return new TestObjectDtoManuallyMappedProjection(100) { StringValue = str, };
}

[MapDerivedType(typeof(TestObjectProjectionTypeA), typeof(TestObjectDtoProjectionTypeA))]
[MapDerivedType(typeof(TestObjectProjectionTypeB), typeof(TestObjectDtoProjectionTypeB))]
private static partial TestObjectDtoProjectionBaseType MapDerived(TestObjectProjectionBaseType source);
}
}
@@ -0,0 +1,9 @@
namespace Riok.Mapperly.IntegrationTests.Models
{
public abstract class TestObjectProjectionBaseType
{
public int Id { get; set; }

public int BaseValue { get; set; }
}
}
@@ -0,0 +1,7 @@
namespace Riok.Mapperly.IntegrationTests.Models
{
public class TestObjectProjectionTypeA : TestObjectProjectionBaseType
{
public int ValueA { get; set; }
}
}
@@ -0,0 +1,7 @@
namespace Riok.Mapperly.IntegrationTests.Models
{
public class TestObjectProjectionTypeB : TestObjectProjectionBaseType
{
public int ValueB { get; set; }
}
}
44 changes: 37 additions & 7 deletions test/Riok.Mapperly.IntegrationTests/ProjectionMapperTest.cs
@@ -1,4 +1,5 @@
using System;
using System.Linq;
using System.Threading.Tasks;
using Riok.Mapperly.IntegrationTests.Mapper;
using Riok.Mapperly.IntegrationTests.Models;
Expand All @@ -23,7 +24,32 @@ public Task SnapshotGeneratedSource()

#if NET7_0_OR_GREATER
[Fact]
public async Task ProjectionShouldTranslateToQuery()
public Task ProjectionShouldTranslateToQuery()
{
return RunWithDatabase(async ctx =>
{
var query = ctx.Objects.ProjectToDto();
await Verifier.Verify(query.ToQueryString(), "sql").UseTextForParameters("query");
var objects = await query.ToListAsync();
await Verifier.Verify(objects).UseTextForParameters("result");
});
}

[Fact]
public Task DerivedTypesProjectionShouldTranslateToQuery()
{
return RunWithDatabase(async ctx =>
{
var query = ctx.BaseTypeObjects.OrderBy(x => x.BaseValue).ProjectToDto();
await Verifier.Verify(query.ToQueryString(), "sql").UseTextForParameters("query");
var objects = await query.ToListAsync();
await Verifier.Verify(objects).UseTextForParameters("result");
});
}

private async Task RunWithDatabase(Func<ProjectionDbContext, Task> action)
{
await using var connection = new SqliteConnection("Data Source=:memory:");
await connection.OpenAsync();
Expand All @@ -33,13 +59,10 @@ public async Task ProjectionShouldTranslateToQuery()
await using var ctx = new ProjectionDbContext(options);
await ctx.Database.EnsureCreatedAsync();
ctx.Objects.Add(CreateObject());
ctx.BaseTypeObjects.Add(new TestObjectProjectionTypeA { BaseValue = 10, ValueA = 10 });
ctx.BaseTypeObjects.Add(new TestObjectProjectionTypeB { BaseValue = 20, ValueB = 20 });
await ctx.SaveChangesAsync();

var query = ctx.Objects.ProjectToDto();
await Verifier.Verify(query.ToQueryString(), "sql").UseTextForParameters("query");

var objects = await query.ToListAsync();
await Verifier.Verify(objects).UseTextForParameters("result");
await action(ctx);
}

private TestObjectProjection CreateObject()
Expand Down Expand Up @@ -82,6 +105,7 @@ public ProjectionDbContext(DbContextOptions options)
: base(options) { }

public DbSet<TestObjectProjection> Objects { get; set; } = null!;
public DbSet<TestObjectProjectionBaseType> BaseTypeObjects { get; set; } = null!;

protected override void OnModelCreating(ModelBuilder modelBuilder)
{
Expand All @@ -92,6 +116,12 @@ protected override void OnModelCreating(ModelBuilder modelBuilder)
modelBuilder.Entity<IdObject>().HasKey(p => p.IdValue);
modelBuilder.Entity<InheritanceSubObject>().HasKey(p => p.SubIntValue);
modelBuilder.Entity<TestObjectNested>().HasKey(p => p.IntValue);

modelBuilder
.Entity<TestObjectProjectionBaseType>()
.HasDiscriminator<string>("type")
.HasValue<TestObjectProjectionTypeA>("A")
.HasValue<TestObjectProjectionTypeB>("B");
}
}
#endif
Expand Down
Expand Up @@ -10,6 +10,13 @@ public static partial class ProjectionMapper
#nullable enable
}

public static partial global::System.Linq.IQueryable<global::Riok.Mapperly.IntegrationTests.Dto.TestObjectDtoProjectionBaseType> ProjectToDto(this global::System.Linq.IQueryable<global::Riok.Mapperly.IntegrationTests.Models.TestObjectProjectionBaseType> q)
{
#nullable disable
return System.Linq.Queryable.Select(q, x => (global::Riok.Mapperly.IntegrationTests.Dto.TestObjectDtoProjectionBaseType)(x is global::Riok.Mapperly.IntegrationTests.Models.TestObjectProjectionTypeA ? new global::Riok.Mapperly.IntegrationTests.Dto.TestObjectDtoProjectionTypeA() { ValueA = ((global::Riok.Mapperly.IntegrationTests.Models.TestObjectProjectionTypeA)x).ValueA, Id = ((global::Riok.Mapperly.IntegrationTests.Models.TestObjectProjectionTypeA)x).Id, BaseValue = ((global::Riok.Mapperly.IntegrationTests.Models.TestObjectProjectionTypeA)x).BaseValue } : x is global::Riok.Mapperly.IntegrationTests.Models.TestObjectProjectionTypeB ? new global::Riok.Mapperly.IntegrationTests.Dto.TestObjectDtoProjectionTypeB() { ValueB = ((global::Riok.Mapperly.IntegrationTests.Models.TestObjectProjectionTypeB)x).ValueB, Id = ((global::Riok.Mapperly.IntegrationTests.Models.TestObjectProjectionTypeB)x).Id, BaseValue = ((global::Riok.Mapperly.IntegrationTests.Models.TestObjectProjectionTypeB)x).BaseValue } : default));
#nullable enable
}

private static partial global::Riok.Mapperly.IntegrationTests.Dto.TestObjectDtoProjection ProjectToDto(this global::Riok.Mapperly.IntegrationTests.Models.TestObjectProjection testObject)
{
var target = new global::Riok.Mapperly.IntegrationTests.Dto.TestObjectDtoProjection(testObject.CtorValue)
Expand Down Expand Up @@ -60,6 +67,34 @@ public static partial class ProjectionMapper
return target;
}

private static partial global::Riok.Mapperly.IntegrationTests.Dto.TestObjectDtoProjectionBaseType MapDerived(global::Riok.Mapperly.IntegrationTests.Models.TestObjectProjectionBaseType source)
{
return source switch
{
global::Riok.Mapperly.IntegrationTests.Models.TestObjectProjectionTypeA x => MapToTestObjectDtoProjectionTypeA(x),
global::Riok.Mapperly.IntegrationTests.Models.TestObjectProjectionTypeB x => MapToTestObjectDtoProjectionTypeB(x),
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to Riok.Mapperly.IntegrationTests.Dto.TestObjectDtoProjectionBaseType as there is no known derived type mapping", nameof(source)),
};
}

private static global::Riok.Mapperly.IntegrationTests.Dto.TestObjectDtoProjectionTypeA MapToTestObjectDtoProjectionTypeA(global::Riok.Mapperly.IntegrationTests.Models.TestObjectProjectionTypeA source)
{
var target = new global::Riok.Mapperly.IntegrationTests.Dto.TestObjectDtoProjectionTypeA();
target.ValueA = source.ValueA;
target.Id = source.Id;
target.BaseValue = source.BaseValue;
return target;
}

private static global::Riok.Mapperly.IntegrationTests.Dto.TestObjectDtoProjectionTypeB MapToTestObjectDtoProjectionTypeB(global::Riok.Mapperly.IntegrationTests.Models.TestObjectProjectionTypeB source)
{
var target = new global::Riok.Mapperly.IntegrationTests.Dto.TestObjectDtoProjectionTypeB();
target.ValueB = source.ValueB;
target.Id = source.Id;
target.BaseValue = source.BaseValue;
return target;
}

private static global::Riok.Mapperly.IntegrationTests.Dto.TestObjectNestedDto MapToTestObjectNestedDto(global::Riok.Mapperly.IntegrationTests.Models.TestObjectNested source)
{
var target = new global::Riok.Mapperly.IntegrationTests.Dto.TestObjectNestedDto();
Expand Down
Expand Up @@ -15,6 +15,15 @@ public static partial class ProjectionMapper
#nullable enable
}

public static partial global::System.Linq.IQueryable<global::Riok.Mapperly.IntegrationTests.Dto.TestObjectDtoProjectionBaseType> ProjectToDto(this global::System.Linq.IQueryable<global::Riok.Mapperly.IntegrationTests.Models.TestObjectProjectionBaseType> q)
{
#nullable disable
return System.Linq.Queryable.Select(q, x => (global::Riok.Mapperly.IntegrationTests.Dto.TestObjectDtoProjectionBaseType)(x is global::Riok.Mapperly.IntegrationTests.Models.TestObjectProjectionTypeA ? new global::Riok.Mapperly.IntegrationTests.Dto.TestObjectDtoProjectionTypeA()
{ValueA = ((global::Riok.Mapperly.IntegrationTests.Models.TestObjectProjectionTypeA)x).ValueA, Id = ((global::Riok.Mapperly.IntegrationTests.Models.TestObjectProjectionTypeA)x).Id, BaseValue = ((global::Riok.Mapperly.IntegrationTests.Models.TestObjectProjectionTypeA)x).BaseValue} : x is global::Riok.Mapperly.IntegrationTests.Models.TestObjectProjectionTypeB ? new global::Riok.Mapperly.IntegrationTests.Dto.TestObjectDtoProjectionTypeB()
{ValueB = ((global::Riok.Mapperly.IntegrationTests.Models.TestObjectProjectionTypeB)x).ValueB, Id = ((global::Riok.Mapperly.IntegrationTests.Models.TestObjectProjectionTypeB)x).Id, BaseValue = ((global::Riok.Mapperly.IntegrationTests.Models.TestObjectProjectionTypeB)x).BaseValue} : default));
#nullable enable
}

private static partial global::Riok.Mapperly.IntegrationTests.Dto.TestObjectDtoProjection ProjectToDto(this global::Riok.Mapperly.IntegrationTests.Models.TestObjectProjection testObject)
{
var target = new global::Riok.Mapperly.IntegrationTests.Dto.TestObjectDtoProjection(testObject.CtorValue)
Expand Down Expand Up @@ -62,6 +71,34 @@ public static partial class ProjectionMapper
return target;
}

private static partial global::Riok.Mapperly.IntegrationTests.Dto.TestObjectDtoProjectionBaseType MapDerived(global::Riok.Mapperly.IntegrationTests.Models.TestObjectProjectionBaseType source)
{
return source switch
{
global::Riok.Mapperly.IntegrationTests.Models.TestObjectProjectionTypeA x => MapToTestObjectDtoProjectionTypeA(x),
global::Riok.Mapperly.IntegrationTests.Models.TestObjectProjectionTypeB x => MapToTestObjectDtoProjectionTypeB(x),
_ => throw new System.ArgumentException($"Cannot map {source.GetType()} to Riok.Mapperly.IntegrationTests.Dto.TestObjectDtoProjectionBaseType as there is no known derived type mapping", nameof(source)),
};
}

private static global::Riok.Mapperly.IntegrationTests.Dto.TestObjectDtoProjectionTypeA MapToTestObjectDtoProjectionTypeA(global::Riok.Mapperly.IntegrationTests.Models.TestObjectProjectionTypeA source)
{
var target = new global::Riok.Mapperly.IntegrationTests.Dto.TestObjectDtoProjectionTypeA();
target.ValueA = source.ValueA;
target.Id = source.Id;
target.BaseValue = source.BaseValue;
return target;
}

private static global::Riok.Mapperly.IntegrationTests.Dto.TestObjectDtoProjectionTypeB MapToTestObjectDtoProjectionTypeB(global::Riok.Mapperly.IntegrationTests.Models.TestObjectProjectionTypeB source)
{
var target = new global::Riok.Mapperly.IntegrationTests.Dto.TestObjectDtoProjectionTypeB();
target.ValueB = source.ValueB;
target.Id = source.Id;
target.BaseValue = source.BaseValue;
return target;
}

private static global::Riok.Mapperly.IntegrationTests.Dto.TestObjectNestedDto MapToTestObjectNestedDto(global::Riok.Mapperly.IntegrationTests.Models.TestObjectNested source)
{
var target = new global::Riok.Mapperly.IntegrationTests.Dto.TestObjectNestedDto();
Expand Down
@@ -0,0 +1,3 @@
SELECT "b"."type" = 'A', "b"."ValueA", "b"."Id", "b"."BaseValue", "b"."type" = 'B', "b"."ValueB"
FROM "BaseTypeObjects" AS "b"
ORDER BY "b"."BaseValue"
@@ -0,0 +1,12 @@
[
{
ValueA: 10,
Id: 1,
BaseValue: 10
},
{
ValueB: 20,
Id: 2,
BaseValue: 20
}
]

0 comments on commit 98a44a0

Please sign in to comment.