From c1b238ece31c13f40ea38ab5213e9607148e9f87 Mon Sep 17 00:00:00 2001 From: Craig Wilson Date: Fri, 30 Apr 2010 20:35:23 -0500 Subject: [PATCH] wip --- source/MongoDB/Linq/ExecutionBuilder.cs | 266 +++++++++++++++++- .../Linq/Expressions/MongoExpressionType.cs | 3 +- .../Expressions/MongoExpressionVisitor.cs | 29 +- .../Linq/Expressions/NamedValueExpression.cs | 31 ++ source/MongoDB/Linq/MongoQueryObject.cs | 6 - source/MongoDB/Linq/MongoQueryProvider.cs | 157 ++++++----- .../Linq/Translators/ExpressionReplacer.cs | 38 +++ .../Translators/MongoQueryObjectBuilder.cs | 4 - .../Linq/Translators/NamedValueGatherer.cs | 29 ++ .../MongoDB/Linq/Translators/Parameterizer.cs | 86 ++++++ .../MongoDB/Linq/Translators/QueryBinder.cs | 8 +- source/MongoDB/MongoDB.csproj | 4 + 12 files changed, 563 insertions(+), 98 deletions(-) create mode 100644 source/MongoDB/Linq/Expressions/NamedValueExpression.cs create mode 100644 source/MongoDB/Linq/Translators/ExpressionReplacer.cs create mode 100644 source/MongoDB/Linq/Translators/NamedValueGatherer.cs create mode 100644 source/MongoDB/Linq/Translators/Parameterizer.cs diff --git a/source/MongoDB/Linq/ExecutionBuilder.cs b/source/MongoDB/Linq/ExecutionBuilder.cs index 2941e35a..44848a15 100644 --- a/source/MongoDB/Linq/ExecutionBuilder.cs +++ b/source/MongoDB/Linq/ExecutionBuilder.cs @@ -6,33 +6,297 @@ using System.Text; using MongoDB.Linq.Expressions; +using MongoDB.Linq.Translators; namespace MongoDB.Linq { - public class ExecutionBuilder : MongoExpressionVisitor + internal class ExecutionBuilder : MongoExpressionVisitor { private List _initializers; private bool _isTop; + private int _lookup; private int _numCursors; private Expression _provider; private MemberInfo _receivingMember; + private Scope _scope; private List _variables; public Expression Build(Expression expression, Expression provider) { + _initializers = new List(); + _variables = new List(); + _isTop = true; _provider = provider; + return Build(expression); + } + + protected override MemberBinding VisitBinding(MemberBinding binding) + { + var save = _receivingMember; + _receivingMember = binding.Member; + var result = base.VisitBinding(binding); + _receivingMember = save; + return result; + } + + protected override Expression VisitClientJoin(ClientJoinExpression clientJoin) + { + var innerKey = MakeJoinKey(clientJoin.InnerKey); + var outerKey = MakeJoinKey(clientJoin.OuterKey); + + var pairConstructor = typeof(KeyValuePair<,>).MakeGenericType(innerKey.Type, clientJoin.Projection.Projector.Type).GetConstructor(new[] { innerKey.Type, clientJoin.Projection.Projector.Type }); + var constructPair = Expression.New(pairConstructor, innerKey, clientJoin.Projection.Projector); + var newProjection = new ProjectionExpression(clientJoin.Projection.Source, constructPair); + + int lookupIndex = _lookup++; + var execution = ExecuteProjection(newProjection); + + var pair = Expression.Parameter(constructPair.Type, "pair"); + + if (clientJoin.Projection.Projector.NodeType == (ExpressionType)MongoExpressionType.OuterJoined) + { + var lambda = Expression.Lambda( + Expression.NotEqual( + Expression.PropertyOrField(pair, "Value"), + Expression.Constant(null, clientJoin.Projection.Projector.Type)), + pair); + execution = Expression.Call(typeof(Enumerable), "Where", new[] { pair.Type }, execution, lambda); + } + + var keySelector = Expression.Lambda(Expression.PropertyOrField(pair, "Key"), pair); + var elementSelector = Expression.Lambda(Expression.PropertyOrField(pair, "Value"), pair); + var toLookup = Expression.Call(typeof(Enumerable), "ToLookup", new [] { pair.Type, outerKey.Type, clientJoin.Projection.Projector.Type }, execution, keySelector, elementSelector); + + var lookup = Expression.Parameter(toLookup.Type, "lookup" + lookupIndex); + var prop = lookup.Type.GetProperty("Item"); + Expression access = Expression.Call(lookup, prop.GetGetMethod(), Visit(outerKey)); + if(clientJoin.Projection.Aggregator != null) + access = new ExpressionReplacer().Replace(clientJoin.Projection.Aggregator.Body, clientJoin.Projection.Aggregator.Parameters[0], access); + + _variables.Add(lookup); + _initializers.Add(toLookup); + + return access; + } + + protected override Expression VisitProjection(ProjectionExpression projection) + { + if (_isTop) + { + _isTop = false; + return ExecuteProjection(projection); + } + else + { + return BuildInner(projection); + } + } + + private Expression AddVariables(Expression expression) + { + if (_variables.Count > 0) + { + var expressions = new List(); + for (int i = 0, n = _variables.Count; i < n; i++) + expressions.Add(MakeAssign(_variables[i], _initializers[i])); + + var sequence = MakeSequence(expressions); + + var nulls = _variables.Select(v => Expression.Constant(null, v.Type)).ToArray(); + expression = Expression.Invoke(Expression.Lambda(sequence, _variables.ToArray()), nulls); + } + return expression; + } + + private Expression Build(Expression expression) + { expression = Visit(expression); expression = AddVariables(expression); return expression; } + private Expression BuildInner(Expression expression) + { + var builder = new ExecutionBuilder(); + builder._scope = _scope; + builder._receivingMember = _receivingMember; + builder._numCursors = _numCursors; + builder._lookup = _lookup; + return builder.Build(expression); + } + + private Expression ExecuteProjection(ProjectionExpression projection) + { + projection = (ProjectionExpression)new Parameterizer().Parameterize(projection); + + if(_scope != null) + projection = (ProjectionExpression)new OuterParameterizer().Parameterize(projection, _scope.Alias); + + var saveScope = _scope; + var document = Expression.Parameter(projection.Projector.Type, "d" + (_numCursors++)); + _scope = new Scope(_scope, document, projection.Source.Alias, projection.Source.Fields); + var projector = Expression.Lambda(Visit(projection.Projector), document); + _scope = saveScope; + + var queryObject = new MongoQueryObjectBuilder().Build(projection); + queryObject.Projector = new ProjectionBuilder().Build(queryObject, projector); + + var namedValues = new NamedValueGatherer().Gather(projection.Source); + var names = namedValues.Select(v => v.Name).ToArray(); + var values = namedValues.Select(v => Expression.Convert(Visit(v.Value), typeof(object))).ToArray(); + + Expression result = Expression.Call( + _provider, + "ExecuteQueryObject", + new[] { queryObject.DocumentType, queryObject.Projector.Type }, + Expression.Constant(queryObject, typeof(MongoQueryObject))); + + if(projection.Aggregator != null) + result = new ExpressionReplacer().Replace(projection.Aggregator.Body, projection.Aggregator.Parameters[0], result); + + return result; + } + + public static T Assign(ref T variable, T value) + { + variable = value; + return value; + } + public static object Sequence(params object[] values) + { + return values[values.Length - 1]; + } + + private static Expression MakeAssign(ParameterExpression variable, Expression value) + { + return Expression.Call(typeof(ExecutionBuilder), "Assign", new[] { variable.Type }, variable, value); + } + private static Expression MakeJoinKey(IList key) + { + if (key.Count == 1) + return key[0]; + else + { + return Expression.New( + typeof(CompoundKey).GetConstructors()[0], + Expression.NewArrayInit(typeof(object), key.Select(k => (Expression)Expression.Convert(k, typeof(object))))); + } + } + + private static Expression MakeSequence(IList expressions) + { + var last = expressions[expressions.Count - 1]; + return Expression.Convert(Expression.Call(typeof(ExecutionBuilder), "Sequence", null, Expression.NewArrayInit(typeof(object), expressions)), last.Type); + } private class Scope { + private ParameterExpression _document; + private Scope _outer; + private Dictionary _nameMap; + + internal Alias Alias { get; private set; } + + public Scope(Scope outer, ParameterExpression document, Alias alias, IEnumerable fields) + { + _outer = outer; + _document = document; + Alias = alias; + _nameMap = fields.Select((f, i) => new { f, i }).ToDictionary(x => x.f.Name, x => x.i); + } + + public bool TryGetValue(FieldExpression field, out ParameterExpression document, out int ordinal) + { + for (Scope s = this; s != null; s = s._outer) + { + if (field.Alias == s.Alias && _nameMap.TryGetValue(field.Name, out ordinal)) + { + document = _document; + return true; + } + } + document = null; + ordinal = 0; + return false; + } + } + + private class OuterParameterizer : MongoExpressionVisitor + { + private int _paramIndex; + private Alias _outerAlias; + private Dictionary _map; + + public Expression Parameterize(Expression expression, Alias outerAlias) + { + _outerAlias = outerAlias; + return Visit(expression); + } + + protected override Expression VisitProjection(ProjectionExpression projection) + { + SelectExpression select = (SelectExpression)Visit(projection.Source); + if (select != projection.Source) + return new ProjectionExpression(select, projection.Projector, projection.Aggregator); + return projection; + } + + protected override Expression VisitField(FieldExpression field) + { + if (field.Alias == _outerAlias) + { + NamedValueExpression nv; + if (!_map.TryGetValue(field, out nv)) + { + nv = new NamedValueExpression("n" + (_paramIndex++), field); + _map.Add(field, nv); + } + return nv; + } + return field; + } + } + + private class CompoundKey : IEquatable + { + private object[] _values; + private int _hashCode; + + public CompoundKey(params object[] values) + { + _values = values; + for (int i = 0, n = values.Length; i < n; i++) + { + object value = values[i]; + if (value != null) + _hashCode ^= (value.GetHashCode() + i); + } + } + + public override int GetHashCode() + { + return _hashCode; + } + + public override bool Equals(object obj) + { + return base.Equals(obj); + } + public bool Equals(CompoundKey other) + { + if (other == null || other._values.Length != _values.Length) + return false; + for (int i = 0, n = other._values.Length; i < n; i++) + { + if (!object.Equals(_values[i], other._values[i])) + return false; + } + return true; + } } } } \ No newline at end of file diff --git a/source/MongoDB/Linq/Expressions/MongoExpressionType.cs b/source/MongoDB/Linq/Expressions/MongoExpressionType.cs index 38dd7b35..f3ff6a63 100644 --- a/source/MongoDB/Linq/Expressions/MongoExpressionType.cs +++ b/source/MongoDB/Linq/Expressions/MongoExpressionType.cs @@ -17,6 +17,7 @@ internal enum MongoExpressionType Aggregate, AggregateSubquery, Scalar, - OuterJoined + OuterJoined, + NamedValue } } \ No newline at end of file diff --git a/source/MongoDB/Linq/Expressions/MongoExpressionVisitor.cs b/source/MongoDB/Linq/Expressions/MongoExpressionVisitor.cs index 470789fb..0f3ff63a 100644 --- a/source/MongoDB/Linq/Expressions/MongoExpressionVisitor.cs +++ b/source/MongoDB/Linq/Expressions/MongoExpressionVisitor.cs @@ -35,6 +35,8 @@ protected override Expression Visit(Expression exp) return VisitClientJoin((ClientJoinExpression)exp); case MongoExpressionType.OuterJoined: return VisitOuterJoined((OuterJoinedExpression)exp); + case MongoExpressionType.NamedValue: + return VisitNamedValue((NamedValueExpression)exp); default: return base.Visit(exp); } @@ -92,18 +94,9 @@ protected virtual Expression VisitJoin(JoinExpression join) return join; } - protected virtual Expression VisitSelect(SelectExpression select) + protected virtual Expression VisitNamedValue(NamedValueExpression namedValue) { - var from = VisitSource(select.From); - var where = Visit(select.Where); - var groupBy = Visit(select.GroupBy); - var orderBy = VisitOrderBy(select.OrderBy); - var skip = Visit(select.Skip); - var take = Visit(select.Take); - var fields = VisitFieldDeclarationList(select.Fields); - if (from != select.From || where != select.Where || orderBy != select.OrderBy || groupBy != select.GroupBy || skip != select.Skip || take != select.Take || fields != select.Fields) - return new SelectExpression(select.Alias, fields, from, where, orderBy, groupBy, select.IsDistinct, skip, take); - return select; + return namedValue; } protected virtual Expression VisitProjection(ProjectionExpression projection) @@ -152,6 +145,20 @@ protected virtual Expression VisitScalar(ScalarExpression scalar) return scalar; } + protected virtual Expression VisitSelect(SelectExpression select) + { + var from = VisitSource(select.From); + var where = Visit(select.Where); + var groupBy = Visit(select.GroupBy); + var orderBy = VisitOrderBy(select.OrderBy); + var skip = Visit(select.Skip); + var take = Visit(select.Take); + var fields = VisitFieldDeclarationList(select.Fields); + if (from != select.From || where != select.Where || orderBy != select.OrderBy || groupBy != select.GroupBy || skip != select.Skip || take != select.Take || fields != select.Fields) + return new SelectExpression(select.Alias, fields, from, where, orderBy, groupBy, select.IsDistinct, skip, take); + return select; + } + protected virtual Expression VisitSource(Expression source) { return Visit(source); diff --git a/source/MongoDB/Linq/Expressions/NamedValueExpression.cs b/source/MongoDB/Linq/Expressions/NamedValueExpression.cs new file mode 100644 index 00000000..75313caa --- /dev/null +++ b/source/MongoDB/Linq/Expressions/NamedValueExpression.cs @@ -0,0 +1,31 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Text; + +namespace MongoDB.Linq.Expressions +{ + internal class NamedValueExpression : MongoExpression + { + private readonly string _name; + private readonly Expression _value; + + public string Name + { + get { return _name; } + } + public Expression Value + { + get { return _value; } + } + + public NamedValueExpression(string name, Expression value) + : base(MongoExpressionType.NamedValue, value.Type) + { + _name = name; + _value = value; + } + + } +} diff --git a/source/MongoDB/Linq/MongoQueryObject.cs b/source/MongoDB/Linq/MongoQueryObject.cs index f83d9c09..28716f37 100644 --- a/source/MongoDB/Linq/MongoQueryObject.cs +++ b/source/MongoDB/Linq/MongoQueryObject.cs @@ -12,12 +12,6 @@ internal class MongoQueryObject private Document _query; private Document _sort; - /// - /// Gets or sets the aggregator. - /// - /// The aggregator. - public LambdaExpression Aggregator { get; set; } - /// /// Gets or sets the name of the collection. /// diff --git a/source/MongoDB/Linq/MongoQueryProvider.cs b/source/MongoDB/Linq/MongoQueryProvider.cs index e0675e17..9a139659 100644 --- a/source/MongoDB/Linq/MongoQueryProvider.cs +++ b/source/MongoDB/Linq/MongoQueryProvider.cs @@ -106,8 +106,31 @@ public TResult Execute(Expression expression) /// public object Execute(Expression expression) { - var queryObject = GetQueryObject(expression); - return ExecuteQueryObject(queryObject); + var plan = BuildExecutionPlan(expression); + + var lambda = expression as LambdaExpression; + if (lambda != null) + { + var fn = Expression.Lambda(lambda.Type, plan, lambda.Parameters); + return fn.Compile(); + } + else + { + var efn = Expression.Lambda>(Expression.Convert(plan, typeof(object))); + var fn = efn.Compile(); + return fn(); + } + } + + /// + /// Gets the query object. + /// + /// The expression. + /// + internal MongoQueryObject GetQueryObject(Expression expression) + { + var projection = Translate(expression); + return new MongoQueryObjectBuilder().Build(projection); } /// @@ -115,47 +138,52 @@ public object Execute(Expression expression) /// /// The query object. /// - internal object ExecuteQueryObject(MongoQueryObject queryObject){ + internal IEnumerable ExecuteQueryObject(MongoQueryObject queryObject){ if (queryObject.IsCount) - return ExecuteCount(queryObject); + return ExecuteCount(queryObject); else if (queryObject.IsMapReduce) - return ExecuteMapReduce(queryObject); - return ExecuteFind(queryObject); + return ExecuteMapReduce(queryObject); + return ExecuteFind(queryObject); } - /// - /// Gets the query object. - /// - /// The expression. - /// - internal MongoQueryObject GetQueryObject(Expression expression) + private Expression BuildExecutionPlan(Expression expression) { - var projection = expression as ProjectionExpression; - if(projection == null) - { - expression = PartialEvaluator.Evaluate(expression, CanBeEvaluatedLocally); - - expression = new FieldBinder().Bind(expression); - expression = new QueryBinder(this, expression).Bind(expression); - expression = new AggregateRewriter().Rewrite(expression); - expression = new RedundantFieldRemover().Remove(expression); - expression = new RedundantSubqueryRemover().Remove(expression); - expression = new RedundantJoinRemover().Remove(expression); - - expression = new ClientJoinProjectionRewriter().Rewrite(expression); - expression = new RedundantFieldRemover().Remove(expression); - expression = new RedundantSubqueryRemover().Remove(expression); - expression = new RedundantJoinRemover().Remove(expression); - - expression = new OrderByRewriter().Rewrite(expression); - expression = new RedundantFieldRemover().Remove(expression); - expression = new RedundantSubqueryRemover().Remove(expression); - expression = new RedundantJoinRemover().Remove(expression); - - projection = (ProjectionExpression)expression; - } + var lambda = expression as LambdaExpression; + if (lambda != null) + expression = lambda.Body; - return new MongoQueryObjectBuilder().Build(projection); + var projection = Translate(expression); + + var rootQueryable = new RootQueryableFinder().Find(expression); + var provider = Expression.Convert( + Expression.Property(rootQueryable, typeof(IQueryable).GetProperty("Provider")), + typeof(MongoQueryProvider)); + + return new ExecutionBuilder().Build(projection, provider); + } + + private ProjectionExpression Translate(Expression expression) + { + expression = PartialEvaluator.Evaluate(expression, CanBeEvaluatedLocally); + + expression = new FieldBinder().Bind(expression); + expression = new QueryBinder(this, expression).Bind(expression); + expression = new AggregateRewriter().Rewrite(expression); + expression = new RedundantFieldRemover().Remove(expression); + expression = new RedundantSubqueryRemover().Remove(expression); + expression = new RedundantJoinRemover().Remove(expression); + + expression = new ClientJoinProjectionRewriter().Rewrite(expression); + expression = new RedundantFieldRemover().Remove(expression); + expression = new RedundantSubqueryRemover().Remove(expression); + expression = new RedundantJoinRemover().Remove(expression); + + expression = new OrderByRewriter().Rewrite(expression); + expression = new RedundantFieldRemover().Remove(expression); + expression = new RedundantSubqueryRemover().Remove(expression); + expression = new RedundantJoinRemover().Remove(expression); + + return (ProjectionExpression)expression; } /// @@ -192,23 +220,20 @@ private bool CanBeEvaluatedLocally(Expression expression) /// /// The query object. /// - private object ExecuteCount(MongoQueryObject queryObject) + private IEnumerable ExecuteCount(MongoQueryObject queryObject) { var miGetCollection = typeof(IMongoDatabase).GetMethods().Where(m => m.Name == "GetCollection" && m.GetGenericArguments().Length == 1 && m.GetParameters().Length == 1).Single().MakeGenericMethod(queryObject.DocumentType); var collection = miGetCollection.Invoke(queryObject.Database, new[] { queryObject.CollectionName }); + IEnumerable documents; if (queryObject.Query == null) - return Convert.ToInt32(collection.GetType().GetMethod("Count", Type.EmptyTypes).Invoke(collection, null)); + documents = new[] { (TDocument)collection.GetType().GetMethod("Count", Type.EmptyTypes).Invoke(collection, null) }; + documents = new[] { (TDocument)collection.GetType().GetMethod("Count", new[] { typeof(object) }).Invoke(collection, new[] { queryObject.Query }) }; - return Convert.ToInt32(collection.GetType().GetMethod("Count", new[] { typeof(object) }).Invoke(collection, new[] { queryObject.Query })); + return Project(documents, (Func)queryObject.Projector.Compile()); } - /// - /// Executes the select. - /// - /// The query object. - /// - private object ExecuteFind(MongoQueryObject queryObject) + private IEnumerable ExecuteFind(MongoQueryObject queryObject) { var miGetCollection = typeof(IMongoDatabase).GetMethods().Where(m => m.Name == "GetCollection" && m.GetGenericArguments().Length == 1 && m.GetParameters().Length == 1).Single().MakeGenericMethod(queryObject.DocumentType); var collection = miGetCollection.Invoke(queryObject.Database, new[] { queryObject.CollectionName }); @@ -231,16 +256,11 @@ private object ExecuteFind(MongoQueryObject queryObject) cursorType.GetMethod("Limit").Invoke(cursor, new object[] { queryObject.NumberToLimit }); cursorType.GetMethod("Skip").Invoke(cursor, new object[] { queryObject.NumberToSkip }); - var executor = GetExecutor(queryObject.DocumentType, queryObject.Projector, queryObject.Aggregator, true); - return executor.Compile().DynamicInvoke(cursor.GetType().GetProperty("Documents").GetValue(cursor, null)); + var documents = (IEnumerable)cursor.GetType().GetProperty("Documents").GetValue(cursor, null); + return Project(documents, (Func)queryObject.Projector.Compile()); } - /// - /// Executes the map reduce. - /// - /// The query object. - /// - private object ExecuteMapReduce(MongoQueryObject queryObject) + private IEnumerable ExecuteMapReduce(MongoQueryObject queryObject) { var miGetCollection = typeof(IMongoDatabase).GetMethods().Where(m => m.Name == "GetCollection" && m.GetGenericArguments().Length == 1 && m.GetParameters().Length == 1).Single().MakeGenericMethod(queryObject.DocumentType); var collection = miGetCollection.Invoke(queryObject.Database, new[] { queryObject.CollectionName }); @@ -251,40 +271,35 @@ private object ExecuteMapReduce(MongoQueryObject queryObject) mapReduce.Finalize = new Code(queryObject.FinalizerFunction); mapReduce.Query = queryObject.Query; - if(queryObject.Sort != null) + if (queryObject.Sort != null) mapReduce.Sort = queryObject.Sort; mapReduce.Limit = queryObject.NumberToLimit; if (queryObject.NumberToSkip != 0) throw new InvalidQueryException("MapReduce queries do no support Skips."); - var executor = GetExecutor(typeof(Document), queryObject.Projector, queryObject.Aggregator, true); - return executor.Compile().DynamicInvoke(mapReduce.Documents); + var documents = (IEnumerable)mapReduce.Documents; + return Project(documents, (Func)queryObject.Projector.Compile()); } - /// - /// Gets the executor. - /// - /// Type of the document. - /// The projector. - /// The aggregator. - /// if set to true [box return]. - /// - private static LambdaExpression GetExecutor(Type documentType, LambdaExpression projector, LambdaExpression aggregator, bool boxReturn) + private IEnumerable Project(IEnumerable documents, Func projector) + { + foreach (var doc in documents) + { + yield return projector(doc); + } + } + + private static LambdaExpression GetExecutor(Type documentType, LambdaExpression projector, bool boxReturn) { var documents = Expression.Parameter(typeof(IEnumerable<>).MakeGenericType(documentType), "documents"); Expression body = Expression.New(typeof(ProjectionReader<,>).MakeGenericType(documentType, projector.Body.Type).GetConstructors()[0], documents, projector); - if (aggregator != null) - body = Expression.Invoke(aggregator, body); if (boxReturn && body.Type != typeof(object)) body = Expression.Convert(body, typeof(object)); return Expression.Lambda(body, documents); } - /// - /// attempt to isolate a sub-expression that accesses a Query object - /// private class RootQueryableFinder : MongoExpressionVisitor { private Expression _root; diff --git a/source/MongoDB/Linq/Translators/ExpressionReplacer.cs b/source/MongoDB/Linq/Translators/ExpressionReplacer.cs new file mode 100644 index 00000000..983c14fe --- /dev/null +++ b/source/MongoDB/Linq/Translators/ExpressionReplacer.cs @@ -0,0 +1,38 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Text; + +using MongoDB.Linq.Expressions; + +namespace MongoDB.Linq.Translators +{ + internal class ExpressionReplacer : MongoExpressionVisitor + { + private Expression _replaceWith; + private Expression _searchFor; + + public Expression Replace(Expression expression, Expression searchFor, Expression replaceWith) + { + this._searchFor = searchFor; + this._replaceWith = replaceWith; + return Visit(expression); + } + + public Expression ReplaceAll(Expression expression, Expression[] searchFor, Expression[] replaceWith) + { + for (int i = 0, n = searchFor.Length; i < n; i++) + expression = Replace(expression, searchFor[i], replaceWith[i]); + return expression; + } + + protected override Expression Visit(Expression exp) + { + if (exp == _searchFor) + return _replaceWith; + return base.Visit(exp); + } + + } +} diff --git a/source/MongoDB/Linq/Translators/MongoQueryObjectBuilder.cs b/source/MongoDB/Linq/Translators/MongoQueryObjectBuilder.cs index e3bb32fd..af0bf29b 100644 --- a/source/MongoDB/Linq/Translators/MongoQueryObjectBuilder.cs +++ b/source/MongoDB/Linq/Translators/MongoQueryObjectBuilder.cs @@ -80,10 +80,6 @@ protected override Expression VisitSelect(SelectExpression select) protected override Expression VisitProjection(ProjectionExpression projection) { Visit(projection.Source); - - _queryObject.Projector = new ProjectionBuilder().Build(_queryObject, projection.Projector); - _queryObject.Aggregator = projection.Aggregator; - return projection; } diff --git a/source/MongoDB/Linq/Translators/NamedValueGatherer.cs b/source/MongoDB/Linq/Translators/NamedValueGatherer.cs new file mode 100644 index 00000000..4a37427d --- /dev/null +++ b/source/MongoDB/Linq/Translators/NamedValueGatherer.cs @@ -0,0 +1,29 @@ +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Linq; +using System.Linq.Expressions; +using System.Text; + +using MongoDB.Linq.Expressions; + +namespace MongoDB.Linq.Translators +{ + internal class NamedValueGatherer : MongoExpressionVisitor + { + private List _namedValues; + + public ReadOnlyCollection Gather(Expression expr) + { + _namedValues = new List(); + Visit(expr); + return _namedValues.AsReadOnly(); + } + + protected override Expression VisitNamedValue(NamedValueExpression value) + { + _namedValues.Add(value); + return value; + } + } +} diff --git a/source/MongoDB/Linq/Translators/Parameterizer.cs b/source/MongoDB/Linq/Translators/Parameterizer.cs new file mode 100644 index 00000000..bee7b4c7 --- /dev/null +++ b/source/MongoDB/Linq/Translators/Parameterizer.cs @@ -0,0 +1,86 @@ +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Linq; +using System.Linq.Expressions; +using System.Text; + +using MongoDB.Linq.Expressions; + +namespace MongoDB.Linq.Translators +{ + internal class Parameterizer : MongoExpressionVisitor + { + private int _paramIndex; + private Dictionary _map; + private Dictionary _pmap; + + public Expression Parameterize(Expression expression) + { + return Visit(expression); + } + + protected override Expression VisitConstant(ConstantExpression c) + { + if (c.Value != null && !IsNumeric(c.Value.GetType())) + { + NamedValueExpression nv; + if (!_map.TryGetValue(c.Value, out nv)) + { + var name = "p" + (_paramIndex++); + nv = new NamedValueExpression(name, c); + _map.Add(c.Value, nv); + } + return nv; + } + return c; + } + + protected override Expression VisitParameter(ParameterExpression p) + { + return GetNamedValue(p); + } + + protected override Expression VisitProjection(ProjectionExpression projection) + { + SelectExpression select = (SelectExpression)Visit(projection.Source); + if (select != projection.Source) + return new ProjectionExpression(select, projection.Projector, projection.Aggregator); + return projection; + } + + private Expression GetNamedValue(Expression e) + { + NamedValueExpression nv; + if (!_pmap.TryGetValue(e, out nv)) + { + string name = "p" + (_paramIndex++); + nv = new NamedValueExpression(name, e); + _pmap.Add(e, nv); + } + return nv; + } + + private static bool IsNumeric(Type type) + { + switch (Type.GetTypeCode(type)) + { + case TypeCode.Boolean: + case TypeCode.Byte: + case TypeCode.Decimal: + case TypeCode.Double: + case TypeCode.Int16: + case TypeCode.Int32: + case TypeCode.Int64: + case TypeCode.SByte: + case TypeCode.Single: + case TypeCode.UInt16: + case TypeCode.UInt32: + case TypeCode.UInt64: + return true; + default: + return false; + } + } + } +} \ No newline at end of file diff --git a/source/MongoDB/Linq/Translators/QueryBinder.cs b/source/MongoDB/Linq/Translators/QueryBinder.cs index 664869b4..c1fb8ff0 100644 --- a/source/MongoDB/Linq/Translators/QueryBinder.cs +++ b/source/MongoDB/Linq/Translators/QueryBinder.cs @@ -363,12 +363,12 @@ private Expression BindJoin(Type resultType, Expression outerSource, Expression { var outerProjection = VisitSequence(outerSource); var innerProjection = VisitSequence(innerSource); - map[outerKey.Parameters[0]] = outerProjection.Projector; + _map[outerKey.Parameters[0]] = outerProjection.Projector; var outerKeyExpression = Visit(outerKey.Body); - map[innerKey.Parameters[0]] = innerProjection.Projector; + _map[innerKey.Parameters[0]] = innerProjection.Projector; var innerKeyExpression = Visit(innerKey.Body); - map[resultSelector.Parameters[0]] = outerProjection.Projector; - map[resultSelector.Parameters[1]] = innerProjection.Projector; + _map[resultSelector.Parameters[0]] = outerProjection.Projector; + _map[resultSelector.Parameters[1]] = innerProjection.Projector; var resultExpression = Visit(resultSelector.Body); var join = new JoinExpression(JoinType.InnerJoin, outerProjection.Source, innerProjection.Source, Expression.Equal(outerKeyExpression, innerKeyExpression)); var alias = new Alias(); diff --git a/source/MongoDB/MongoDB.csproj b/source/MongoDB/MongoDB.csproj index 95af55c1..2bbfe8b3 100644 --- a/source/MongoDB/MongoDB.csproj +++ b/source/MongoDB/MongoDB.csproj @@ -128,7 +128,11 @@ + + + +