Permalink
Browse files

Support boost by constant float.

  • Loading branch information...
1 parent 1635b16 commit 4288da3064b766e7326af76b79d24678e40f7550 @chriseldredge chriseldredge committed May 24, 2012
@@ -1,4 +1,7 @@
+using System;
using System.Linq;
+using Lucene.Net.Analysis;
+using Lucene.Net.Analysis.Standard;
using NUnit.Framework;
namespace Lucene.Net.Linq.Tests.Integration
@@ -12,27 +15,74 @@ public class BoostTests : IntegrationTestBase
public void AddDocuments()
{
AddDocument(new SampleDocument { Name = "sample", Id = "0", Scalar = 0});
- AddDocument(new SampleDocument { Name = "sample", Id = "1", Scalar = 1});
+ AddDocument(new SampleDocument { Name = "other", Id = "1", Scalar = 1});
documents = provider.AsQueryable<SampleDocument>();
}
[Test]
- public void Single()
+ public void Boost()
{
- var first = documents.Where(d => d.Name == "sample").Boost(d => d.Scalar);
+ var query = from d in documents where d.Name == "sample" || d.Id.Boost(0) == "0" select d;
+
+ Assert.That(query.First().Name, Is.EqualTo("sample"));
+ }
+
+ [Test]
+ public void Boost_MethodCall()
+ {
+ var query = from d in documents where d.Name == "sample" || d.Name.Boost(2f).StartsWith("other") select d;
+
+ Assert.That(query.First().Name, Is.EqualTo("other"));
+ }
+
+ [Test]
+ public void Boost_BinaryExpression()
+ {
+ var query = from d in documents where d.Name == "other" || (d.Id == "0").Boost(100) select d;
+
+ Assert.That(query.First().Name, Is.EqualTo("sample"));
+ }
+
+ [Test]
+ public void Boost_CompoundBinaryExpression()
+ {
+ var query = from d in documents where (d.Name == "other" || (d.Id == "1")).Boost(0) || d.Name == "sample" select d;
+
+ Assert.That(query.First().Name, Is.EqualTo("sample"));
+ }
+
+ [Test]
+ public void Dynamic_Single()
+ {
+ var first = documents.Boost(d => d.Scalar);
Assert.That(first.ToList()[0].Id, Is.EqualTo("1"));
}
[Test]
- public void Multiple()
+ public void Dynamic_Multiple()
{
AddDocument(new SampleDocument { Name = "sample", Id = "33", Scalar = 1 });
var first = documents.Where(d => d.Name == "sample").Boost(d => d.Id.Length).Boost(d => d.Scalar);
Assert.That(first.ToList()[0].Id, Is.EqualTo("33"));
}
+
+ [Test]
+ public void ExtensionMethodThrowsWhenInvoked()
+ {
+ TestDelegate call = () => "hello".Boost(0f);
+
+ Assert.That(call, Throws.InvalidOperationException);
+ }
+
+ [Test]
+ public void ThrowsWhenNotCalledOnQueryField()
+ {
+ TestDelegate call = () => (from d in documents orderby d.Name.Boost(5f) select d).ToList();
+ Assert.That(call, Throws.InstanceOf<NotSupportedException>());
+ }
}
}
@@ -17,12 +17,12 @@ class TaggedDocument
}
[Test]
- public void EnumerableContains()
+ public void Enumerable_Contains()
{
using (var session = provider.OpenSession<TaggedDocument>())
{
- session.Add(new TaggedDocument { Name = "First", Tags = new[] {"a", "b"}});
- session.Add(new TaggedDocument { Name = "Second", Tags = new[] {"c", "d"}});
+ session.Add(new TaggedDocument { Name = "First", Tags = new[] { "a", "b" } });
+ session.Add(new TaggedDocument { Name = "Second", Tags = new[] { "c", "d" } });
session.Commit();
}
@@ -77,6 +77,7 @@
<Compile Include="Mapping\FieldMappingInfoBuilderDateFormatTests.cs" />
<Compile Include="Mapping\FieldMappingInfoBuilderTimeSpanTests.cs" />
<Compile Include="Mapping\NumericReflectionFieldMapperTests.cs" />
+ <Compile Include="Transformation\TreeVisitors\BoostMethodCallTreeVisitorTests.cs" />
<Compile Include="Transformation\TreeVisitors\CompareCallToBinaryExpressionTreeVisitorTests.cs" />
<Compile Include="Transformation\TreeVisitors\LuceneExtensionMethodCallTreeVisitorTests.cs" />
<Compile Include="Transformation\TreeVisitors\MethodCallToBinaryExpressionTreeVisitorTests.cs" />
@@ -0,0 +1,96 @@
+using System;
+using System.Linq.Expressions;
+using Lucene.Net.Linq.Expressions;
+using Lucene.Net.Linq.Transformation.TreeVisitors;
+using Lucene.Net.Search;
+using NUnit.Framework;
+using Remotion.Linq;
+
+namespace Lucene.Net.Linq.Tests.Transformation.TreeVisitors
+{
+ [TestFixture]
+ public class BoostMethodCallTreeVisitorTests
+ {
+ private BoostMethodCallTreeVisitor visitor;
+
+ [SetUp]
+ public void SetUp()
+ {
+ visitor = new BoostMethodCallTreeVisitor(0);
+ }
+
+ [Test]
+ public void Stage0_Transform()
+ {
+ var methodInfo = ReflectionUtility.GetMethod(() => LuceneMethods.Boost<string>(null, 0f));
+ var fieldExpression = new LuceneQueryFieldExpression(typeof (string), "Name");
+ const float boostAmount = 5.5f;
+
+ // LuceneField(Name).Boost(5.5)
+ var call = Expression.Call(methodInfo, fieldExpression, Expression.Constant(boostAmount));
+
+ var result = visitor.VisitExpression(call);
+
+ Assert.That(result, Is.SameAs(fieldExpression));
+ Assert.That(((LuceneQueryFieldExpression)result).Boost, Is.EqualTo(boostAmount));
+ }
+
+ [Test]
+ public void Stage1_Transform()
+ {
+ visitor = new BoostMethodCallTreeVisitor(1);
+ var methodInfo = ReflectionUtility.GetMethod(() => false.Boost(0f));
+ var fieldExpression = new LuceneQueryFieldExpression(typeof(string), "Name");
+ var query = new LuceneQueryExpression(fieldExpression, Expression.Constant("foo"), BooleanClause.Occur.SHOULD);
+
+ const float boostAmount = 0.5f;
+
+ // (LuceneQuery[Default](+Name:"foo")).Boost(0.5f)
+ var call = Expression.Call(methodInfo, query, Expression.Constant(boostAmount));
+
+ var result = visitor.VisitExpression(call);
+
+ Assert.That(result, Is.SameAs(query));
+ Assert.That(((LuceneQueryExpression)result).Boost, Is.EqualTo(boostAmount));
+ }
+
+ [Test]
+ public void Stage0_IgnoresNonLuceneQueryFieldExpression()
+ {
+ var methodInfo = ReflectionUtility.GetMethod(() => LuceneMethods.Boost<string>(null, 0f));
+
+ // "hello".Boost(5.5)
+ var expr = Expression.Call(methodInfo, Expression.Constant("hello"), Expression.Constant(5.5f));
+
+ var result = visitor.VisitExpression(expr);
+
+ Assert.That(result, Is.SameAs(expr));
+ }
+
+ [Test]
+ public void Stage1_ThrowsWhenNotOnQueryField()
+ {
+ visitor = new BoostMethodCallTreeVisitor(1);
+ var methodInfo = ReflectionUtility.GetMethod(() => LuceneMethods.Boost<string>(null, 0f));
+
+ // "hello".Boost(5.5)
+ var expr = Expression.Call(methodInfo, Expression.Constant("hello"), Expression.Constant(5.5f));
+
+ TestDelegate call = () => visitor.VisitExpression(expr);
+
+ Assert.That(call, Throws.InstanceOf<NotSupportedException>());
+ }
+
+ [Test]
+ public void IgnoresUnrelatedMethodCalls()
+ {
+ var methodInfo = ReflectionUtility.GetMethod(() => string.IsNullOrEmpty("a"));
+
+ var expr = Expression.Call(methodInfo, Expression.Constant("hello"));
+
+ var result = visitor.VisitExpression(expr);
+
+ Assert.That(result, Is.SameAs(expr));
+ }
+ }
+}
@@ -0,0 +1,45 @@
+using System.Linq.Expressions;
+using Remotion.Linq.Clauses.Expressions;
+using Remotion.Linq.Parsing;
+
+namespace Lucene.Net.Linq.Expressions
+{
+ internal class BoostBinaryExpression : ExtensionExpression
+ {
+ public const ExpressionType ExpressionType = (ExpressionType)150006;
+
+ private readonly BinaryExpression expression;
+ private readonly float boost;
+
+ public BoostBinaryExpression(BinaryExpression expression, float boost)
+ : base(expression.Type, ExpressionType)
+ {
+ this.expression = expression;
+ this.boost = boost;
+ }
+
+ public BinaryExpression BinaryExpression
+ {
+ get { return expression; }
+ }
+
+ public float Boost
+ {
+ get { return boost; }
+ }
+
+ public override string ToString()
+ {
+ return string.Format("{0}^{1}", BinaryExpression, Boost);
+ }
+
+ protected override Expression VisitChildren(ExpressionTreeVisitor visitor)
+ {
+ var newExpression = visitor.VisitExpression(BinaryExpression);
+
+ if (ReferenceEquals(BinaryExpression, newExpression)) return this;
+
+ return new BoostBinaryExpression((BinaryExpression) newExpression, Boost);
+ }
+ }
+}
@@ -45,6 +45,12 @@ public BooleanClause.Occur Occur
get { return occur; }
}
+ public float Boost
+ {
+ get { return field.Boost; }
+ set { field.Boost = value; }
+ }
+
public QueryType QueryType
{
get { return queryType; }
@@ -60,7 +66,7 @@ protected override Expression VisitChildren(ExpressionTreeVisitor visitor)
public override string ToString()
{
- return string.Format("LuceneQuery[{0}]({1}{2}:{3})", QueryType, Occur, QueryField.FieldName, pattern);
+ return string.Format("LuceneQuery[{0}]({1}{2}:{3}){4}", QueryType, Occur, QueryField.FieldName, pattern, Boost - 1.0f < 0.01f ? "" : "^" + Boost);
}
}
@@ -14,6 +14,7 @@ internal LuceneQueryFieldExpression(Type type, string fieldName)
: base(type, ExpressionType)
{
this.fieldName = fieldName;
+ Boost = 1;
}
internal LuceneQueryFieldExpression(Type type, ExpressionType expressionType, string fieldName)
@@ -29,6 +30,7 @@ protected override Expression VisitChildren(ExpressionTreeVisitor visitor)
}
public string FieldName { get { return fieldName; } }
+ public float Boost { get; set; }
public bool Equals(LuceneQueryFieldExpression other)
{
@@ -62,7 +64,12 @@ public override int GetHashCode()
public override string ToString()
{
- return "LuceneField(" + fieldName + ")";
+ var s = "LuceneField(" + fieldName + ")";
+ if (Math.Abs(Boost - 1.0f) > 0.01f)
+ {
+ return s + "^" + Boost;
+ }
+ return s;
}
}
}
@@ -60,6 +60,7 @@
<Compile Include="Context.cs" />
<Compile Include="Converters\DateTimeConverter.cs" />
<Compile Include="Converters\FormatConverter.cs" />
+ <Compile Include="Expressions\BoostBinaryExpression.cs" />
<Compile Include="Expressions\LuceneQueryAnyFieldExpression.cs" />
<Compile Include="Expressions\LuceneCompositeOrderingExpression.cs" />
<Compile Include="Expressions\LuceneQueryExpression.cs" />
@@ -85,6 +86,7 @@
<Compile Include="Search\FieldComparator.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />
<Compile Include="Search\Function\DelegatingCustomScoreQuery.cs" />
+ <Compile Include="Transformation\TreeVisitors\BoostMethodCallTreeVisitor.cs" />
<Compile Include="Transformation\TreeVisitors\LuceneExtensionMethodCallTreeVisitor.cs" />
<Compile Include="Transformation\TreeVisitors\SubQueryContainsTreeVisitor.cs" />
<Compile Include="Translation\LuceneQueryModel.cs" />
@@ -14,6 +14,6 @@
<tags>lucene.net lucene linq odata search nosql</tags>
<summary>Provides LINQ IQueryable interface over a Lucene.Net index.</summary>
<description>Execute LINQ queries on Lucene.Net complete with object to Document mapping.</description>
- <releaseNotes>Adds custom query boost functions. Adds support for First(), FirstOrDefault() and Any() queries.</releaseNotes>
+ <releaseNotes>Adds custom query boost functions. Adds support for First(), FirstOrDefault(), Last, LastOrDefault() and Any() queries.</releaseNotes>
</metadata>
</package>
@@ -62,6 +62,7 @@ public LuceneDataProvider(Directory directory, Analyzer analyzer, Version versio
this.version = version;
queryParser = QueryParser.CreateDefault();
+
context = new Context(this.directory, this.analyzer, this.version, indexWriter, transactionLock);
}
@@ -13,7 +13,7 @@ namespace Lucene.Net.Linq
/// </summary>
public static class LuceneMethods
{
- private const string UnreachableCode = "Unreachable code. This method should have been translated and not directly invoked.";
+ private const string UnreachableCode = "Unreachable code. This method should have been translated within a LINQ expression and not directly invoked.";
/// <summary>
/// Expression to be used in orderby clauses to sort results by score.
@@ -22,7 +22,7 @@ public static class LuceneMethods
/// </summary>
public static Expression Score<T>(this T mappedDocument)
{
- throw new NotImplementedException(UnreachableCode);
+ throw new InvalidOperationException(UnreachableCode);
}
///<summary>
@@ -31,7 +31,7 @@ public static Expression Score<T>(this T mappedDocument)
///</summary>
public static string AnyField<T>(this T mappedDocument)
{
- throw new NotImplementedException(UnreachableCode);
+ throw new InvalidOperationException(UnreachableCode);
}
/// <summary>
@@ -48,5 +48,13 @@ public static IQueryable<T> Boost<T>(this IQueryable<T> source, Func<T, float> b
return source;
}
+
+ /// <summary>
+ /// Applies a boost to a property in a where clause.
+ /// </summary>
+ public static T Boost<T>(this T property, float boostAmount)
+ {
+ throw new InvalidOperationException(UnreachableCode);
+ }
}
}
Oops, something went wrong.

0 comments on commit 4288da3

Please sign in to comment.