Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions Orm/Xtensive.Orm.Tests/Linq/MemberCompilerProviderTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public partial class MemberCompilerProviderTest
{
private readonly string[] dummy = new string[10];

private static Func<string, string[],string> GetCompilerForMethod(
private static IMemberCompilerProvider<string>.BoundCompiler GetCompilerForMethod(
IMemberCompilerProvider<string> provider, Type source, string methodName)
{
if (source.IsGenericTypeDefinition)
Expand All @@ -31,23 +31,23 @@ private static Func<string, string[],string> GetCompilerForMethod(
if (mi.IsGenericMethodDefinition)
mi = mi.MakeGenericMethod(typeof(object));
var result = provider.GetCompiler(mi);
Assert.IsNotNull(result);
Assert.IsFalse(result.IsNull);
return result;
}

private static Func<string, string[], string> GetCompilerForCtor(
private static IMemberCompilerProvider<string>.BoundCompiler GetCompilerForCtor(
IMemberCompilerProvider<string> provider, Type source)
{
if (source.IsGenericTypeDefinition)
source = source.MakeGenericType(typeof(object));

var ci = source.GetConstructors().First();
var result = provider.GetCompiler(ci);
Assert.IsNotNull(result);
Assert.IsFalse(result.IsNull);
return result;
}

private static Func<string, string[], string> GetCompilerForField(
private static IMemberCompilerProvider<string>.BoundCompiler GetCompilerForField(
IMemberCompilerProvider<string> provider, Type source, string fieldName)
{
if (source.IsGenericTypeDefinition)
Expand All @@ -56,7 +56,7 @@ private static Func<string, string[], string> GetCompilerForField(
var fi = source.GetField(fieldName);
Assert.IsNotNull(fi);
var result = provider.GetCompiler(fi);
Assert.IsNotNull(result);
Assert.IsFalse(result.IsNull);
return result;
}

Expand All @@ -71,7 +71,7 @@ public void MethodsTest()
foreach (string s2 in new[]{"Generic", "NonGeneric"}) {
string method = s1 + s2 + "Method";
var d = GetCompilerForMethod(provider, t, method);
Assert.AreEqual(t.Name + "." + method, d(null, dummy));
Assert.AreEqual(t.Name + "." + method, d.Invoke(null, dummy));
}
}

Expand All @@ -86,7 +86,7 @@ public void PropertiesTest()
foreach (string s2 in new[] { "InstanceProperty", "StaticProperty", "Item" }) {
string method = s1 + s2;
var d = GetCompilerForMethod(provider, t, method);
Assert.AreEqual(t.Name + "." + method, d(null, dummy));
Assert.AreEqual(t.Name + "." + method, d.Invoke(null, dummy));
}
}

Expand All @@ -99,7 +99,7 @@ public void FieldsTest()
foreach (var t in new[]{typeof(NonGenericTarget), typeof(GenericTarget<>)})
foreach (string s in new[] {"InstanceField", "StaticField"}) {
var d = GetCompilerForField(provider, t, s);
Assert.AreEqual(t.Name + "." + s, d(null, dummy));
Assert.AreEqual(t.Name + "." + s, d.Invoke(null, dummy));
}
}

Expand All @@ -110,7 +110,7 @@ public void CtorsTest()
provider.RegisterCompilers(typeof(CtorCompiler));
foreach (var t in new[]{typeof(NonGenericTarget), typeof(GenericTarget<>)}) {
var d = GetCompilerForCtor(provider, t);
Assert.AreEqual(t.Name + Reflection.WellKnown.CtorName, d(null, dummy));
Assert.AreEqual(t.Name + Reflection.WellKnown.CtorName, d.Invoke(null, dummy));
}
}

Expand All @@ -128,8 +128,8 @@ public void GenericFindTest()
.MakeGenericMethod(typeof(string));

var d = provider.GetCompiler(mi);
Assert.IsNotNull(d);
Assert.AreEqual("OK", d(null, dummy));
Assert.IsFalse(d.IsNull);
Assert.AreEqual("OK", d.Invoke(null, dummy));
}

[Test]
Expand Down Expand Up @@ -160,7 +160,7 @@ public void ConflictKeepOldTest()
provider.RegisterCompilers(typeof(ConflictCompiler1));
provider.RegisterCompilers(typeof(ConflictCompiler2), ConflictHandlingMethod.KeepOld);
var d = GetCompilerForMethod(provider, typeof(ConflictTarget), "ConflictMethod");
Assert.AreEqual("Compiler1", d(null, dummy));
Assert.AreEqual("Compiler1", d.Invoke(null, dummy));
}

