Skip to content

Commit

Permalink
Forwards cancellation token to recipients. (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
tommasobertoni committed Feb 6, 2021
2 parents 20f481b + 1d2ee9b commit 15e1ae2
Show file tree
Hide file tree
Showing 32 changed files with 721 additions and 235 deletions.
56 changes: 46 additions & 10 deletions src/NScatterGather/Aggregator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,27 @@ namespace NScatterGather
{
public class Aggregator
{
public TimeSpan CancellationWindow
{
get { return _cancellationWindow; }
set
{
if (value.IsNegative())
throw new ArgumentException($"{nameof(CancellationToken)} can't be negative.");

_cancellationWindow = value;
}
}

public bool AllowCancellationWindowOnAllRecipients { get; set; } = false;

private TimeSpan _cancellationWindow;
private readonly IRecipientsScope _scope;

public Aggregator(RecipientsCollection collection)
{
_scope = collection.CreateScope();
CancellationWindow = TimeSpan.FromMilliseconds(100);
}

public async Task<AggregatedResponse<object?>> Send(
Expand Down Expand Up @@ -46,19 +62,19 @@ public Aggregator(RecipientsCollection collection)
object request,
CancellationToken cancellationToken)
{
var runners = recipients.SelectMany(recipient => recipient.Accept(request)).ToArray();
var runners = recipients.SelectMany(recipient => recipient.Accept(request, cancellationToken)).ToArray();

var tasks = runners
.Select(runner => runner.Start())
.ToArray();

var allTasksCompleted = Task.WhenAll(tasks);

if (allTasksCompleted.IsCompletedSuccessfully())
return runners;
using var cancellation = new CancellationTokenTaskSource<object?[]>(cancellationToken);
await Task.WhenAny(allTasksCompleted, cancellation.Task).ConfigureAwait(false);

using (var cancellation = new CancellationTokenTaskSource<object?[]>(cancellationToken))
await Task.WhenAny(allTasksCompleted, cancellation.Task).ConfigureAwait(false);
if (cancellationToken.IsCancellationRequested)
await WaitForLatecomers(runners).ConfigureAwait(false);

return runners;
}
Expand Down Expand Up @@ -90,21 +106,41 @@ public Aggregator(RecipientsCollection collection)
object request,
CancellationToken cancellationToken)
{
var runners = recipients.SelectMany(recipient => recipient.ReplyWith<TResponse>(request)).ToArray();
var runners = recipients.SelectMany(recipient => recipient.ReplyWith<TResponse>(request, cancellationToken)).ToArray();

var tasks = runners
.Select(runner => runner.Start())
.ToArray();

var allTasksCompleted = Task.WhenAll(tasks);

if (allTasksCompleted.IsCompletedSuccessfully())
return runners;
using var cancellation = new CancellationTokenTaskSource<TResponse[]>(cancellationToken);
await Task.WhenAny(allTasksCompleted, cancellation.Task).ConfigureAwait(false);

using (var cancellation = new CancellationTokenTaskSource<TResponse[]>(cancellationToken))
await Task.WhenAny(allTasksCompleted, cancellation.Task).ConfigureAwait(false);
if (cancellationToken.IsCancellationRequested)
await WaitForLatecomers(runners).ConfigureAwait(false);

return runners;
}

private async Task WaitForLatecomers<TResponse>(IReadOnlyList<RecipientRunner<TResponse>> runners)
{
var incompleteRunners = runners.Where(r => !r.Task.IsCompleted);

if (!AllowCancellationWindowOnAllRecipients)
incompleteRunners = incompleteRunners.Where(r => r.AcceptedCancellationToken);

var completionTasks = incompleteRunners.Select(CreateCompletionTask).ToArray();

if (!completionTasks.Any()) return;

await Task.WhenAll(completionTasks).ConfigureAwait(false);

async Task CreateCompletionTask(RecipientRunner<TResponse> runner)
{
var wait = Task.Delay(CancellationWindow);
await Task.WhenAny(runner.Task, wait).ConfigureAwait(false);
}
}
}
}
33 changes: 31 additions & 2 deletions src/NScatterGather/Inspection/MethodAnalyzer.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;

