diff --git a/src/Splat.Microsoft.Extensions.DependencyInjection.Tests/Splat.Microsoft.Extensions.DependencyInjection.Tests.csproj b/src/Splat.Microsoft.Extensions.DependencyInjection.Tests/Splat.Microsoft.Extensions.DependencyInjection.Tests.csproj index e6117b078..3244e46eb 100644 --- a/src/Splat.Microsoft.Extensions.DependencyInjection.Tests/Splat.Microsoft.Extensions.DependencyInjection.Tests.csproj +++ b/src/Splat.Microsoft.Extensions.DependencyInjection.Tests/Splat.Microsoft.Extensions.DependencyInjection.Tests.csproj @@ -1,7 +1,7 @@  - net6.0 + net6.0;net8.0 $(NoWarn);1591;CA1707;SA1633;CA2000 false enable diff --git a/src/Splat.Microsoft.Extensions.DependencyInjection/MicrosoftDependencyResolver.cs b/src/Splat.Microsoft.Extensions.DependencyInjection/MicrosoftDependencyResolver.cs index f24032995..67e0e8d60 100644 --- a/src/Splat.Microsoft.Extensions.DependencyInjection/MicrosoftDependencyResolver.cs +++ b/src/Splat.Microsoft.Extensions.DependencyInjection/MicrosoftDependencyResolver.cs @@ -3,9 +3,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for full license information. -using System.Collections.Concurrent; -using System.Data; -using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; using Microsoft.Extensions.DependencyInjection; namespace Splat.Microsoft.Extensions.DependencyInjection; @@ -17,7 +15,6 @@ namespace Splat.Microsoft.Extensions.DependencyInjection; public class MicrosoftDependencyResolver : IDependencyResolver { private const string ImmutableExceptionMessage = "This container has already been built and cannot be modified."; - private static readonly Type _dictionaryType = typeof(ContractDictionary<>); private readonly object _syncLock = new(); private IServiceCollection? _serviceCollection; private bool _isImmutable; @@ -91,7 +88,7 @@ public virtual IEnumerable GetServices(Type? serviceType, string? contra var isNull = serviceType is null; serviceType ??= typeof(NullServiceType); - IEnumerable services; + IEnumerable services = Enumerable.Empty(); if (contract is null || string.IsNullOrWhiteSpace(contract)) { @@ -99,21 +96,19 @@ public virtual IEnumerable GetServices(Type? serviceType, string? contra services = ServiceProvider.GetServices(serviceType) .Where(a => a is not null) .Select(a => a!); - - if (isNull) - { - services = services - .Cast() - .Select(nst => nst.Factory()!); - } } - else + else if (ServiceProvider is IKeyedServiceProvider serviceProvider) + { + services = serviceProvider.GetKeyedServices(serviceType, contract) + .Where(a => a is not null) + .Select(a => a!); + } + + if (isNull) { - var dic = GetContractDictionary(serviceType, false); - services = dic? - .GetFactories(contract) - .Select(f => f()!) - ?? Array.Empty(); + services = services + .Cast() + .Select(nst => nst.Factory()!); } return services; @@ -142,9 +137,10 @@ public virtual void Register(Func factory, Type? serviceType, string? c } else { - var dic = GetContractDictionary(serviceType, true); - - dic?.AddFactory(contract, factory); + _serviceCollection?.AddKeyedTransient(serviceType, contract, (_, __) => + isNull + ? new NullServiceType(factory) + : factory()!); } // required so that it gets rebuilt if not injected externally. @@ -166,7 +162,7 @@ public virtual void UnregisterCurrent(Type? serviceType, string? contract = null { if (contract is null || string.IsNullOrWhiteSpace(contract)) { - var sd = _serviceCollection?.LastOrDefault(s => s.ServiceType == serviceType); + var sd = _serviceCollection?.LastOrDefault(s => !s.IsKeyedService && s.ServiceType == serviceType); if (sd is not null) { _serviceCollection?.Remove(sd); @@ -174,14 +170,10 @@ public virtual void UnregisterCurrent(Type? serviceType, string? contract = null } else { - var dic = GetContractDictionary(serviceType, false); - if (dic is not null) + var sd = _serviceCollection?.LastOrDefault(sd => MatchesKeyedContract(serviceType, contract, sd)); + if (sd is not null) { - dic.RemoveLastFactory(contract); - if (dic.IsEmpty) - { - RemoveContractService(serviceType); - } + _serviceCollection?.Remove(sd); } } @@ -196,7 +188,7 @@ public virtual void UnregisterCurrent(Type? serviceType, string? contract = null /// ignoring the argument. /// /// The service type to unregister. - /// This parameter is ignored. Service will be removed from all contracts. + /// A optional value which will remove only an object registered with the same contract. public virtual void UnregisterAll(Type? serviceType, string? contract = null) { if (_isImmutable) @@ -208,34 +200,28 @@ public virtual void UnregisterAll(Type? serviceType, string? contract = null) lock (_syncLock) { - switch (contract) + if (_serviceCollection is null) + { + // required so that it gets rebuilt if not injected externally. + _serviceProvider = null; + return; + } + + IEnumerable sds = Enumerable.Empty(); + + if (contract is null || string.IsNullOrWhiteSpace(contract)) { - case null when _serviceCollection is not null: - { - var sds = _serviceCollection - .Where(s => s.ServiceType == serviceType) - .ToList(); - - foreach (var sd in sds) - { - _serviceCollection.Remove(sd); - } - - break; - } - - case null: - throw new ArgumentException("There must be a valid contract if there is no service collection.", nameof(contract)); - default: - { - var dic = GetContractDictionary(serviceType, false); - if (dic?.TryRemoveContract(contract) == true && dic.IsEmpty) - { - RemoveContractService(serviceType); - } - - break; - } + sds = _serviceCollection.Where(s => !s.IsKeyedService && s.ServiceType == serviceType); + } + else + { + sds = _serviceCollection + .Where(sd => MatchesKeyedContract(serviceType, contract, sd)); + } + + foreach (var sd in sds.ToList()) + { + _serviceCollection.Remove(sd); } // required so that it gets rebuilt if not injected externally. @@ -255,16 +241,10 @@ public virtual bool HasRegistration(Type? serviceType, string? contract = null) { if (contract is null || string.IsNullOrWhiteSpace(contract)) { - return _serviceCollection?.Any(sd => sd.ServiceType == serviceType) == true; + return _serviceCollection?.Any(sd => !sd.IsKeyedService && sd.ServiceType == serviceType) == true; } - var dictionary = (ContractDictionary?)_serviceCollection?.FirstOrDefault(sd => sd.ServiceType == GetDictionaryType(serviceType))?.ImplementationInstance; - - return dictionary switch - { - null => false, - _ => dictionary.GetFactories(contract).Select(f => f()).Any() - }; + return _serviceCollection?.Any(sd => MatchesKeyedContract(serviceType, contract, sd)) == true; } if (contract is null) @@ -273,8 +253,12 @@ public virtual bool HasRegistration(Type? serviceType, string? contract = null) return service is not null; } - var dic = GetContractDictionary(serviceType, false); - return dic?.IsEmpty == false; + if (_serviceProvider is IKeyedServiceProvider keyedServiceProvider) + { + return keyedServiceProvider.GetKeyedService(serviceType, contract) is not null; + } + + return false; } /// @@ -292,103 +276,9 @@ protected virtual void Dispose(bool disposing) { } - private static Type GetDictionaryType(Type serviceType) => _dictionaryType.MakeGenericType(serviceType); - - private void RemoveContractService(Type serviceType) - { - var dicType = GetDictionaryType(serviceType); - var sd = _serviceCollection?.SingleOrDefault(s => s.ServiceType == serviceType); - - if (sd is not null) - { - _serviceCollection?.Remove(sd); - } - } - - [SuppressMessage("Naming Rules", "SA1300", Justification = "Intentional")] - private ContractDictionary? GetContractDictionary(Type serviceType, bool createIfNotExists) - { - var dicType = GetDictionaryType(serviceType); - - if (ServiceProvider is null) - { - throw new InvalidOperationException("The ServiceProvider is null."); - } - - if (_isImmutable) - { - return (ContractDictionary?)ServiceProvider.GetService(dicType); - } - - var dic = getDictionary(); - if (createIfNotExists && dic is null) - { - lock (_syncLock) - { - if (createIfNotExists) - { - dic = (ContractDictionary?)Activator.CreateInstance(dicType); - - if (dic is not null) - { - _serviceCollection?.AddSingleton(dicType, dic); - } - } - } - } - - return dic; - - ContractDictionary? getDictionary() => _serviceCollection? - .Where(sd => sd.ServiceType == dicType) - .Select(sd => sd.ImplementationInstance) - .Cast() - .SingleOrDefault(); - } - - private class ContractDictionary - { - private readonly ConcurrentDictionary>> _dictionary = new(); - - public bool IsEmpty => _dictionary.IsEmpty; - - public bool TryRemoveContract(string contract) => - _dictionary.TryRemove(contract, out var _); - - public Func? GetFactory(string contract) => - GetFactories(contract) - .LastOrDefault(); - - public IEnumerable> GetFactories(string contract) => - _dictionary.TryGetValue(contract, out var collection) - ? collection ?? Enumerable.Empty>() - : Array.Empty>(); - - public void AddFactory(string contract, Func factory) => - _dictionary.AddOrUpdate(contract, _ => new() { factory }, (_, list) => - { - (list ??= []).Add(factory); - return list; - }); - - public void RemoveLastFactory(string contract) => - _dictionary.AddOrUpdate(contract, [], (_, list) => - { - var lastIndex = list.Count - 1; - if (lastIndex > 0) - { - list.RemoveAt(lastIndex); - } - - // TODO if list empty remove contract entirely - // need to find how to atomically update or remove - // https://github.com/dotnet/corefx/issues/24246 - return list; - }); - } - - [SuppressMessage("Design", "CA1812: Unused class.", Justification = "Used in reflection.")] - private sealed class ContractDictionary : ContractDictionary - { - } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static bool MatchesKeyedContract(Type? serviceType, string contract, ServiceDescriptor sd) => + sd.ServiceType == serviceType + && sd is { IsKeyedService: true, ServiceKey: string serviceKey } + && serviceKey == contract; }