diff --git a/src/Tact.Configuration/Extensions/ConfigurationExtensions.cs b/src/Tact.Configuration/Extensions/ConfigurationExtensions.cs index faebf0a..189db84 100644 --- a/src/Tact.Configuration/Extensions/ConfigurationExtensions.cs +++ b/src/Tact.Configuration/Extensions/ConfigurationExtensions.cs @@ -17,6 +17,9 @@ public static T Create(this IConfiguration config) public static object Create(this IConfiguration config, Type type) { + if (type == null) + throw new ArgumentNullException(nameof(type)); + var value = Activator.CreateInstance(type); Bind(config, type, value); return value; @@ -25,6 +28,9 @@ public static object Create(this IConfiguration config, Type type) public static T CreateAndValidate(this IConfiguration config) where T : new() { + if (config == null) + throw new ArgumentNullException(nameof(config)); + var type = typeof(T); var value = new T(); BindAndValidate(config, type, value); @@ -33,6 +39,9 @@ public static T CreateAndValidate(this IConfiguration config) public static object CreateAndValidate(this IConfiguration config, Type type) { + if (type == null) + throw new ArgumentNullException(nameof(type)); + var value = Activator.CreateInstance(type); BindAndValidate(config, type, value); return value; @@ -46,6 +55,9 @@ public static void BindAndValidate(IConfiguration config, object value) public static void BindAndValidate(IConfiguration config, Type type, object value) { + if (type == null) + throw new ArgumentNullException(nameof(type)); + Bind(config, type, value); var context = new ValidationContext(value); context.ValidateObject(); @@ -53,6 +65,9 @@ public static void BindAndValidate(IConfiguration config, Type type, object valu private static void Bind(IConfiguration config, Type type, object value) { + if (config == null) + throw new ArgumentNullException(nameof(config)); + config.GetSection(type.Name).Bind(value); } } diff --git a/src/Tact.Configuration/Extensions/ContainerExtensions.cs b/src/Tact.Configuration/Extensions/ContainerExtensions.cs index c49136a..c3d7650 100644 --- a/src/Tact.Configuration/Extensions/ContainerExtensions.cs +++ b/src/Tact.Configuration/Extensions/ContainerExtensions.cs @@ -33,6 +33,9 @@ public static void ConfigureByAttribute(this IContainer container, IConfigura public static void ConfigureByAttribute(this IContainer container, IConfiguration configuration, params Type[] types) where T : IRegisterConfigurationAttribute { + if(container == null) + throw new ArgumentNullException(nameof(container)); + ILog logger; container.TryResolve(out logger); diff --git a/src/Tact.Configuration/Extensions/ValidationContextExtensions.cs b/src/Tact.Configuration/Extensions/ValidationContextExtensions.cs index 8ae86e1..34c2f97 100644 --- a/src/Tact.Configuration/Extensions/ValidationContextExtensions.cs +++ b/src/Tact.Configuration/Extensions/ValidationContextExtensions.cs @@ -17,6 +17,9 @@ public static class ValidationContextExtensions public static void ValidateObject(this ValidationContext context, bool validateAllProperties = true) { + if (context == null) + throw new ArgumentNullException(nameof(context)); + var instance = context.ObjectInstance; var type = instance.GetType(); diff --git a/src/Tact/Extensions/CollectionExtensions.cs b/src/Tact/Extensions/CollectionExtensions.cs index 1b8f261..99ff141 100644 --- a/src/Tact/Extensions/CollectionExtensions.cs +++ b/src/Tact/Extensions/CollectionExtensions.cs @@ -38,6 +38,12 @@ public static class CollectionExtensions Func> func, int? maxParallelization = null) { + if (collection == null) + throw new ArgumentNullException(nameof(collection)); + + if (func == null) + throw new ArgumentNullException(nameof(func)); + var results = new TOutput[collection.Count]; await collection diff --git a/src/Tact/Extensions/ContainerExtensions.cs b/src/Tact/Extensions/ContainerExtensions.cs index d04ce21..4cfabd1 100644 --- a/src/Tact/Extensions/ContainerExtensions.cs +++ b/src/Tact/Extensions/ContainerExtensions.cs @@ -36,6 +36,9 @@ public static void InitializeByAttribute(this IContainer container, params As public static void InitializeByAttribute(this IContainer container, params Type[] types) where T : IInitializeAttribute { + if (container == null) + throw new ArgumentNullException(nameof(container)); + ILog logger; container.TryResolve(out logger); @@ -96,6 +99,9 @@ public static void RegisterByAttribute(this IContainer container, params Type where TRegister : IRegisterAttribute where TCondition : IRegisterConditionAttribute { + if (container == null) + throw new ArgumentNullException(nameof(container)); + ILog logger; container.TryResolve(out logger); @@ -141,6 +147,9 @@ public static void RegisterPerResolve(this IContainer container, Func factory = null) { + if (container == null) + throw new ArgumentNullException(nameof(container)); + if (factory == null) { container.RegisterPerResolve(type, type); @@ -160,6 +169,9 @@ public static void RegisterPerResolve(this IContainer container, string key, public static void RegisterPerResolve(this IContainer container, Type type, string key, Func factory = null) { + if (container == null) + throw new ArgumentNullException(nameof(container)); + if (factory == null) { container.RegisterPerResolve(type, type, key); @@ -180,6 +192,9 @@ public static void RegisterPerResolve(this IContainer container, Type type, stri public static void RegisterPerResolve(this IContainer container, Type fromType, Type toType, string key = null) { + if (container == null) + throw new ArgumentNullException(nameof(container)); + toType.EnsureSingleCostructor(); var lifetimeManager = new PerResolveLifetimeManager(toType, container); @@ -202,6 +217,9 @@ public static void RegisterPerScope(this IContainer container, Func factory = null) { + if (container == null) + throw new ArgumentNullException(nameof(container)); + if (factory == null) { container.RegisterPerScope(type, type); @@ -221,6 +239,9 @@ public static void RegisterPerScope(this IContainer container, string key, Fu public static void RegisterPerScope(this IContainer container, Type type, string key, Func factory = null) { + if (container == null) + throw new ArgumentNullException(nameof(container)); + if (factory == null) { container.RegisterPerScope(type, type, key); @@ -241,6 +262,9 @@ public static void RegisterPerScope(this IContainer container, Type type, string public static void RegisterPerScope(this IContainer container, Type fromType, Type toType, string key = null) { + if (container == null) + throw new ArgumentNullException(nameof(container)); + toType.EnsureSingleCostructor(); var lifetimeManager = new PerScopeLifetimeManager(toType, container); @@ -256,6 +280,9 @@ public static void RegisterPerScope(this IContainer container, Type fromType, Ty public static void RegisterInstance(this IContainer container, T value, string key = null) { + if (container == null) + throw new ArgumentNullException(nameof(container)); + var type = typeof(T); if (string.IsNullOrWhiteSpace(key)) @@ -266,6 +293,9 @@ public static void RegisterInstance(this IContainer container, T value, strin public static void RegisterInstance(this IContainer container, Type type, object value, string key = null) { + if (container == null) + throw new ArgumentNullException(nameof(container)); + var lifetimeManager = new InstanceLifetimeManager(value, container); if (string.IsNullOrWhiteSpace(key)) @@ -287,6 +317,9 @@ public static void RegisterSingleton(this IContainer container, Func factory = null) { + if (container == null) + throw new ArgumentNullException(nameof(container)); + if (factory == null) { container.RegisterSingleton(type, type); @@ -306,6 +339,9 @@ public static void RegisterSingleton(this IContainer container, string key, F public static void RegisterSingleton(this IContainer container, Type type, string key, Func factory = null) { + if (container == null) + throw new ArgumentNullException(nameof(container)); + if (factory == null) { container.RegisterSingleton(type, type, key); @@ -326,6 +362,9 @@ public static void RegisterSingleton(this IContainer container, Type type, strin public static void RegisterSingleton(this IContainer container, Type fromType, Type toType, string key = null) { + if (container == null) + throw new ArgumentNullException(nameof(container)); + toType.EnsureSingleCostructor(); var lifetimeManager = new SingletonLifetimeManager(toType, container); @@ -348,6 +387,9 @@ public static void RegisterTransient(this IContainer container, Func factory = null) { + if (container == null) + throw new ArgumentNullException(nameof(container)); + if (factory == null) { container.RegisterTransient(type, type); @@ -367,6 +409,9 @@ public static void RegisterTransient(this IContainer container, string key, F public static void RegisterTransient(this IContainer container, Type type, string key, Func factory = null) { + if (container == null) + throw new ArgumentNullException(nameof(container)); + if (factory == null) { container.RegisterTransient(type, type, key); @@ -387,6 +432,9 @@ public static void RegisterTransient(this IContainer container, Type type, strin public static void RegisterTransient(this IContainer container, Type fromType, Type toType, string key = null) { + if (container == null) + throw new ArgumentNullException(nameof(container)); + toType.EnsureSingleCostructor(); var lifetimeManager = new TransientLifetimeManager(toType, container); @@ -409,6 +457,9 @@ public static void RegisterTransient(this IContainer container, Type fromType, T public static void RegisterProxy(this IContainer container, Type fromType, Type toType, string fromKey = null, string toKey = null) { + if (container == null) + throw new ArgumentNullException(nameof(container)); + var lifetimeManager = new ProxyLifetimeManager(toType, fromKey, container); if (string.IsNullOrWhiteSpace(toKey)) diff --git a/src/Tact/Extensions/EnumerableExtensions.cs b/src/Tact/Extensions/EnumerableExtensions.cs index a5552a1..5e50d27 100644 --- a/src/Tact/Extensions/EnumerableExtensions.cs +++ b/src/Tact/Extensions/EnumerableExtensions.cs @@ -39,6 +39,12 @@ public static class EnumerableExtensions Func func, int? maxParallelization = null) { + if (enumerable == null) + throw new ArgumentNullException(nameof(enumerable)); + + if (func == null) + throw new ArgumentNullException(nameof(func)); + var exceptions = new ConcurrentQueue(); var maxCount = maxParallelization ?? Environment.ProcessorCount; var tasks = new List(maxCount); diff --git a/src/Tact/Extensions/ReaderWriterLockSlimExtensions.cs b/src/Tact/Extensions/ReaderWriterLockSlimExtensions.cs index 171eeb5..ff927cb 100644 --- a/src/Tact/Extensions/ReaderWriterLockSlimExtensions.cs +++ b/src/Tact/Extensions/ReaderWriterLockSlimExtensions.cs @@ -12,13 +12,21 @@ public static IDisposable UseReadLock(this ReaderWriterLockSlim lockSlim, int mi public static IDisposable UseReadLock(this ReaderWriterLockSlim lockSlim, TimeSpan timeout) { + if (lockSlim == null) + throw new ArgumentNullException(nameof(lockSlim)); + var result = lockSlim.TryEnterReadLock(timeout); - if (!result) throw new TimeoutException("Unable to obtain a read lock"); + if (!result) + throw new TimeoutException("Unable to obtain a read lock"); + return new ReaderWriterLockSlimWrapper(lockSlim, false); } public static IDisposable UseReadLock(this ReaderWriterLockSlim lockSlim) { + if (lockSlim == null) + throw new ArgumentNullException(nameof(lockSlim)); + lockSlim.EnterReadLock(); return new ReaderWriterLockSlimWrapper(lockSlim, false); } @@ -30,13 +38,21 @@ public static IDisposable UseWriteLock(this ReaderWriterLockSlim lockSlim, int m public static IDisposable UseWriteLock(this ReaderWriterLockSlim lockSlim, TimeSpan timeout) { + if (lockSlim == null) + throw new ArgumentNullException(nameof(lockSlim)); + var result = lockSlim.TryEnterWriteLock(timeout); - if (!result) throw new TimeoutException("Unable to obtain a write lock"); + if (!result) + throw new TimeoutException("Unable to obtain a write lock"); + return new ReaderWriterLockSlimWrapper(lockSlim, true); } public static IDisposable UseWriteLock(this ReaderWriterLockSlim lockSlim) { + if (lockSlim == null) + throw new ArgumentNullException(nameof(lockSlim)); + lockSlim.EnterWriteLock(); return new ReaderWriterLockSlimWrapper(lockSlim, true); } diff --git a/src/Tact/Extensions/ResolverExtensions.cs b/src/Tact/Extensions/ResolverExtensions.cs index 7746eee..e69e256 100644 --- a/src/Tact/Extensions/ResolverExtensions.cs +++ b/src/Tact/Extensions/ResolverExtensions.cs @@ -1,4 +1,5 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; using System.Linq; using Tact.Practices; @@ -8,12 +9,18 @@ public static class ResolverExtensions { public static T Resolve(this IResolver resolver) { + if (resolver == null) + throw new ArgumentNullException(nameof(resolver)); + var type = typeof(T); return (T) resolver.Resolve(type); } public static bool TryResolve(this IResolver resolver, out T result) { + if (resolver == null) + throw new ArgumentNullException(nameof(resolver)); + var type = typeof(T); object objResult; if (resolver.TryResolve(type, out objResult)) @@ -28,12 +35,18 @@ public static bool TryResolve(this IResolver resolver, out T result) public static T Resolve(this IResolver resolver, string key) { + if (resolver == null) + throw new ArgumentNullException(nameof(resolver)); + var type = typeof(T); return (T) resolver.Resolve(type, key); } public static bool TryResolve(this IResolver resolver, string key, out T result) { + if (resolver == null) + throw new ArgumentNullException(nameof(resolver)); + var type = typeof(T); object objResult; if (resolver.TryResolve(type, key, out objResult)) @@ -48,6 +61,9 @@ public static bool TryResolve(this IResolver resolver, string key, out T resu public static IEnumerable ResolveAll(this IResolver resolver) { + if (resolver == null) + throw new ArgumentNullException(nameof(resolver)); + var type = typeof(T); return resolver.ResolveAll(type).Cast(); } diff --git a/src/Tact/Extensions/SemaphoreSlimExtensions.cs b/src/Tact/Extensions/SemaphoreSlimExtensions.cs index 271d671..93feccd 100644 --- a/src/Tact/Extensions/SemaphoreSlimExtensions.cs +++ b/src/Tact/Extensions/SemaphoreSlimExtensions.cs @@ -13,14 +13,23 @@ public static Task UseAsync(this SemaphoreSlim semaphore, int milli public static async Task UseAsync(this SemaphoreSlim semaphore, TimeSpan timeout, CancellationToken cancelToken = default(CancellationToken)) { + if (semaphore == null) + throw new ArgumentNullException(nameof(semaphore)); + var result = await semaphore.WaitAsync(timeout, cancelToken).ConfigureAwait(false); - if (!result) throw new TimeoutException("Unable to obtain a lock"); + if (!result) + throw new TimeoutException("Unable to obtain a lock"); + return new SemaphoreSlimWrapper(semaphore); } public static async Task UseAsync(this SemaphoreSlim semaphore, CancellationToken cancelToken = default(CancellationToken)) { + if (semaphore == null) + throw new ArgumentNullException(nameof(semaphore)); + await semaphore.WaitAsync(cancelToken).ConfigureAwait(false); + return new SemaphoreSlimWrapper(semaphore); } diff --git a/src/Tact/Extensions/TaskExtensions.cs b/src/Tact/Extensions/TaskExtensions.cs index ac3b432..467fff0 100644 --- a/src/Tact/Extensions/TaskExtensions.cs +++ b/src/Tact/Extensions/TaskExtensions.cs @@ -9,6 +9,7 @@ namespace Tact public static class TaskExtensions { private const string CompleteTaskMessage = "Task must be complete"; + private const string ResultPropertyName = "Result"; private static readonly Type GenericTaskType = typeof(Task<>); @@ -17,6 +18,9 @@ public static class TaskExtensions public static Task IgnoreCancellation(this Task task, CancellationToken token) { + if (task == null) + throw new ArgumentNullException(nameof(task)); + // ReSharper disable once MethodSupportsCancellation return task .ContinueWith(t => @@ -36,6 +40,9 @@ public static Task IgnoreCancellation(this Task task, CancellationToken token) public static Task IgnoreCancellation(this Task task) { + if (task == null) + throw new ArgumentNullException(nameof(task)); + return task .ContinueWith(t => { @@ -53,6 +60,9 @@ public static Task IgnoreCancellation(this Task task) public static Task IgnoreCancellation(this Task task, CancellationToken token) { + if (task == null) + throw new ArgumentNullException(nameof(task)); + // ReSharper disable once MethodSupportsCancellation return task .ContinueWith(t => @@ -72,6 +82,9 @@ public static Task IgnoreCancellation(this Task task, CancellationToken public static Task IgnoreCancellation(this Task task) { + if (task == null) + throw new ArgumentNullException(nameof(task)); + return task .ContinueWith(t => { diff --git a/src/Tact/Extensions/TypeExtensions.cs b/src/Tact/Extensions/TypeExtensions.cs index ff9a1e0..7b969a1 100644 --- a/src/Tact/Extensions/TypeExtensions.cs +++ b/src/Tact/Extensions/TypeExtensions.cs @@ -8,6 +8,7 @@ namespace Tact public static class TypeExtensions { private const string ClassRequired = "TTo must be a class"; + private const string ConstructorRequired = "There must be a single public constructor defined"; private static readonly ConcurrentDictionary> ResultMap = @@ -15,6 +16,9 @@ public static class TypeExtensions public static ConstructorInfo EnsureSingleCostructor(this Type type) { + if (type == null) + throw new ArgumentNullException(nameof(type)); + var result = ResultMap.GetOrAdd(type, t => { var typeInfo = type.GetTypeInfo(); @@ -43,6 +47,7 @@ public static ConstructorInfo EnsureSingleCostructor(this Type type) throw new ArgumentOutOfRangeException(); } } + public static EfficientInvoker GetMethodInvoker(this Type type, string methodName) { return EfficientInvoker.ForMethod(type, methodName); diff --git a/src/Tact/Practices/Base/ContainerBase.cs b/src/Tact/Practices/Base/ContainerBase.cs index a05653f..a4ba52b 100644 --- a/src/Tact/Practices/Base/ContainerBase.cs +++ b/src/Tact/Practices/Base/ContainerBase.cs @@ -22,7 +22,11 @@ public abstract class ContainerBase : IContainer protected ContainerBase(ILog log) { + if (log == null) + throw new ArgumentNullException(nameof(log)); + Log = log; + this.RegisterInstance(log); } @@ -179,6 +183,9 @@ public void Register(Type fromType, string key, ILifetimeManager lifetimeManager public object CreateInstance(Type type, Stack stack) { + if (type == null) + throw new ArgumentNullException(nameof(type)); + var constructor = type.EnsureSingleCostructor(); var parameterTypes = constructor.GetParameters().Select(p => p.ParameterType).ToArray(); var arguments = new object[parameterTypes.Length]; @@ -195,6 +202,12 @@ public object CreateInstance(Type type, Stack stack) private bool TryResolve(Type type, Stack stack, bool canThrow, out object result) { + if (type == null) + throw new ArgumentNullException(nameof(type)); + + if (stack == null) + throw new ArgumentNullException(nameof(stack)); + using (EnterPush(type, stack)) using (_lock.UseReadLock()) { @@ -217,6 +230,15 @@ private bool TryResolve(Type type, Stack stack, bool canThrow, out object private bool TryResolve(Type type, string key, Stack stack, bool canThrow, out object result) { + if (type == null) + throw new ArgumentNullException(nameof(type)); + + if (key == null) + throw new ArgumentNullException(nameof(key)); + + if (stack == null) + throw new ArgumentNullException(nameof(stack)); + using (EnterPush(type, stack)) using (_lock.UseReadLock()) { diff --git a/src/Tact/Practices/Implementation/Container.cs b/src/Tact/Practices/Implementation/Container.cs index 1363f21..f0ee3bf 100644 --- a/src/Tact/Practices/Implementation/Container.cs +++ b/src/Tact/Practices/Implementation/Container.cs @@ -38,7 +38,7 @@ public sealed class Container : ContainerBase public Container(ILog log, IList resolutionHandlers) : base(log) { - ResolutionHandlers = resolutionHandlers; + ResolutionHandlers = resolutionHandlers ?? new List(); } protected override IList ResolutionHandlers { get; } diff --git a/src/Tact/Reflection/EfficientInvoker.cs b/src/Tact/Reflection/EfficientInvoker.cs index 44327b4..f31b691 100644 --- a/src/Tact/Reflection/EfficientInvoker.cs +++ b/src/Tact/Reflection/EfficientInvoker.cs @@ -24,6 +24,9 @@ private EfficientInvoker(Func func) public static EfficientInvoker ForDelegate(Delegate del) { + if (del == null) + throw new ArgumentNullException(nameof(del)); + var type = del.GetType(); return TypeToWrapperMap.GetOrAdd(type, t => { @@ -35,6 +38,12 @@ public static EfficientInvoker ForDelegate(Delegate del) public static EfficientInvoker ForMethod(Type type, string methodName) { + if (type == null) + throw new ArgumentNullException(nameof(type)); + + if (methodName == null) + throw new ArgumentNullException(nameof(methodName)); + var key = new MethodKey(type, methodName); return MethodToWrapperMap.GetOrAdd(key, k => { @@ -46,6 +55,12 @@ public static EfficientInvoker ForMethod(Type type, string methodName) public static EfficientInvoker ForProperty(Type type, string propertyName) { + if (type == null) + throw new ArgumentNullException(nameof(type)); + + if (propertyName == null) + throw new ArgumentNullException(nameof(propertyName)); + var key = new MethodKey(type, propertyName); return MethodToWrapperMap.GetOrAdd(key, k => { diff --git a/src/Tact/Threading/Using.cs b/src/Tact/Threading/Using.cs index bde4937..8761766 100644 --- a/src/Tact/Threading/Using.cs +++ b/src/Tact/Threading/Using.cs @@ -20,6 +20,12 @@ public static class Using Func> func) where TInput : IAsyncDisposable { + if (disposable == null) + throw new ArgumentNullException(nameof(disposable)); + + if (func == null) + throw new ArgumentNullException(nameof(func)); + try { return await func(disposable, cancelToken).ConfigureAwait(false); @@ -44,6 +50,12 @@ public static class Using Func func) where T : IAsyncDisposable { + if (disposable == null) + throw new ArgumentNullException(nameof(disposable)); + + if (func == null) + throw new ArgumentNullException(nameof(func)); + try { await func(disposable, cancelToken).ConfigureAwait(false);