namespace NScatterGather.Inspection
Expand All @@ -10,6 +12,7 @@ internal class MethodAnalyzer
public bool IsMatch(
MethodInspection inspection,
Type requestType,
bool allowCancellationTokenParameter,
[NotNullWhen(true)] out MethodInfo? match)
{
match = null;
Expand All @@ -22,9 +25,18 @@ internal class MethodAnalyzer
return true;
}

if (parameters.Count != 1)
if (parameters.Count > 2)
return false;

if (parameters.Count == 2)
{
if (!allowCancellationTokenParameter)
return false;

if (!AcceptsCancellationToken(inspection))
return false;
}

var theParameter = parameters[0];

if (IsSameOrCompatible(baseType: theParameter.ParameterType, requestType))
Expand All @@ -40,11 +52,12 @@ internal class MethodAnalyzer
MethodInspection inspection,
Type requestType,
Type responseType,
bool allowCancellationTokenParameter,
[NotNullWhen(true)] out MethodInfo? match)
{
match = null;

if (!IsMatch(inspection, requestType, out _))
if (!IsMatch(inspection, requestType, allowCancellationTokenParameter, out _))
return false;

// Method has the correct input parameter.
Expand Down Expand Up @@ -74,6 +87,22 @@ internal class MethodAnalyzer
return false;
}

public bool AcceptsCancellationToken(MethodInspection inspection) =>
AcceptCancellationToken(inspection.Parameters);

public bool AcceptsCancellationToken(MethodInfo method) =>
AcceptCancellationToken(method.GetParameters());

private bool AcceptCancellationToken(IReadOnlyList<ParameterInfo> parameters)
{
if (parameters.Count != 2)
return false;

var theCancellationTokenParameter = parameters[1];

return theCancellationTokenParameter.ParameterType == typeof(CancellationToken);
}

