Skip to content

Commit

Permalink
Use provider discriminator values for OfType()
Browse files Browse the repository at this point in the history
  • Loading branch information
roji committed Jan 22, 2024
1 parent af05058 commit 3c567ce
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1147,13 +1147,28 @@ SqlExpression GeneratePredicateTpt(StructuralTypeProjectionExpression entityProj
{
var concreteEntityTypes = derivedType.GetConcreteDerivedTypesInclusive().ToList();
var discriminatorColumn = BindProperty(typeReference, discriminatorProperty);

// Apply any value conversion to the discriminator values.
// Note that this is important also to get the correct SqlConstantExpression.Type, which needs to be the provider type
// rather than the model type; this is in line with how we translate constants everywhere else, and is important in order
// for comparison logic between constants to function correctly (see #32865).
var converter = discriminatorColumn.TypeMapping?.Converter?.ConvertToProvider;

return concreteEntityTypes.Count == 1
? _sqlExpressionFactory.Equal(
discriminatorColumn,
_sqlExpressionFactory.Constant(concreteEntityTypes[0].GetDiscriminatorValue()))
_sqlExpressionFactory.Constant(GetDiscriminatorValue(concreteEntityTypes[0])))
: _sqlExpressionFactory.In(
discriminatorColumn,
concreteEntityTypes.Select(et => _sqlExpressionFactory.Constant(et.GetDiscriminatorValue())).ToArray());
concreteEntityTypes.Select(et => _sqlExpressionFactory.Constant(GetDiscriminatorValue(et))).ToArray());

object? GetDiscriminatorValue(IEntityType entityType)
=> entityType.GetDiscriminatorValue() switch
{
object value when converter is not null => converter(value),
object value => value,
null => null
};
}

return _sqlExpressionFactory.Constant(true);
Expand Down
13 changes: 11 additions & 2 deletions src/EFCore.Relational/Query/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -806,14 +806,23 @@ private void AddConditions(SelectExpression selectExpression, IEntityType entity
{
var discriminatorColumn = GetMappedProjection(selectExpression).BindProperty(discriminatorProperty);
var concreteEntityTypes = entityType.GetConcreteDerivedTypesInclusive().ToList();
var converter = discriminatorColumn.TypeMapping?.Converter?.ConvertToProvider;
var predicate = concreteEntityTypes.Count == 1
? (SqlExpression)Equal(discriminatorColumn, Constant(concreteEntityTypes[0].GetDiscriminatorValue()))
: In(discriminatorColumn, concreteEntityTypes.Select(et => Constant(et.GetDiscriminatorValue())).ToArray());
? (SqlExpression)Equal(discriminatorColumn, Constant(GetDiscriminatorValue(concreteEntityTypes[0])))
: In(discriminatorColumn, concreteEntityTypes.Select(et => Constant(GetDiscriminatorValue(et))).ToArray());

selectExpression.ApplyPredicate(predicate);

// If discriminator predicate is added then it will also serve as condition for existence of dependents in table sharing
return;

object? GetDiscriminatorValue(IEntityType entityType)
=> entityType.GetDiscriminatorValue() switch
{
object value when converter is not null => converter(value),
object value => value,
null => null
};
}

// Keyless entities cannot be table sharing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,22 @@ FROM [Plants] AS [p]
""");
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task OfType_on_enum_discriminator_with_Where_on_same_discriminator(bool async)
{
await AssertQuery(
async,
ss => ss.Set<Plant>().OfType<Rose>().Where(p => p.Genus == PlantGenus.Rose));

AssertSql(
"""
SELECT [p].[Species], [p].[CountryId], [p].[Genus], [p].[Name], [p].[HasThorns]
FROM [Plants] AS [p]
WHERE [p].[Genus] = 0
""");
}

public override async Task Can_query_all_animals(bool async)
{
await base.Can_query_all_animals(async);
Expand Down

0 comments on commit 3c567ce

Please sign in to comment.