diff --git a/source/MongoDB.Tests/IntegrationTests/Linq/MongoQueryTests.cs b/source/MongoDB.Tests/IntegrationTests/Linq/MongoQueryTests.cs index 59b55073..08e3e09e 100644 --- a/source/MongoDB.Tests/IntegrationTests/Linq/MongoQueryTests.cs +++ b/source/MongoDB.Tests/IntegrationTests/Linq/MongoQueryTests.cs @@ -350,5 +350,16 @@ public void Complex_Addition() Assert.AreEqual(1, people.Count); } + + [Test] + public void Join() + { + var people = Enumerable.ToList( + from p in collection.Linq() + join op in collection.Linq() on p.PrimaryAddress equals op.PrimaryAddress + select p); + + Assert.AreEqual(0, people.Count); + } } } \ No newline at end of file diff --git a/source/MongoDB/Linq/ExecutionBuilder.cs b/source/MongoDB/Linq/ExecutionBuilder.cs index 44848a15..522bc082 100644 --- a/source/MongoDB/Linq/ExecutionBuilder.cs +++ b/source/MongoDB/Linq/ExecutionBuilder.cs @@ -18,7 +18,6 @@ internal class ExecutionBuilder : MongoExpressionVisitor private int _numCursors; private Expression _provider; private MemberInfo _receivingMember; - private Scope _scope; private List _variables; @@ -80,6 +79,11 @@ protected override Expression VisitClientJoin(ClientJoinExpression clientJoin) return access; } + protected override Expression VisitField(FieldExpression field) + { + return Visit(field.Expression); + } + protected override Expression VisitProjection(ProjectionExpression projection) { if (_isTop) @@ -119,7 +123,6 @@ private Expression Build(Expression expression) private Expression BuildInner(Expression expression) { var builder = new ExecutionBuilder(); - builder._scope = _scope; builder._receivingMember = _receivingMember; builder._numCursors = _numCursors; builder._lookup = _lookup; @@ -128,33 +131,17 @@ private Expression BuildInner(Expression expression) private Expression ExecuteProjection(ProjectionExpression projection) { - projection = (ProjectionExpression)new Parameterizer().Parameterize(projection); - - if(_scope != null) - projection = (ProjectionExpression)new OuterParameterizer().Parameterize(projection, _scope.Alias); - - var saveScope = _scope; - var document = Expression.Parameter(projection.Projector.Type, "d" + (_numCursors++)); - _scope = new Scope(_scope, document, projection.Source.Alias, projection.Source.Fields); - var projector = Expression.Lambda(Visit(projection.Projector), document); - _scope = saveScope; - + var projection = base.VisitProjection(projection); var queryObject = new MongoQueryObjectBuilder().Build(projection); - queryObject.Projector = new ProjectionBuilder().Build(queryObject, projector); - - var namedValues = new NamedValueGatherer().Gather(projection.Source); - var names = namedValues.Select(v => v.Name).ToArray(); - var values = namedValues.Select(v => Expression.Convert(Visit(v.Value), typeof(object))).ToArray(); + queryObject.Projector = new ProjectionBuilder().Build(projection.Projector, queryObject.DocumentType, "d" + (_numCursors++), queryObject.IsMapReduce); + queryObject.Aggregator = projection.Aggregator; Expression result = Expression.Call( _provider, "ExecuteQueryObject", - new[] { queryObject.DocumentType, queryObject.Projector.Type }, + Type.EmptyTypes, Expression.Constant(queryObject, typeof(MongoQueryObject))); - if(projection.Aggregator != null) - result = new ExpressionReplacer().Replace(projection.Aggregator.Body, projection.Aggregator.Parameters[0], result); - return result; } @@ -192,74 +179,6 @@ private static Expression MakeSequence(IList expressions) return Expression.Convert(Expression.Call(typeof(ExecutionBuilder), "Sequence", null, Expression.NewArrayInit(typeof(object), expressions)), last.Type); } - private class Scope - { - private ParameterExpression _document; - private Scope _outer; - private Dictionary _nameMap; - - internal Alias Alias { get; private set; } - - public Scope(Scope outer, ParameterExpression document, Alias alias, IEnumerable fields) - { - _outer = outer; - _document = document; - Alias = alias; - _nameMap = fields.Select((f, i) => new { f, i }).ToDictionary(x => x.f.Name, x => x.i); - } - - public bool TryGetValue(FieldExpression field, out ParameterExpression document, out int ordinal) - { - for (Scope s = this; s != null; s = s._outer) - { - if (field.Alias == s.Alias && _nameMap.TryGetValue(field.Name, out ordinal)) - { - document = _document; - return true; - } - } - document = null; - ordinal = 0; - return false; - } - } - - private class OuterParameterizer : MongoExpressionVisitor - { - private int _paramIndex; - private Alias _outerAlias; - private Dictionary _map; - - public Expression Parameterize(Expression expression, Alias outerAlias) - { - _outerAlias = outerAlias; - return Visit(expression); - } - - protected override Expression VisitProjection(ProjectionExpression projection) - { - SelectExpression select = (SelectExpression)Visit(projection.Source); - if (select != projection.Source) - return new ProjectionExpression(select, projection.Projector, projection.Aggregator); - return projection; - } - - protected override Expression VisitField(FieldExpression field) - { - if (field.Alias == _outerAlias) - { - NamedValueExpression nv; - if (!_map.TryGetValue(field, out nv)) - { - nv = new NamedValueExpression("n" + (_paramIndex++), field); - _map.Add(field, nv); - } - return nv; - } - return field; - } - } - private class CompoundKey : IEquatable { private object[] _values; diff --git a/source/MongoDB/Linq/MongoQueryObject.cs b/source/MongoDB/Linq/MongoQueryObject.cs index 28716f37..f83d9c09 100644 --- a/source/MongoDB/Linq/MongoQueryObject.cs +++ b/source/MongoDB/Linq/MongoQueryObject.cs @@ -12,6 +12,12 @@ internal class MongoQueryObject private Document _query; private Document _sort; + /// + /// Gets or sets the aggregator. + /// + /// The aggregator. + public LambdaExpression Aggregator { get; set; } + /// /// Gets or sets the name of the collection. /// diff --git a/source/MongoDB/Linq/MongoQueryProvider.cs b/source/MongoDB/Linq/MongoQueryProvider.cs index 9a139659..02bed016 100644 --- a/source/MongoDB/Linq/MongoQueryProvider.cs +++ b/source/MongoDB/Linq/MongoQueryProvider.cs @@ -138,12 +138,12 @@ internal MongoQueryObject GetQueryObject(Expression expression) /// /// The query object. /// - internal IEnumerable ExecuteQueryObject(MongoQueryObject queryObject){ + internal object ExecuteQueryObject(MongoQueryObject queryObject){ if (queryObject.IsCount) - return ExecuteCount(queryObject); + return ExecuteCount(queryObject); else if (queryObject.IsMapReduce) - return ExecuteMapReduce(queryObject); - return ExecuteFind(queryObject); + return ExecuteMapReduce(queryObject); + return ExecuteFind(queryObject); } private Expression BuildExecutionPlan(Expression expression) @@ -220,20 +220,18 @@ private bool CanBeEvaluatedLocally(Expression expression) /// /// The query object. /// - private IEnumerable ExecuteCount(MongoQueryObject queryObject) + private object ExecuteCount(MongoQueryObject queryObject) { var miGetCollection = typeof(IMongoDatabase).GetMethods().Where(m => m.Name == "GetCollection" && m.GetGenericArguments().Length == 1 && m.GetParameters().Length == 1).Single().MakeGenericMethod(queryObject.DocumentType); var collection = miGetCollection.Invoke(queryObject.Database, new[] { queryObject.CollectionName }); - IEnumerable documents; if (queryObject.Query == null) - documents = new[] { (TDocument)collection.GetType().GetMethod("Count", Type.EmptyTypes).Invoke(collection, null) }; - documents = new[] { (TDocument)collection.GetType().GetMethod("Count", new[] { typeof(object) }).Invoke(collection, new[] { queryObject.Query }) }; + return Convert.ToInt32(collection.GetType().GetMethod("Count", Type.EmptyTypes).Invoke(collection, null)); - return Project(documents, (Func)queryObject.Projector.Compile()); + return Convert.ToInt32(collection.GetType().GetMethod("Count", new[] { typeof(object) }).Invoke(collection, new[] { queryObject.Query })); } - private IEnumerable ExecuteFind(MongoQueryObject queryObject) + private object ExecuteFind(MongoQueryObject queryObject) { var miGetCollection = typeof(IMongoDatabase).GetMethods().Where(m => m.Name == "GetCollection" && m.GetGenericArguments().Length == 1 && m.GetParameters().Length == 1).Single().MakeGenericMethod(queryObject.DocumentType); var collection = miGetCollection.Invoke(queryObject.Database, new[] { queryObject.CollectionName }); @@ -256,11 +254,11 @@ private bool CanBeEvaluatedLocally(Expression expression) cursorType.GetMethod("Limit").Invoke(cursor, new object[] { queryObject.NumberToLimit }); cursorType.GetMethod("Skip").Invoke(cursor, new object[] { queryObject.NumberToSkip }); - var documents = (IEnumerable)cursor.GetType().GetProperty("Documents").GetValue(cursor, null); - return Project(documents, (Func)queryObject.Projector.Compile()); + var executor = GetExecutor(queryObject.DocumentType, queryObject.Projector, queryObject.Aggregator, true); + return executor.Compile().DynamicInvoke(cursor.GetType().GetProperty("Documents").GetValue(cursor, null)); } - private IEnumerable ExecuteMapReduce(MongoQueryObject queryObject) + private object ExecuteMapReduce(MongoQueryObject queryObject) { var miGetCollection = typeof(IMongoDatabase).GetMethods().Where(m => m.Name == "GetCollection" && m.GetGenericArguments().Length == 1 && m.GetParameters().Length == 1).Single().MakeGenericMethod(queryObject.DocumentType); var collection = miGetCollection.Invoke(queryObject.Database, new[] { queryObject.CollectionName }); @@ -271,35 +269,41 @@ private bool CanBeEvaluatedLocally(Expression expression) mapReduce.Finalize = new Code(queryObject.FinalizerFunction); mapReduce.Query = queryObject.Query; - if (queryObject.Sort != null) + if(queryObject.Sort != null) mapReduce.Sort = queryObject.Sort; mapReduce.Limit = queryObject.NumberToLimit; if (queryObject.NumberToSkip != 0) throw new InvalidQueryException("MapReduce queries do no support Skips."); - var documents = (IEnumerable)mapReduce.Documents; - return Project(documents, (Func)queryObject.Projector.Compile()); - } - - private IEnumerable Project(IEnumerable documents, Func projector) - { - foreach (var doc in documents) - { - yield return projector(doc); - } + var executor = GetExecutor(typeof(Document), queryObject.Projector, queryObject.Aggregator, true); + return executor.Compile().DynamicInvoke(mapReduce.Documents); } - private static LambdaExpression GetExecutor(Type documentType, LambdaExpression projector, bool boxReturn) + private static LambdaExpression GetExecutor(Type documentType, LambdaExpression projector, LambdaExpression aggregator, bool boxReturn) { var documents = Expression.Parameter(typeof(IEnumerable<>).MakeGenericType(documentType), "documents"); - Expression body = Expression.New(typeof(ProjectionReader<,>).MakeGenericType(documentType, projector.Body.Type).GetConstructors()[0], documents, projector); + Expression body = Expression.Call( + typeof(MongoQueryProvider), + "Project", + new[] { documentType, projector.Body.Type }, + documents, + projector); + if (aggregator != null) + body = Expression.Invoke(aggregator, body); + if (boxReturn && body.Type != typeof(object)) body = Expression.Convert(body, typeof(object)); return Expression.Lambda(body, documents); } + private static IEnumerable Project(IEnumerable documents, Func projector) + { + foreach (var doc in documents) + yield return projector(doc); + } + private class RootQueryableFinder : MongoExpressionVisitor { private Expression _root; diff --git a/source/MongoDB/Linq/ProjectionReader.cs b/source/MongoDB/Linq/ProjectionReader.cs deleted file mode 100644 index b3ce21b8..00000000 --- a/source/MongoDB/Linq/ProjectionReader.cs +++ /dev/null @@ -1,70 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Linq.Expressions; -using System.Reflection; -using System.Text; - -namespace MongoDB.Linq -{ - internal class ProjectionReader : IEnumerable - { - private Enumerator _enumerator; - - public ProjectionReader(IEnumerable documents, Func projector) - { - _enumerator = new Enumerator(documents.GetEnumerator(), projector); - } - - public IEnumerator GetEnumerator() - { - var e = _enumerator; - if (e == null) - throw new InvalidOperationException("Cannot enumerate more than once."); - _enumerator = null; - return e; - } - - System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() - { - return GetEnumerator(); - } - - private class Enumerator : IEnumerator, IDisposable - { - private IEnumerator _cursorEnumerator; - private Func _projector; - - public TResult Current - { - get { return _projector(_cursorEnumerator.Current); } - } - - object System.Collections.IEnumerator.Current - { - get { return Current; } - } - - public Enumerator(IEnumerator enumerator, Func projector) - { - _cursorEnumerator = enumerator; - _projector = projector; - } - - public void Dispose() - { - _cursorEnumerator.Dispose(); - } - - public bool MoveNext() - { - return _cursorEnumerator.MoveNext(); - } - - public void Reset() - { - _cursorEnumerator.Reset(); - } - } - } -} diff --git a/source/MongoDB/Linq/Translators/ClientJoinProjectionRewriter.cs b/source/MongoDB/Linq/Translators/ClientJoinProjectionRewriter.cs index f929cf0b..523357c8 100644 --- a/source/MongoDB/Linq/Translators/ClientJoinProjectionRewriter.cs +++ b/source/MongoDB/Linq/Translators/ClientJoinProjectionRewriter.cs @@ -81,7 +81,7 @@ protected override Expression VisitSubquery(SubqueryExpression subquery) private bool CanJoinOnClient(SelectExpression select) { return !select.IsDistinct - && select.GroupBy != null + && select.GroupBy == null && !new AggregateChecker().HasAggregates(select); } diff --git a/source/MongoDB/Linq/Translators/Parameterizer.cs b/source/MongoDB/Linq/Translators/Parameterizer.cs deleted file mode 100644 index bee7b4c7..00000000 --- a/source/MongoDB/Linq/Translators/Parameterizer.cs +++ /dev/null @@ -1,86 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Collections.ObjectModel; -using System.Linq; -using System.Linq.Expressions; -using System.Text; - -using MongoDB.Linq.Expressions; - -namespace MongoDB.Linq.Translators -{ - internal class Parameterizer : MongoExpressionVisitor - { - private int _paramIndex; - private Dictionary _map; - private Dictionary _pmap; - - public Expression Parameterize(Expression expression) - { - return Visit(expression); - } - - protected override Expression VisitConstant(ConstantExpression c) - { - if (c.Value != null && !IsNumeric(c.Value.GetType())) - { - NamedValueExpression nv; - if (!_map.TryGetValue(c.Value, out nv)) - { - var name = "p" + (_paramIndex++); - nv = new NamedValueExpression(name, c); - _map.Add(c.Value, nv); - } - return nv; - } - return c; - } - - protected override Expression VisitParameter(ParameterExpression p) - { - return GetNamedValue(p); - } - - protected override Expression VisitProjection(ProjectionExpression projection) - { - SelectExpression select = (SelectExpression)Visit(projection.Source); - if (select != projection.Source) - return new ProjectionExpression(select, projection.Projector, projection.Aggregator); - return projection; - } - - private Expression GetNamedValue(Expression e) - { - NamedValueExpression nv; - if (!_pmap.TryGetValue(e, out nv)) - { - string name = "p" + (_paramIndex++); - nv = new NamedValueExpression(name, e); - _pmap.Add(e, nv); - } - return nv; - } - - private static bool IsNumeric(Type type) - { - switch (Type.GetTypeCode(type)) - { - case TypeCode.Boolean: - case TypeCode.Byte: - case TypeCode.Decimal: - case TypeCode.Double: - case TypeCode.Int16: - case TypeCode.Int32: - case TypeCode.Int64: - case TypeCode.SByte: - case TypeCode.Single: - case TypeCode.UInt16: - case TypeCode.UInt32: - case TypeCode.UInt64: - return true; - default: - return false; - } - } - } -} \ No newline at end of file diff --git a/source/MongoDB/Linq/Translators/ProjectionBuilder.cs b/source/MongoDB/Linq/Translators/ProjectionBuilder.cs index 7ddb8d97..13f65c34 100644 --- a/source/MongoDB/Linq/Translators/ProjectionBuilder.cs +++ b/source/MongoDB/Linq/Translators/ProjectionBuilder.cs @@ -11,7 +11,7 @@ namespace MongoDB.Linq.Translators { internal class ProjectionBuilder : MongoExpressionVisitor { - private MongoQueryObject _queryObject; + private bool _isMapReduce; private ParameterExpression _document; private GroupingKeyDeterminer _determiner; @@ -20,34 +20,33 @@ public ProjectionBuilder() _determiner = new GroupingKeyDeterminer(); } - public LambdaExpression Build(MongoQueryObject queryObject, Expression projector) + public LambdaExpression Build(Expression projector, Type documentType, string parameterName, bool isMapReduce) { - _queryObject = queryObject; - if (_queryObject.IsMapReduce) - _document = Expression.Parameter(typeof(Document), "document"); + _isMapReduce = isMapReduce; + if (_isMapReduce) + _document = Expression.Parameter(typeof(Document), parameterName); else - _document = Expression.Parameter(queryObject.DocumentType, "document"); + _document = Expression.Parameter(documentType, parameterName); - var body = Visit(projector); - return Expression.Lambda(body, _document); + return Expression.Lambda(Visit(projector), _document); } protected override Expression VisitField(FieldExpression field) { - if (_queryObject.IsMapReduce) + if (_isMapReduce) { var parts = field.Name.Split('.'); bool isGroupingField = _determiner.IsGroupingKey(field); Expression current; - if(parts.Contains("Key") && isGroupingField) + if (parts.Contains("Key") && isGroupingField) current = _document; else { current = Expression.Call( _document, "Get", - new [] { typeof(Document) }, + new[] { typeof(Document) }, Expression.Constant("value")); } @@ -73,10 +72,7 @@ protected override Expression VisitField(FieldExpression field) protected override Expression VisitParameter(ParameterExpression p) { - if (p.Type == _document.Type) - return _document; - - return p; + return _document; } private class GroupingKeyDeterminer : MongoExpressionVisitor diff --git a/source/MongoDB/Linq/Translators/QueryBinder.cs b/source/MongoDB/Linq/Translators/QueryBinder.cs index c1fb8ff0..797b57f5 100644 --- a/source/MongoDB/Linq/Translators/QueryBinder.cs +++ b/source/MongoDB/Linq/Translators/QueryBinder.cs @@ -363,19 +363,36 @@ private Expression BindJoin(Type resultType, Expression outerSource, Expression { var outerProjection = VisitSequence(outerSource); var innerProjection = VisitSequence(innerSource); - _map[outerKey.Parameters[0]] = outerProjection.Projector; - var outerKeyExpression = Visit(outerKey.Body); - _map[innerKey.Parameters[0]] = innerProjection.Projector; - var innerKeyExpression = Visit(innerKey.Body); - _map[resultSelector.Parameters[0]] = outerProjection.Projector; - _map[resultSelector.Parameters[1]] = innerProjection.Projector; + + var join = Expression.Call( //change this to a client-side join... + typeof(Enumerable), + "Join", + new [] { outerProjection.Projector.Type, innerProjection.Projector.Type, outerKey.Body.Type, resultSelector.Body.Type }, + outerProjection, + innerProjection, + outerKey, + innerKey, + resultSelector); + var resultExpression = Visit(resultSelector.Body); - var join = new JoinExpression(JoinType.InnerJoin, outerProjection.Source, innerProjection.Source, Expression.Equal(outerKeyExpression, innerKeyExpression)); var alias = new Alias(); - var fieldProjections = _projector.ProjectFields(resultExpression, alias, outerProjection.Source.Alias, innerProjection.Source.Alias); + var fieldProjection = _projector.ProjectFields(resultExpression, outerProjection.Source.Alias, innerProjection.Source.Alias); return new ProjectionExpression( - new SelectExpression(alias, fieldProjections.Fields, join, null), - fieldProjections.Projector); + new SelectExpression(alias, fieldProjection.Fields, join, null), + fieldProjection.Projector); + //_map[outerKey.Parameters[0]] = outerProjection.Projector; + //var outerKeyExpression = Visit(outerKey.Body); + //_map[innerKey.Parameters[0]] = innerProjection.Projector; + //var innerKeyExpression = Visit(innerKey.Body); + //_map[resultSelector.Parameters[0]] = outerProjection.Projector; + //_map[resultSelector.Parameters[1]] = innerProjection.Projector; + //var resultExpression = Visit(resultSelector.Body); + //var join = new JoinExpression(JoinType.InnerJoin, outerProjection.Source, innerProjection.Source, Expression.Equal(outerKeyExpression, innerKeyExpression)); + //var alias = new Alias(); + //var fieldProjections = _projector.ProjectFields(resultExpression, alias, outerProjection.Source.Alias, innerProjection.Source.Alias); + //return new ProjectionExpression( + // new SelectExpression(alias, fieldProjections.Fields, join, null), + // fieldProjections.Projector); } private Expression BindOrderBy(Type resultType, Expression source, LambdaExpression orderSelector, OrderType orderType) diff --git a/source/MongoDB/MongoDB.csproj b/source/MongoDB/MongoDB.csproj index 2bbfe8b3..d413289e 100644 --- a/source/MongoDB/MongoDB.csproj +++ b/source/MongoDB/MongoDB.csproj @@ -130,7 +130,6 @@ - @@ -180,7 +179,6 @@ -