private bool IsSameOrCompatible(Type baseType, Type otherType)
{
if (baseType == otherType)
Expand Down
26 changes: 20 additions & 6 deletions src/NScatterGather/Inspection/TypeInspector.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Reflection;

Expand Down Expand Up @@ -69,7 +68,11 @@ private MethodMatchEvaluation FindOrEvaluate(Type requestType)

var matches = ListMatchingMethods(requestType);

var evaluation = new MethodMatchEvaluation(requestType, responseType: null, matches);
var evaluation = new MethodMatchEvaluation(
requestType,
responseType: null,
matches);

_evaluationsCache.TryAdd(evaluation);

return evaluation;
Expand All @@ -78,9 +81,14 @@ private MethodMatchEvaluation FindOrEvaluate(Type requestType)
private IReadOnlyList<MethodInfo> ListMatchingMethods(Type requestType)
{
return _methodInspections
.Select(i =>
.Select(inspection =>
{
var isMatch = _methodAnalyzer.IsMatch(i, requestType, out var match);
var isMatch = _methodAnalyzer.IsMatch(
inspection,
requestType,
allowCancellationTokenParameter: true,
out var match);
return (isMatch, match);
})
.Where(x => x.isMatch)
Expand Down Expand Up @@ -137,9 +145,15 @@ private MethodMatchEvaluation FindOrEvaluate(Type requestType, Type responseType
private IReadOnlyList<MethodInfo> ListMatchingMethods(Type requestType, Type responseType)
{
return _methodInspections
.Select(i =>
.Select(inspection =>
{
var isMatch = _methodAnalyzer.IsMatch(i, requestType, responseType, out var match);
var isMatch = _methodAnalyzer.IsMatch(
inspection,
requestType,
responseType,
allowCancellationTokenParameter: true,
out var match);
return (isMatch, match);
})
.Where(x => x.isMatch)
Expand Down
10 changes: 10 additions & 0 deletions src/NScatterGather/Internals/ValidationExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
using System;

namespace NScatterGather
{
internal static class ValidationExtensions
{
public static bool IsNegative(this TimeSpan timeSpan) =>
timeSpan.Ticks < 0;
}
}
15 changes: 15 additions & 0 deletions src/NScatterGather/Recipients/Collection/RecipientsCollection.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using NScatterGather.Inspection;
using NScatterGather.Recipients;
using NScatterGather.Recipients.Collection.Scope;
Expand Down Expand Up @@ -121,6 +122,20 @@ static bool HasADefaultConstructor<T>()
return delegateRecipient.Id;
}

public Guid Add<TRequest, TResponse>(
Func<TRequest, CancellationToken, TResponse> @delegate,
string? name = null)
{
if (@delegate is null)
throw new ArgumentNullException(nameof(@delegate));

var delegateRecipient = DelegateRecipient.Create(@delegate, name);

_recipients.Add(delegateRecipient);

return delegateRecipient.Id;
}

internal IRecipientsScope CreateScope()
{
var scope = new RecipientsScope();
Expand Down
55 changes: 51 additions & 4 deletions src/NScatterGather/Recipients/DelegateRecipient.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Threading;
using NScatterGather.Recipients.Descriptors;
using NScatterGather.Recipients.Invokers;

Expand All @@ -10,6 +11,8 @@ internal class DelegateRecipient : Recipient

public Type ResponseType { get; }

public bool AcceptsCancellationToken { get; }

public static DelegateRecipient Create<TRequest, TResponse>(
Func<TRequest, TResponse> @delegate,
string? name)
Expand All @@ -24,29 +27,73 @@ internal class DelegateRecipient : Recipient
return response;
}

var descriptor = new DelegateRecipientDescriptor(typeof(TRequest), typeof(TResponse));
var acceptsCancellationToken = false;

var descriptor = new DelegateRecipientDescriptor(typeof(TRequest), typeof(TResponse), acceptsCancellationToken);
var invoker = new DelegateRecipientInvoker(descriptor, delegateInvoker);

return new DelegateRecipient(descriptor.RequestType, descriptor.ResponseType, descriptor, invoker, name);
return new DelegateRecipient(
descriptor.RequestType,
descriptor.ResponseType,
descriptor,
invoker,
name,
acceptsCancellationToken);
}

public static DelegateRecipient Create<TRequest, TResponse>(
Func<TRequest, CancellationToken, TResponse> @delegate,
string? name)
{
if (@delegate is null)
throw new ArgumentNullException(nameof(@delegate));

object? delegateInvoker(object request, CancellationToken cancellationToken)
{
var typedRequest = (TRequest)request;
TResponse response = @delegate(typedRequest, cancellationToken);
return response;
}

var acceptsCancellationToken = true;

var descriptor = new DelegateRecipientDescriptor(typeof(TRequest), typeof(TResponse), acceptsCancellationToken);
var invoker = new DelegateRecipientInvoker(descriptor, delegateInvoker);

return new DelegateRecipient(
descriptor.RequestType,
descriptor.ResponseType,
descriptor,
invoker,
name,
acceptsCancellationToken);
}

protected DelegateRecipient(
Type requestType,
Type responseType,
IRecipientDescriptor descriptor,
IRecipientInvoker invoker,
string? name)
string? name,
bool acceptsCancellationToken)
: base(descriptor, invoker, name, Lifetime.Singleton, CollisionStrategy.IgnoreRecipient)
{
RequestType = requestType;
ResponseType = responseType;
AcceptsCancellationToken = acceptsCancellationToken;
}

#if NETSTANDARD2_0 || NETSTANDARD2_1
public override Recipient Clone() =>
#else
public override DelegateRecipient Clone() =>
#endif
new DelegateRecipient(RequestType, ResponseType, _descriptor, _invoker, Name);
new DelegateRecipient(
RequestType,
ResponseType,
_descriptor,
_invoker,
Name,
AcceptsCancellationToken);
}
}
Loading

0 comments on commit 15e1ae2

Please sign in to comment.