[Test]
Expand All @@ -170,7 +170,7 @@ public void ConflictOverwriteTest()
provider.RegisterCompilers(typeof(ConflictCompiler1));
provider.RegisterCompilers(typeof(ConflictCompiler2), ConflictHandlingMethod.Overwrite);
var d = GetCompilerForMethod(provider, typeof(ConflictTarget), "ConflictMethod");
Assert.AreEqual("Compiler2", d(null, dummy));
Assert.AreEqual("Compiler2", d.Invoke(null, dummy));
}

[Test]
Expand All @@ -189,7 +189,7 @@ public void NonPublicGetterTest()
.GetProperty("InternalProperty", BindingFlags.Instance | BindingFlags.NonPublic);
Assert.IsNotNull(property);
var result = provider.GetCompiler(property);
Assert.IsNull(result);
Assert.IsTrue(result.IsNull);
}

[Test]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
// Created by: Denis Krjuchkov
// Created: 2009.02.09

using System;
using System.Collections.Generic;
using System.Reflection;

namespace Xtensive.Orm.Linq.MemberCompilation
Expand Down Expand Up @@ -36,12 +34,19 @@ public interface IMemberCompilerProvider<T> : IMemberCompilerProvider
/// <param name="conflictHandlingMethod">Conflict handling method.</param>
void RegisterCompilers(IEnumerable<KeyValuePair<MemberInfo, Func<MemberInfo, T, T[], T>>> compilerDefinitions, ConflictHandlingMethod conflictHandlingMethod);

readonly struct BoundCompiler(Func<MemberInfo, T, T[], T> compiler, MemberInfo memberInfo)
{
public bool IsNull => compiler == null;

public T Invoke(T arg2, T[] arg3) => compiler(memberInfo, arg2, arg3);
}

/// <summary>
/// Finds compiler for specified <see cref="MemberInfo"/>.
/// </summary>
/// <param name="target"><see cref="MemberInfo"/> to search compiler for.</param>
/// <returns>compiler associated with <see cref="MethodInfo"/>
/// or <see langword="null"/> if compiler is not found.</returns>
Func<T, T[], T> GetCompiler(MemberInfo target);
BoundCompiler GetCompiler(MemberInfo target);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
// Created by: Denis Krjuchkov
// Created: 2009.02.09

using System;
using System.Collections.Generic;
using System.Linq;
using System.Collections.Frozen;
using System.Reflection;
using Xtensive.Core;
using Xtensive.Reflection;
Expand Down Expand Up @@ -36,24 +34,25 @@ public CompilerKey(MemberInfo memberInfo)
}
}

private readonly Dictionary<CompilerKey, Delegate> compilers
= new Dictionary<CompilerKey, Delegate>();
private IDictionary<CompilerKey, Delegate> compilers = new Dictionary<CompilerKey, Delegate>();

public Type ExpressionType => typeof(T);

public Delegate GetUntypedCompiler(MemberInfo target)
public override void Lock(bool recursive)
{
ArgumentNullException.ThrowIfNull(target);

return compilers.GetValueOrDefault(GetCompilerKey(target));
base.Lock(recursive);
compilers = compilers.ToFrozenDictionary();
}

