From eb99839cdad40fd32ed8cdaa2deb4766d1f56183 Mon Sep 17 00:00:00 2001 From: Craig Wilson Date: Mon, 31 May 2010 17:50:37 -0500 Subject: [PATCH] fixed issue with a flattening projection and still maintaining the original hierarchy. --- .../Linq/MongoQueryProviderTests.cs | 14 ++++++++++++++ .../Linq/Expressions/MongoExpressionExtensions.cs | 8 ++++++++ source/MongoDB/Linq/MongoQueryProvider.cs | 5 ++++- source/MongoDB/Linq/Translators/FieldBinder.cs | 12 +++++++++++- .../MapReduceFinalizerFunctionBuilder.cs | 2 ++ .../Translators/MapReduceMapFunctionBuilder.cs | 2 ++ .../Translators/MapReduceReduceFunctionBuilder.cs | 2 ++ .../Linq/Translators/MongoQueryObjectBuilder.cs | 2 +- 8 files changed, 44 insertions(+), 3 deletions(-) diff --git a/source/MongoDB.Tests/IntegrationTests/Linq/MongoQueryProviderTests.cs b/source/MongoDB.Tests/IntegrationTests/Linq/MongoQueryProviderTests.cs index b6bcee76..8d489b41 100644 --- a/source/MongoDB.Tests/IntegrationTests/Linq/MongoQueryProviderTests.cs +++ b/source/MongoDB.Tests/IntegrationTests/Linq/MongoQueryProviderTests.cs @@ -317,6 +317,20 @@ public void ProjectionWithConstraints() Assert.AreEqual(new Document("Age", new Document().Merge(Op.GreaterThan(21)).Merge(Op.LessThan(42))), queryObject.Query); } + [Test] + public void ProjectionWithLocalCreation_ChildobjectShouldNotBeNull() + { + var people = Collection.Linq() + .Select(p => new PersonWrapper(p, p.FirstName)); + + var queryObject = ((IMongoQueryable)people).GetQueryObject(); + Assert.AreEqual(0, queryObject.Fields.Count()); + Assert.AreEqual(0, queryObject.NumberToLimit); + Assert.AreEqual(0, queryObject.NumberToSkip); + Assert.AreEqual(0, queryObject.Query.Count); + + } + [Test] public void Regex_IsMatch() { diff --git a/source/MongoDB/Linq/Expressions/MongoExpressionExtensions.cs b/source/MongoDB/Linq/Expressions/MongoExpressionExtensions.cs index f72ca5fe..8c718ed5 100644 --- a/source/MongoDB/Linq/Expressions/MongoExpressionExtensions.cs +++ b/source/MongoDB/Linq/Expressions/MongoExpressionExtensions.cs @@ -8,6 +8,14 @@ namespace MongoDB.Linq.Expressions { internal static class MongoExpressionExtensions { + public static bool HasSelectAllField(this IEnumerable fields) + { + if (fields == null) + return true; + + return fields.Any(f => f.Name == "*"); + } + public static SelectExpression AddField(this SelectExpression select, FieldDeclaration field) { List fields = new List(select.Fields); diff --git a/source/MongoDB/Linq/MongoQueryProvider.cs b/source/MongoDB/Linq/MongoQueryProvider.cs index fdd5f033..ed471429 100644 --- a/source/MongoDB/Linq/MongoQueryProvider.cs +++ b/source/MongoDB/Linq/MongoQueryProvider.cs @@ -164,9 +164,12 @@ private Expression BuildExecutionPlan(Expression expression) private ProjectionExpression Translate(Expression expression) { + var rootQueryable = new RootQueryableFinder().Find(expression); + var elementType = ((IQueryable)((ConstantExpression)rootQueryable).Value).ElementType; + expression = PartialEvaluator.Evaluate(expression, CanBeEvaluatedLocally); - expression = new FieldBinder().Bind(expression); + expression = new FieldBinder().Bind(expression, elementType); expression = new QueryBinder(this, expression).Bind(expression); expression = new AggregateRewriter().Rewrite(expression); expression = new RedundantFieldRemover().Remove(expression); diff --git a/source/MongoDB/Linq/Translators/FieldBinder.cs b/source/MongoDB/Linq/Translators/FieldBinder.cs index 53e9f957..86fce6fd 100644 --- a/source/MongoDB/Linq/Translators/FieldBinder.cs +++ b/source/MongoDB/Linq/Translators/FieldBinder.cs @@ -17,11 +17,13 @@ internal class FieldBinder : ExpressionVisitor private Alias _alias; private FieldFinder _finder; + private Type _elementType; - public Expression Bind(Expression expression) + public Expression Bind(Expression expression, Type elementType) { _alias = new Alias(); _finder = new FieldFinder(); + _elementType = elementType; return Visit(expression); } @@ -37,6 +39,14 @@ protected override Expression Visit(Expression exp) return base.Visit(exp); } + protected override Expression VisitParameter(ParameterExpression p) + { + if(p.Type == _elementType) + return new FieldExpression(p, _alias, "*"); + + return base.VisitParameter(p); + } + private class FieldFinder : ExpressionVisitor { private Stack _fieldParts; diff --git a/source/MongoDB/Linq/Translators/MapReduceFinalizerFunctionBuilder.cs b/source/MongoDB/Linq/Translators/MapReduceFinalizerFunctionBuilder.cs index e6d1f279..b9d4bce2 100644 --- a/source/MongoDB/Linq/Translators/MapReduceFinalizerFunctionBuilder.cs +++ b/source/MongoDB/Linq/Translators/MapReduceFinalizerFunctionBuilder.cs @@ -57,6 +57,8 @@ protected override ReadOnlyCollection VisitFieldDeclarationLis for (int i = 0, n = fields.Count; i < n; i++) { _currentAggregateName = fields[i].Name; + if (_currentAggregateName == "*") + continue; Visit(fields[i].Expression); } diff --git a/source/MongoDB/Linq/Translators/MapReduceMapFunctionBuilder.cs b/source/MongoDB/Linq/Translators/MapReduceMapFunctionBuilder.cs index 2b8c30c7..95dac4f7 100644 --- a/source/MongoDB/Linq/Translators/MapReduceMapFunctionBuilder.cs +++ b/source/MongoDB/Linq/Translators/MapReduceMapFunctionBuilder.cs @@ -66,6 +66,8 @@ protected override ReadOnlyCollection VisitFieldDeclarationLis for (int i = 0, n = fields.Count; i < n; i++) { _currentAggregateName = fields[i].Name; + if (_currentAggregateName == "*") + continue; Visit(fields[i].Expression); } diff --git a/source/MongoDB/Linq/Translators/MapReduceReduceFunctionBuilder.cs b/source/MongoDB/Linq/Translators/MapReduceReduceFunctionBuilder.cs index 57527b0e..264c63f0 100644 --- a/source/MongoDB/Linq/Translators/MapReduceReduceFunctionBuilder.cs +++ b/source/MongoDB/Linq/Translators/MapReduceReduceFunctionBuilder.cs @@ -76,6 +76,8 @@ protected override ReadOnlyCollection VisitFieldDeclarationLis for (int i = 0, n = fields.Count; i < n; i++) { _currentAggregateName = fields[i].Name; + if (_currentAggregateName == "*") + continue; Visit(fields[i].Expression); } diff --git a/source/MongoDB/Linq/Translators/MongoQueryObjectBuilder.cs b/source/MongoDB/Linq/Translators/MongoQueryObjectBuilder.cs index de3a4561..6f4710a0 100644 --- a/source/MongoDB/Linq/Translators/MongoQueryObjectBuilder.cs +++ b/source/MongoDB/Linq/Translators/MongoQueryObjectBuilder.cs @@ -49,7 +49,7 @@ protected override Expression VisitSelect(SelectExpression select) _queryObject.ReduceFunction = new MapReduceReduceFunctionBuilder().Build(select.Fields); _queryObject.FinalizerFunction = new MapReduceFinalizerFunctionBuilder().Build(select.Fields); } - else if(!_queryAttributes.IsCount) + else if(!_queryAttributes.IsCount && !select.Fields.HasSelectAllField()) { var fieldGatherer = new FieldGatherer(); foreach (var field in select.Fields)