Skip to content

Commit

Permalink
Some custom aggregate translations
Browse files Browse the repository at this point in the history
* string.Join (SQL Server and SQLite)
* string.Concat (SQL Server and SQLite)
* Standard deviation and variance (SQL Server)

Closes dotnet#2981
Closes dotnet#28104
  • Loading branch information
roji committed May 26, 2022
1 parent 4cb70fd commit 287ecfd
Show file tree
Hide file tree
Showing 9 changed files with 855 additions and 12 deletions.
316 changes: 306 additions & 10 deletions src/EFCore.SqlServer/Extensions/SqlServerDbFunctionsExtensions.cs

Large diffs are not rendered by default.

Expand Up @@ -21,10 +21,14 @@ public SqlServerAggregateMethodCallTranslatorProvider(RelationalAggregateMethodC
: base(dependencies)
{
var sqlExpressionFactory = dependencies.SqlExpressionFactory;
var typeMappingSource = dependencies.RelationalTypeMappingSource;

AddTranslators(
new IAggregateMethodCallTranslator[]
{
new SqlServerLongCountMethodTranslator(sqlExpressionFactory)
new SqlServerLongCountMethodTranslator(sqlExpressionFactory),
new SqlServerStatisticsAggregateMethodTranslator(sqlExpressionFactory, typeMappingSource),
new SqlServerStringAggregateMethodTranslator(sqlExpressionFactory, typeMappingSource)
});
}
}
@@ -0,0 +1,90 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.EntityFrameworkCore.Query.SqlExpressions;

namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public class SqlServerStatisticsAggregateMethodTranslator : IAggregateMethodCallTranslator
{
private readonly ISqlExpressionFactory _sqlExpressionFactory;
private readonly RelationalTypeMapping _doubleTypeMapping;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public SqlServerStatisticsAggregateMethodTranslator(
ISqlExpressionFactory sqlExpressionFactory,
IRelationalTypeMappingSource typeMappingSource)
{
_sqlExpressionFactory = sqlExpressionFactory;
_doubleTypeMapping = typeMappingSource.FindMapping(typeof(double))!;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual SqlExpression? Translate(
MethodInfo method, EnumerableExpression source, IReadOnlyList<SqlExpression> arguments,
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
{
// Docs: https://docs.microsoft.com/sql/t-sql/functions/aggregate-functions-transact-sql

if (method.DeclaringType != typeof(SqlServerDbFunctionsExtensions)
|| source.Selector is not SqlExpression sqlExpression)
{
return null;
}

var functionName = method.Name switch
{
nameof(SqlServerDbFunctionsExtensions.StandardDeviationSample) => "STDEV",
nameof(SqlServerDbFunctionsExtensions.StandardDeviationPopulation) => "STDEVP",
nameof(SqlServerDbFunctionsExtensions.VarianceSample) => "VAR",
nameof(SqlServerDbFunctionsExtensions.VariancePopulation) => "VARP",
_ => null
};

if (functionName is null)
{
return null;
}

if (source.Predicate != null)
{
if (sqlExpression is SqlFragmentExpression)
{
sqlExpression = _sqlExpressionFactory.Constant(1);
}

sqlExpression = _sqlExpressionFactory.Case(
new List<CaseWhenClause> { new(source.Predicate, sqlExpression) },
elseResult: null);
}

if (source.IsDistinct)
{
sqlExpression = new DistinctExpression(sqlExpression);
}

return _sqlExpressionFactory.Function(
functionName,
new[] { sqlExpression },
nullable: true,
argumentsPropagateNullability: new[] { false },
typeof(double),
_doubleTypeMapping);
}
}
@@ -0,0 +1,124 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.EntityFrameworkCore.Query.SqlExpressions;

namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public class SqlServerStringAggregateMethodTranslator : IAggregateMethodCallTranslator
{
private static readonly MethodInfo StringConcatMethod
= typeof(string).GetRuntimeMethod(nameof(string.Concat), new[] { typeof(IEnumerable<string>) })!;

private static readonly MethodInfo StringJoinMethod
= typeof(string).GetRuntimeMethod(nameof(string.Join), new[] { typeof(string), typeof(IEnumerable<string>) })!;

private readonly ISqlExpressionFactory _sqlExpressionFactory;
private readonly IRelationalTypeMappingSource _typeMappingSource;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public SqlServerStringAggregateMethodTranslator(
ISqlExpressionFactory sqlExpressionFactory,
IRelationalTypeMappingSource typeMappingSource)
{
_sqlExpressionFactory = sqlExpressionFactory;
_typeMappingSource = typeMappingSource;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual SqlExpression? Translate(
MethodInfo method, EnumerableExpression source, IReadOnlyList<SqlExpression> arguments,
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
{
// Docs: https://docs.microsoft.com/sql/t-sql/functions/string-agg-transact-sql

if (source.Selector is not SqlExpression sqlExpression
|| (method != StringJoinMethod && method != StringConcatMethod))
{
return null;
}

// STRING_AGG enlarges the return type size (e.g. for input VARCHAR(5), it returns VARCHAR(8000)).
// See https://docs.microsoft.com/sql/t-sql/functions/string-agg-transact-sql#return-types
var resultTypeMapping = sqlExpression.TypeMapping;
if (resultTypeMapping?.Size != null)
{
if (resultTypeMapping.IsUnicode && resultTypeMapping.Size < 8000)
{
resultTypeMapping = _typeMappingSource.FindMapping(
typeof(string),
resultTypeMapping.StoreTypeNameBase,
unicode: true,
size: 8000);
}
else if (!resultTypeMapping.IsUnicode && resultTypeMapping.Size < 4000)
{
resultTypeMapping = _typeMappingSource.FindMapping(
typeof(string),
resultTypeMapping.StoreTypeNameBase,
unicode: false,
size: 4000);
}
}

// STRING_AGG filters out nulls, but string.Join treats them as empty strings; coalesce unless we know we're aggregating over
// a non-nullable column.
if (sqlExpression is not ColumnExpression { IsNullable: false })
{
sqlExpression = _sqlExpressionFactory.Coalesce(
sqlExpression,
_sqlExpressionFactory.Constant(string.Empty, typeof(string)));
}

if (source.Predicate != null)
{
if (sqlExpression is SqlFragmentExpression)
{
sqlExpression = _sqlExpressionFactory.Constant(1);
}

sqlExpression = _sqlExpressionFactory.Case(
new List<CaseWhenClause> { new(source.Predicate, sqlExpression) },
elseResult: null);
}

if (source.IsDistinct)
{
sqlExpression = new DistinctExpression(sqlExpression);
}

// STRING_AGG returns null when there are no rows (or non-null values), but string.Join returns an empty string.
return
_sqlExpressionFactory.Coalesce(
_sqlExpressionFactory.Function(
"STRING_AGG",
new[]
{
sqlExpression,
_sqlExpressionFactory.ApplyTypeMapping(
method == StringJoinMethod ? arguments[0] : _sqlExpressionFactory.Constant(string.Empty, typeof(string)),
sqlExpression.TypeMapping)
},
nullable: true,
argumentsPropagateNullability: new[] { false, true },
typeof(string)),
_sqlExpressionFactory.Constant(string.Empty, typeof(string)),
resultTypeMapping);
}
}
Expand Up @@ -25,7 +25,8 @@ public SqliteAggregateMethodCallTranslatorProvider(RelationalAggregateMethodCall
AddTranslators(
new IAggregateMethodCallTranslator[]
{
new SqliteQueryableAggregateMethodTranslator(sqlExpressionFactory)
new SqliteQueryableAggregateMethodTranslator(sqlExpressionFactory),
new SqliteStringAggregateMethodTranslator(sqlExpressionFactory)
});
}
}
@@ -0,0 +1,94 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.EntityFrameworkCore.Query.SqlExpressions;

namespace Microsoft.EntityFrameworkCore.Sqlite.Query.Internal;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public class SqliteStringAggregateMethodTranslator : IAggregateMethodCallTranslator
{
private static readonly MethodInfo StringConcatMethod
= typeof(string).GetRuntimeMethod(nameof(string.Concat), new[] { typeof(IEnumerable<string>) })!;

private static readonly MethodInfo StringJoinMethod
= typeof(string).GetRuntimeMethod(nameof(string.Join), new[] { typeof(string), typeof(IEnumerable<string>) })!;

private readonly ISqlExpressionFactory _sqlExpressionFactory;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public SqliteStringAggregateMethodTranslator(ISqlExpressionFactory sqlExpressionFactory)
=> _sqlExpressionFactory = sqlExpressionFactory;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual SqlExpression? Translate(
MethodInfo method, EnumerableExpression source, IReadOnlyList<SqlExpression> arguments,
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
{
// Docs: https://sqlite.org/lang_aggfunc.html#group_concat

if (source.Selector is not SqlExpression sqlExpression
|| (method != StringJoinMethod && method != StringConcatMethod))
{
return null;
}

// group_concat filters out nulls, but string.Join treats them as empty strings; coalesce unless we know we're aggregating over
// a non-nullable column.
if (sqlExpression is not ColumnExpression { IsNullable: false })
{
sqlExpression = _sqlExpressionFactory.Coalesce(
sqlExpression,
_sqlExpressionFactory.Constant(string.Empty, typeof(string)));
}

if (source.Predicate != null)
{
if (sqlExpression is SqlFragmentExpression)
{
sqlExpression = _sqlExpressionFactory.Constant(1);
}

sqlExpression = _sqlExpressionFactory.Case(
new List<CaseWhenClause> { new(source.Predicate, sqlExpression) },
elseResult: null);
}

if (source.IsDistinct)
{
sqlExpression = new DistinctExpression(sqlExpression);
}

// group_concat returns null when there are no rows (or non-null values), but string.Join returns an empty string.
return _sqlExpressionFactory.Coalesce(
_sqlExpressionFactory.Function(
"group_concat",
new[]
{
sqlExpression,
_sqlExpressionFactory.ApplyTypeMapping(
method == StringJoinMethod ? arguments[0] : _sqlExpressionFactory.Constant(string.Empty, typeof(string)),
sqlExpression.TypeMapping)
},
nullable: true,
argumentsPropagateNullability: new[] { false, true },
typeof(string)),
_sqlExpressionFactory.Constant(string.Empty, typeof(string)),
sqlExpression.TypeMapping);
}
}

0 comments on commit 287ecfd

Please sign in to comment.