From 3c567ce024ceede4713d054c7e898c8acd77d45b Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Mon, 22 Jan 2024 10:28:56 +0100 Subject: [PATCH] Use provider discriminator values for OfType() Fixes #32865 --- ...lationalSqlTranslatingExpressionVisitor.cs | 19 +++++++++++++++++-- .../Query/SqlExpressionFactory.cs | 13 +++++++++++-- .../Query/TPHInheritanceQuerySqlServerTest.cs | 16 ++++++++++++++++ 3 files changed, 44 insertions(+), 4 deletions(-) diff --git a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs index 9d854645ae7..026f9cfea1c 100644 --- a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs @@ -1147,13 +1147,28 @@ SqlExpression GeneratePredicateTpt(StructuralTypeProjectionExpression entityProj { var concreteEntityTypes = derivedType.GetConcreteDerivedTypesInclusive().ToList(); var discriminatorColumn = BindProperty(typeReference, discriminatorProperty); + + // Apply any value conversion to the discriminator values. + // Note that this is important also to get the correct SqlConstantExpression.Type, which needs to be the provider type + // rather than the model type; this is in line with how we translate constants everywhere else, and is important in order + // for comparison logic between constants to function correctly (see #32865). + var converter = discriminatorColumn.TypeMapping?.Converter?.ConvertToProvider; + return concreteEntityTypes.Count == 1 ? _sqlExpressionFactory.Equal( discriminatorColumn, - _sqlExpressionFactory.Constant(concreteEntityTypes[0].GetDiscriminatorValue())) + _sqlExpressionFactory.Constant(GetDiscriminatorValue(concreteEntityTypes[0]))) : _sqlExpressionFactory.In( discriminatorColumn, - concreteEntityTypes.Select(et => _sqlExpressionFactory.Constant(et.GetDiscriminatorValue())).ToArray()); + concreteEntityTypes.Select(et => _sqlExpressionFactory.Constant(GetDiscriminatorValue(et))).ToArray()); + + object? GetDiscriminatorValue(IEntityType entityType) + => entityType.GetDiscriminatorValue() switch + { + object value when converter is not null => converter(value), + object value => value, + null => null + }; } return _sqlExpressionFactory.Constant(true); diff --git a/src/EFCore.Relational/Query/SqlExpressionFactory.cs b/src/EFCore.Relational/Query/SqlExpressionFactory.cs index 2f83856b7e8..d603cd9d93a 100644 --- a/src/EFCore.Relational/Query/SqlExpressionFactory.cs +++ b/src/EFCore.Relational/Query/SqlExpressionFactory.cs @@ -806,14 +806,23 @@ private void AddConditions(SelectExpression selectExpression, IEntityType entity { var discriminatorColumn = GetMappedProjection(selectExpression).BindProperty(discriminatorProperty); var concreteEntityTypes = entityType.GetConcreteDerivedTypesInclusive().ToList(); + var converter = discriminatorColumn.TypeMapping?.Converter?.ConvertToProvider; var predicate = concreteEntityTypes.Count == 1 - ? (SqlExpression)Equal(discriminatorColumn, Constant(concreteEntityTypes[0].GetDiscriminatorValue())) - : In(discriminatorColumn, concreteEntityTypes.Select(et => Constant(et.GetDiscriminatorValue())).ToArray()); + ? (SqlExpression)Equal(discriminatorColumn, Constant(GetDiscriminatorValue(concreteEntityTypes[0]))) + : In(discriminatorColumn, concreteEntityTypes.Select(et => Constant(GetDiscriminatorValue(et))).ToArray()); selectExpression.ApplyPredicate(predicate); // If discriminator predicate is added then it will also serve as condition for existence of dependents in table sharing return; + + object? GetDiscriminatorValue(IEntityType entityType) + => entityType.GetDiscriminatorValue() switch + { + object value when converter is not null => converter(value), + object value => value, + null => null + }; } // Keyless entities cannot be table sharing diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/TPHInheritanceQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/TPHInheritanceQuerySqlServerTest.cs index eadf91c0291..872d485c96c 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/TPHInheritanceQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/TPHInheritanceQuerySqlServerTest.cs @@ -226,6 +226,22 @@ public override async Task Can_use_of_type_rose(bool async) """); } + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task OfType_on_enum_discriminator_with_Where_on_same_discriminator(bool async) + { + await AssertQuery( + async, + ss => ss.Set().OfType().Where(p => p.Genus == PlantGenus.Rose)); + + AssertSql( + """ +SELECT [p].[Species], [p].[CountryId], [p].[Genus], [p].[Name], [p].[HasThorns] +FROM [Plants] AS [p] +WHERE [p].[Genus] = 0 +"""); + } + public override async Task Can_query_all_animals(bool async) { await base.Can_query_all_animals(async);