diff --git a/src/SimpleInjector.Tests.Unit/RegisterCollectionTests.cs b/src/SimpleInjector.Tests.Unit/RegisterCollectionTests.cs index d0376d6cc..906364e74 100644 --- a/src/SimpleInjector.Tests.Unit/RegisterCollectionTests.cs +++ b/src/SimpleInjector.Tests.Unit/RegisterCollectionTests.cs @@ -1661,6 +1661,24 @@ public void GetInstance_TypeDependingOnICollection_InjectsTheRegisteredCollectio AssertThat.IsInstanceOfType(typeof(PluginImpl2), collection.Second()); } + [TestMethod] + public void GetInstance_TypeDependingOnCollection_InjectsTheRegisteredCollection() + { + // Arrange + var container = ContainerFactory.New(); + + container.Collection.Register(new[] { typeof(PluginImpl), typeof(PluginImpl2) }); + + // Act + Collection collection = + container.GetInstance>>().Dependency; + + // Assert + Assert.AreEqual(2, collection.Count); + AssertThat.IsInstanceOfType(typeof(PluginImpl), collection.First()); + AssertThat.IsInstanceOfType(typeof(PluginImpl2), collection.Second()); + } + [TestMethod] public void GetInstance_TypeDependingOnICollection_InjectsTheRegisteredCollectionOfDecorators() { @@ -1714,6 +1732,23 @@ public void GetInstance_TypeDependingOnIList_InjectsTheRegisteredList() AssertThat.IsInstanceOfType(typeof(PluginImpl2), list[1]); } + [TestMethod] + public void GetInstance_TypeDependingOnList_InjectsTheRegisteredList() + { + // Arrange + var container = ContainerFactory.New(); + + container.Collection.Register(new[] { typeof(PluginImpl), typeof(PluginImpl2) }); + + // Act + List list = container.GetInstance>>().Dependency; + + // Assert + Assert.AreEqual(2, list.Count); + AssertThat.IsInstanceOfType(typeof(PluginImpl), list[0]); + AssertThat.IsInstanceOfType(typeof(PluginImpl2), list[1]); + } + [TestMethod] public void GetInstance_TypeDependingOnIList_InjectsTheRegisteredListOfDecorators() { @@ -1806,6 +1841,49 @@ public void GetInstance_CalledMultipleTimesOnContainerControlledCollection_Injec "array has been cached."); } + [TestMethod] + public void GetInstance_CalledMultipleTimesOnContainerControlledCollection_InjectsANewListOnEachRequest() + { + // Arrange + var container = ContainerFactory.New(); + + container.RegisterSingleton(); + + container.Collection.Register(new[] { typeof(ConcreteCommand) }); + + // Act + var injectedList = container.GetInstance>>().Dependency; + + injectedList[0] = null; + + injectedList = container.GetInstance>>().Dependency; + + // Assert + Assert.IsNotNull(injectedList[0], + "The element in the array is expected NOT to be null. When it is null, it means that the " + + "array has been cached."); + } + + [TestMethod] + public void GetInstance_AResolvedCollectionOfT_CanNotBeChanged() + { + // Arrange + var container = ContainerFactory.New(); + + container.RegisterSingleton(); + + container.Collection.Register(new[] { typeof(ConcreteCommand) }); + + var collection = container.GetInstance>(); + + // Act + Action action = () => collection[0] = new ConcreteCommand(); + + // Assert + Assert.ThrowsException(action, + "Changing the collection should be blocked by Simple Injector."); + } + [TestMethod] public void GetInstance_CalledMultipleTimesOnContainerControlledSingletons_StillInjectsANewArrayOnEachRequest() { @@ -1815,17 +1893,118 @@ public void GetInstance_CalledMultipleTimesOnContainerControlledSingletons_Still container.Collection.Register(new ConcreteCommand()); // Act - var composite = container.GetInstance(); + ICommand[] commands = container.GetInstance>().Dependency; - composite.Commands[0] = null; + commands[0] = null; - composite = container.GetInstance(); + commands = container.GetInstance>().Dependency; // Assert - Assert.IsNotNull(composite.Commands[0], + Assert.IsNotNull(commands[0], "The element in the array is expected NOT to be null. When it is null, it means that the " + "array has been cached."); } + + [TestMethod] + public void GetInstance_ResolvingACollectionOfT_IsSingleton() + { + // Arrange + var container = ContainerFactory.New(); + + container.Collection.Register(new ConcreteCommand()); + + // Act + Collection commands1 = container.GetInstance>(); + Collection commands2 = container.GetInstance>(); + + // Assert + Assert.AreSame(commands1, commands2, + "Collection is just a wrapper for a container controlled collection IList. And should therefore be a singleton."); + + Assert.AreSame(Lifestyle.Singleton, container.GetRegistration(typeof(Collection)).Lifestyle); + } + + public class MyList : IList + { + public T this[int index] + { + get => default(T); + set { } + } + + public int Count => throw new NotImplementedException(); + + public bool IsReadOnly => throw new NotImplementedException(); + + public void Add(T item) + { + throw new NotImplementedException(); + } + + public void Clear() + { + throw new NotImplementedException(); + } + + public bool Contains(T item) + { + throw new NotImplementedException(); + } + + public void CopyTo(T[] array, int arrayIndex) + { + throw new NotImplementedException(); + } + + public IEnumerator GetEnumerator() + { + throw new NotImplementedException(); + } + + public int IndexOf(T item) + { + throw new NotImplementedException(); + } + + public void Insert(int index, T item) + { + throw new NotImplementedException(); + } + + public bool Remove(T item) + { + throw new NotImplementedException(); + } + + public void RemoveAt(int index) + { + throw new NotImplementedException(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + throw new NotImplementedException(); + } + } + + [TestMethod] + public void GetInstance_CollectionOfT_FunctionsAsAStream() + { + // Arrange + var container = ContainerFactory.New(); + + container.Collection.Register(typeof(ConcreteCommand)); + + // Act + Collection commands = container.GetInstance>(); + + // Assert + Assert.AreNotSame(commands[0], commands[0], + "Requesting an instance from the collection thould cause a callback into the Container " + + "causing the type to be resolved again."); + + Assert.AreSame(Lifestyle.Singleton, container.GetRegistration(typeof(Collection)).Lifestyle); + } [TestMethod] public void GetInstance_CalledMultipleTimesOnContainerUncontrolledCollection_StillInjectsANewArrayOnEachRequest() @@ -1866,7 +2045,41 @@ public void GetRegistration_RequestingArrayRegistrationContainerControlledCollec var registration = container.GetRegistration(typeof(ICommand[])); // Assert - Assert.AreEqual(Lifestyle.Transient, registration.Lifestyle); + Assert.AreEqual(Lifestyle.Transient, registration.Lifestyle, + "Array must be resolved as transient, because it is a mutable type."); + } + + [TestMethod] + public void GetRegistration_RequestingListRegistrationContainerControlledCollection_HasTheTransientLifestyle() + { + // Arrange + var container = ContainerFactory.New(); + + container.Collection.Register(new[] { typeof(ConcreteCommand) }); + + // Act + var registration = container.GetRegistration(typeof(List)); + + // Assert + Assert.AreEqual(Lifestyle.Transient, registration.Lifestyle, + "List must be resolved as transient, because it is a mutable type."); + } + + [TestMethod] + public void GetRegistration_RequestingCollectionRegistrationContainerControlledCollection_HasTheTransientLifestyle() + { + // Arrange + var container = ContainerFactory.New(); + + container.Collection.Register(new[] { typeof(ConcreteCommand) }); + + // Act + var registration = container.GetRegistration(typeof(List)); + + // Assert + Assert.AreEqual(Lifestyle.Transient, registration.Lifestyle, + "Although Collection technically doesn't have to a a transient, we chose to keep it that way. " + + "See #545 for more details."); } [TestMethod] @@ -2271,7 +2484,7 @@ public void CollectionRegister_SupplyingRegistrationsForVariantAbstractionWithOn Assert.AreEqual(2, deriveds.Count(), "Two registrations were made for IContra."); Assert.AreEqual(0, bases.Count(), "No registrations were made for IContra."); } - + private static void Assert_IsNotAMutableCollection(IEnumerable collection) { string assertMessage = "The container should wrap mutable types to make it impossible for " + diff --git a/src/SimpleInjector/Container.Resolving.cs b/src/SimpleInjector/Container.Resolving.cs index 8c3fa3a68..b2fd76f76 100644 --- a/src/SimpleInjector/Container.Resolving.cs +++ b/src/SimpleInjector/Container.Resolving.cs @@ -24,6 +24,7 @@ namespace SimpleInjector { using System; using System.Collections.Generic; + using System.Collections.ObjectModel; using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Linq.Expressions; @@ -40,6 +41,7 @@ namespace SimpleInjector public partial class Container : IServiceProvider { private static readonly MethodInfo EnumerableToArrayMethod = typeof(Enumerable).GetMethod("ToArray"); + private static readonly MethodInfo EnumerableToListMethod = typeof(Enumerable).GetMethod("ToList"); private readonly Dictionary> resolveUnregisteredTypeRegistrations = new Dictionary>(); @@ -355,11 +357,17 @@ private InstanceProducer BuildInstanceProducerForType(InjectionConsume { return this.TryBuildInstanceProducerThroughUnregisteredTypeResolution(serviceType) ?? - this.TryBuildArrayInstanceProducer(serviceType) ?? - this.TryBuildInstanceProducerForCollection(serviceType) ?? + this.TryBuildInstanceProducerForCollectionType(serviceType) ?? tryBuildInstanceProducerForConcreteType(); } + private InstanceProducer TryBuildInstanceProducerForCollectionType(Type serviceType) + { + return + this.TryBuildInstanceProducerForMutableCollection(serviceType) ?? + this.TryBuildInstanceProducerForStream(serviceType); + } + // Instead of wrapping the complete method in a lock, we lock inside the individual methods. We // don't want to hold a lock while calling back into user code, because who knows what the user // is doing there. We don't want a dead lock. @@ -417,47 +425,51 @@ private InstanceProducer TryGetInstanceProducerThroughResolveUnregisteredTypeEve } } - private InstanceProducer TryBuildArrayInstanceProducer(Type serviceType) + private InstanceProducer TryBuildInstanceProducerForMutableCollection(Type serviceType) { if (serviceType.IsArray) { - Type elementType = serviceType.GetElementType(); - - // We don't auto-register collections for ambiguous types. - if (elementType.IsValueType() || Types.IsAmbiguousType(elementType)) - { - return null; - } + return this.BuildInstanceProducerForMutableCollectionType( + serviceType, serviceType.GetElementType(), MutableCollectionType.Array); + } + else if (typeof(List<>).IsGenericTypeDefinitionOf(serviceType)) + { + return this.BuildInstanceProducerForMutableCollectionType( + serviceType, serviceType.GetGenericArguments().FirstOrDefault(), MutableCollectionType.List); + } + else + { + return null; + } + } - // GetAllInstances locks the container - bool isContainerControlledCollection = - this.GetAllInstances(elementType) is IContainerControlledCollection; + private enum MutableCollectionType { Array, List }; - if (isContainerControlledCollection) - { - return this.BuildArrayProducerFromControlledCollection(serviceType, elementType); - } - else - { - return this.BuildArrayProducerFromUncontrolledCollection(serviceType, elementType); - } + private InstanceProducer BuildInstanceProducerForMutableCollectionType( + Type serviceType, Type elementType, MutableCollectionType type) + { + // We don't auto-register collections for ambiguous types. + if (Types.IsAmbiguousOrValueType(elementType)) + { + return null; } - return null; + // GetAllInstances locks the container + if (this.GetAllInstances(elementType) is IContainerControlledCollection) + { + return this.BuildMutableCollectionProducerFromControlledCollection(serviceType, elementType, type); + } + else + { + return this.BuildMutableCollectionProducerFromUncontrolledCollection(serviceType, elementType); + } } - private InstanceProducer BuildArrayProducerFromControlledCollection(Type serviceType, Type elementType) + private InstanceProducer BuildMutableCollectionProducerFromControlledCollection( + Type serviceType, Type elementType, MutableCollectionType collectionType) { - var arrayMethod = EnumerableToArrayMethod.MakeGenericMethod(elementType); - - IEnumerable singletonCollection = this.GetAllInstances(elementType); - - var collectionExpression = Expression.Constant( - singletonCollection, - typeof(IEnumerable<>).MakeGenericType(elementType)); - - // Build the call "Enumerable.ToArray(collection)". - var arrayExpression = Expression.Call(arrayMethod, collectionExpression); + Expression expression = + BuildMutableCollectionExpressionFromControlledCollection(serviceType, elementType); // Technically, we could determine the longest lifestyle out of the elements of the collection, // instead of using Transient here. This would make it less likely for the user to get false @@ -465,38 +477,60 @@ private InstanceProducer BuildArrayProducerFromControlledCollection(Type service // longest lifestyle might cause the array to be cached in a way that is incorrect, because // who knows what kind of lifestyles the used created. Registration registration = - new ExpressionRegistration(arrayExpression, serviceType, Lifestyle.Transient, this); + new ExpressionRegistration(expression, serviceType, Lifestyle.Transient, this); + + return new InstanceProducer(serviceType, registration) + { + IsContainerAutoRegistered = !this.GetAllInstances(elementType).Any() + }; + } - var producer = new InstanceProducer(serviceType, registration); + private Expression BuildMutableCollectionExpressionFromControlledCollection( + Type serviceType, Type elementType) + { + var streamExpression = Expression.Constant( + value: this.GetAllInstances(elementType), + type: typeof(IEnumerable<>).MakeGenericType(elementType)); - if (!singletonCollection.Any()) + if (serviceType.IsArray) { - producer.IsContainerAutoRegistered = true; + // builds: Enumerable.ToArray(collection) + return Expression.Call( + EnumerableToArrayMethod.MakeGenericMethod(elementType), + streamExpression); } + else + { + // builds: new List(collection) + var listConstructor = typeof(List<>).MakeGenericType(elementType) + .GetConstructor(new[] { typeof(IEnumerable<>).MakeGenericType(elementType) }); - return producer; + return Expression.New(listConstructor, streamExpression); + } } - private InstanceProducer BuildArrayProducerFromUncontrolledCollection(Type serviceType, Type elementType) + private InstanceProducer BuildMutableCollectionProducerFromUncontrolledCollection( + Type serviceType, Type elementType) { - var arrayMethod = EnumerableToArrayMethod.MakeGenericMethod(elementType); - var enumerableProducer = this.GetRegistration(typeof(IEnumerable<>).MakeGenericType(elementType)); - var enumerableExpression = enumerableProducer.BuildExpression(); + Expression enumerableExpression = enumerableProducer.BuildExpression(); - var arrayExpression = Expression.Call(arrayMethod, enumerableExpression); + var expression = Expression.Call( + method: serviceType.IsArray + ? EnumerableToArrayMethod.MakeGenericMethod(elementType) + : EnumerableToListMethod.MakeGenericMethod(elementType), + arg0: enumerableExpression); Registration registration = - new ExpressionRegistration(arrayExpression, serviceType, Lifestyle.Transient, this); - - var producer = new InstanceProducer(serviceType, registration); + new ExpressionRegistration(expression, serviceType, Lifestyle.Transient, this); - producer.IsContainerAutoRegistered = true; - - return producer; + return new InstanceProducer(serviceType, registration) + { + IsContainerAutoRegistered = true + }; } - private InstanceProducer TryBuildInstanceProducerForCollection(Type serviceType) + private InstanceProducer TryBuildInstanceProducerForStream(Type serviceType) { if (!Types.IsGenericCollectionType(serviceType)) { @@ -518,7 +552,7 @@ private InstanceProducer TryBuildInstanceProducerForCollection(Type serviceType) if (!this.emptyAndRedirectedCollectionRegistrationCache.TryGetValue(serviceType, out producer)) { // This call might lock the container - producer = this.TryBuildCollectionInstanceProducer(serviceType); + producer = this.TryBuildStreamInstanceProducer(serviceType); this.emptyAndRedirectedCollectionRegistrationCache[serviceType] = producer; } @@ -527,32 +561,39 @@ private InstanceProducer TryBuildInstanceProducerForCollection(Type serviceType) } } - private InstanceProducer TryBuildCollectionInstanceProducer(Type collectionType) + private InstanceProducer TryBuildStreamInstanceProducer(Type collectionType) { Type serviceTypeDefinition = collectionType.GetGenericTypeDefinition(); - if (serviceTypeDefinition != typeof(IEnumerable<>)) + if (serviceTypeDefinition == typeof(IEnumerable<>)) { - Type elementType = collectionType.GetGenericArguments()[0]; - - var collection = this.GetAllInstances(elementType) as IContainerControlledCollection; + return null; + } - if (collection != null) - { - var registration = SingletonLifestyle.CreateSingleInstanceRegistration(collectionType, collection, this); + Type elementType = collectionType.GetGenericArguments()[0]; - var producer = new InstanceProducer(collectionType, registration); + object stream = this.GetAllInstances(elementType); - if (!((IEnumerable)collection).Any()) - { - producer.IsContainerAutoRegistered = true; - } + if (!(stream is IContainerControlledCollection)) + { + return null; + } - return producer; - } + // We need special handling for Collection, because the ContainerControlledCollection does not + // (and can't) inherit from Collection. So we have to wrap that stream into a Collection. + if (serviceTypeDefinition == typeof(Collection<>)) + { + Type listType = typeof(IList<>).MakeGenericType(elementType); + stream = collectionType.GetConstructor(new[] { listType }).Invoke(new[] { stream }); } - return null; + var registration = + SingletonLifestyle.CreateSingleInstanceRegistration(collectionType, stream, this); + + return new InstanceProducer(collectionType, registration) + { + IsContainerAutoRegistered = !((IEnumerable)stream).Any() + }; } private InstanceProducer BuildEmptyCollectionInstanceProducerForEnumerable(Type enumerableType) diff --git a/src/SimpleInjector/Types.cs b/src/SimpleInjector/Types.cs index 5ae00a89e..dd3d774e4 100644 --- a/src/SimpleInjector/Types.cs +++ b/src/SimpleInjector/Types.cs @@ -24,6 +24,7 @@ namespace SimpleInjector { using System; using System.Collections.Generic; + using System.Collections.ObjectModel; using System.Globalization; using System.Linq; using System.Reflection; @@ -125,7 +126,8 @@ internal static bool IsGenericCollectionType(Type serviceType) #endif serviceTypeDefinition == typeof(IEnumerable<>) || serviceTypeDefinition == typeof(IList<>) || - serviceTypeDefinition == typeof(ICollection<>); + serviceTypeDefinition == typeof(ICollection<>) || + serviceTypeDefinition == typeof(Collection<>); } // Return a list of all base types T inherits, all interfaces T implements and T itself.