public Func<T, T[], T> GetCompiler(MemberInfo target)
public Delegate GetUntypedCompiler(MemberInfo target)
{
var compiler = (Func<MemberInfo, T, T[], T>) GetUntypedCompiler(target);
return compiler.Bind(target);
ArgumentNullException.ThrowIfNull(target);
return compilers.TryGetValue(GetCompilerKey(target), out var v) ? v : null;
}

public IMemberCompilerProvider<T>.BoundCompiler GetCompiler(MemberInfo target) =>
new((Func<MemberInfo, T, T[], T>) GetUntypedCompiler(target), target);

public void RegisterCompilers(Type compilerContainer)
{
RegisterCompilers(compilerContainer, ConflictHandlingMethod.Default);
Expand Down
12 changes: 6 additions & 6 deletions Orm/Xtensive.Orm/Orm/Linq/Translator.Expressions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -361,18 +361,18 @@ protected override Expression VisitMember(MemberExpression ma)
// Reflected type doesn't have custom compiler defined, so falling back to base class compiler
var declaringType = memberInfo.DeclaringType;
var reflectedType = memberInfo.ReflectedType;
if (customCompiler == null && declaringType != reflectedType && declaringType.IsAssignableFrom(reflectedType)) {
if (customCompiler.IsNull && declaringType != reflectedType && declaringType.IsAssignableFrom(reflectedType)) {
var root = declaringType;
var current = reflectedType;
while (current != root && customCompiler == null) {
while (current != root && customCompiler.IsNull) {
current = current.BaseType;
var member = current.GetProperty(memberInfo.Name, BindingFlags.Instance|BindingFlags.Public|BindingFlags.NonPublic);
customCompiler = context.CustomCompilerProvider.GetCompiler(member);
}
}

if (customCompiler != null) {
var expression = customCompiler.Invoke(sourceExpression, Array.Empty<Expression>());
if (!customCompiler.IsNull) {
var expression = customCompiler.Invoke(sourceExpression, []);
if (expression == null) {
if (reflectedType.IsInterface)
return Visit(BuildInterfaceExpression(ma));
Expand Down Expand Up @@ -435,7 +435,7 @@ protected override Expression VisitMethodCall(MethodCallExpression mc)
using (CreateScope(new TranslatorState(State) { IsTailMethod = mc == context.Query && mc.IsQuery() })) {
var method = mc.Method;
var customCompiler = context.CustomCompilerProvider.GetCompiler(method);
if (customCompiler != null) {
if (!customCompiler.IsNull) {
return Visit(customCompiler.Invoke(mc.Object, mc.Arguments.ToArray()));
}

Expand Down Expand Up @@ -1841,7 +1841,7 @@ private Expression BuildExpression(MemberExpression ma, IEnumerable<PropertyInfo
Expression current = Expression.Constant(defaultValue, propertyType);
foreach (var field in fields) {
var compiler = context.CustomCompilerProvider.GetCompiler(field);
if (compiler == null)
if (compiler.IsNull)
continue;
var expression = compiler.Invoke(Expression.TypeAs(ma.Expression, field.ReflectedType), null);
current = Expression.Condition(Expression.TypeIs(ma.Expression, field.ReflectedType), expression, current);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,10 @@ private static SqlExpression OracleBlobCompare(SqlExpression left, SqlExpression

private SqlExpression CompileMember(MemberInfo member, SqlExpression instance, params SqlExpression[] arguments)
{
var memberCompiler = memberCompilerProvider.GetCompiler(member)
?? throw new NotSupportedException(string.Format(Strings.ExMemberXIsNotSupported, member.GetFullName(true)));
var memberCompiler = memberCompilerProvider.GetCompiler(member);
if (memberCompiler.IsNull) {
throw new NotSupportedException(string.Format(Strings.ExMemberXIsNotSupported, member.GetFullName(true)));
}
return memberCompiler.Invoke(instance, arguments);
}

Expand Down