From fb800f09393a311fccf9e185f607e431f04f069f Mon Sep 17 00:00:00 2001 From: Chris Pulman Date: Mon, 1 Jan 2024 00:16:07 +0000 Subject: [PATCH] Fix to handle Cancellation Token Tasks for ReactiveCommand.CreateFromTask (#3704) **What kind of change does this PR introduce?** Fix for #1245 Fix for #2153 Fix for #3450 **What is the current behavior?** ReactiveCommand does not properly support Cancellation tokens properly for CreateFromTask due to an underlying issue in System.Reactive **What is the new behavior?** Fix the issues with the base functions within ReactiveCommand due to an issue with Observable.FromAsync from System.Reactive by using a new ObservableMixins.FromAsyncWithAllNotifications as the new function, this extends Observable.FromAsync handling the error bubbling as required. ObservableMixins.FromAsyncWithAllNotifications can be used to transform a Cancellation Task into an Observable producing the expected cancellation, errors and results. **What might this PR break?** ReactiveCommand.CreateFromTask will now handle exceptions as expected, any existing workarounds could be removed once tested with actual implementation in end users code. **Please check if the PR fulfills these requirements** - [x] Tests for the changes have been added (for bug fixes / features) - [ ] Docs have been added / updated (for bug fixes / features) **Other information**: Co-authored-by: @idg10 - created the base code in #3556 --- .../ReactiveUI.Fody.Analyzer.Tests.csproj | 1 + ...valTests.ReactiveUI.DotNet6_0.verified.txt | 3 + ...valTests.ReactiveUI.DotNet7_0.verified.txt | 3 + ...valTests.ReactiveUI.DotNet8_0.verified.txt | 3 + ...provalTests.ReactiveUI.Net4_7.verified.txt | 3 + .../Commands/ReactiveCommandTest.cs | 385 ++++++++++++++++-- src/ReactiveUI.WinUI/ReactiveUI.WinUI.csproj | 1 + src/ReactiveUI/Mixins/ObservableMixins.cs | 99 ++++- .../ReactiveCommand/ReactiveCommand.cs | 208 +++++++--- 9 files changed, 603 insertions(+), 103 deletions(-) diff --git a/src/ReactiveUI.Fody.Analyzer.Test/ReactiveUI.Fody.Analyzer.Tests.csproj b/src/ReactiveUI.Fody.Analyzer.Test/ReactiveUI.Fody.Analyzer.Tests.csproj index 9e9c02c504..225e1c4d7f 100644 --- a/src/ReactiveUI.Fody.Analyzer.Test/ReactiveUI.Fody.Analyzer.Tests.csproj +++ b/src/ReactiveUI.Fody.Analyzer.Test/ReactiveUI.Fody.Analyzer.Tests.csproj @@ -2,6 +2,7 @@ net472;net6.0 false + $(NoWarn);MSB3243 diff --git a/src/ReactiveUI.Tests/API/ApiApprovalTests.ReactiveUI.DotNet6_0.verified.txt b/src/ReactiveUI.Tests/API/ApiApprovalTests.ReactiveUI.DotNet6_0.verified.txt index 418ec591e2..08891ca66e 100644 --- a/src/ReactiveUI.Tests/API/ApiApprovalTests.ReactiveUI.DotNet6_0.verified.txt +++ b/src/ReactiveUI.Tests/API/ApiApprovalTests.ReactiveUI.DotNet6_0.verified.txt @@ -695,6 +695,9 @@ namespace ReactiveUI } public class ReactiveCommand : ReactiveUI.ReactiveCommandBase { + protected ReactiveCommand([System.Runtime.CompilerServices.TupleElementNames(new string?[]?[] { + "Result", + "Cancel"})] System.Func, System.Action>>> execute, System.IObservable? canExecute, System.Reactive.Concurrency.IScheduler? outputScheduler) { } protected ReactiveCommand(System.Func> execute, System.IObservable? canExecute, System.Reactive.Concurrency.IScheduler? outputScheduler) { } public override System.IObservable CanExecute { get; } public override System.IObservable IsExecuting { get; } diff --git a/src/ReactiveUI.Tests/API/ApiApprovalTests.ReactiveUI.DotNet7_0.verified.txt b/src/ReactiveUI.Tests/API/ApiApprovalTests.ReactiveUI.DotNet7_0.verified.txt index b3fc73c125..a5fc5ef9f0 100644 --- a/src/ReactiveUI.Tests/API/ApiApprovalTests.ReactiveUI.DotNet7_0.verified.txt +++ b/src/ReactiveUI.Tests/API/ApiApprovalTests.ReactiveUI.DotNet7_0.verified.txt @@ -695,6 +695,9 @@ namespace ReactiveUI } public class ReactiveCommand : ReactiveUI.ReactiveCommandBase { + protected ReactiveCommand([System.Runtime.CompilerServices.TupleElementNames(new string?[]?[] { + "Result", + "Cancel"})] System.Func, System.Action>>> execute, System.IObservable? canExecute, System.Reactive.Concurrency.IScheduler? outputScheduler) { } protected ReactiveCommand(System.Func> execute, System.IObservable? canExecute, System.Reactive.Concurrency.IScheduler? outputScheduler) { } public override System.IObservable CanExecute { get; } public override System.IObservable IsExecuting { get; } diff --git a/src/ReactiveUI.Tests/API/ApiApprovalTests.ReactiveUI.DotNet8_0.verified.txt b/src/ReactiveUI.Tests/API/ApiApprovalTests.ReactiveUI.DotNet8_0.verified.txt index d6688f5094..e38976eadc 100644 --- a/src/ReactiveUI.Tests/API/ApiApprovalTests.ReactiveUI.DotNet8_0.verified.txt +++ b/src/ReactiveUI.Tests/API/ApiApprovalTests.ReactiveUI.DotNet8_0.verified.txt @@ -695,6 +695,9 @@ namespace ReactiveUI } public class ReactiveCommand : ReactiveUI.ReactiveCommandBase { + protected ReactiveCommand([System.Runtime.CompilerServices.TupleElementNames(new string?[]?[] { + "Result", + "Cancel"})] System.Func, System.Action>>> execute, System.IObservable? canExecute, System.Reactive.Concurrency.IScheduler? outputScheduler) { } protected ReactiveCommand(System.Func> execute, System.IObservable? canExecute, System.Reactive.Concurrency.IScheduler? outputScheduler) { } public override System.IObservable CanExecute { get; } public override System.IObservable IsExecuting { get; } diff --git a/src/ReactiveUI.Tests/API/ApiApprovalTests.ReactiveUI.Net4_7.verified.txt b/src/ReactiveUI.Tests/API/ApiApprovalTests.ReactiveUI.Net4_7.verified.txt index dcec3c8cb9..f8a1b294b7 100644 --- a/src/ReactiveUI.Tests/API/ApiApprovalTests.ReactiveUI.Net4_7.verified.txt +++ b/src/ReactiveUI.Tests/API/ApiApprovalTests.ReactiveUI.Net4_7.verified.txt @@ -693,6 +693,9 @@ namespace ReactiveUI } public class ReactiveCommand : ReactiveUI.ReactiveCommandBase { + protected ReactiveCommand([System.Runtime.CompilerServices.TupleElementNames(new string?[]?[] { + "Result", + "Cancel"})] System.Func, System.Action>>> execute, System.IObservable? canExecute, System.Reactive.Concurrency.IScheduler? outputScheduler) { } protected ReactiveCommand(System.Func> execute, System.IObservable? canExecute, System.Reactive.Concurrency.IScheduler? outputScheduler) { } public override System.IObservable CanExecute { get; } public override System.IObservable IsExecuting { get; } diff --git a/src/ReactiveUI.Tests/Commands/ReactiveCommandTest.cs b/src/ReactiveUI.Tests/Commands/ReactiveCommandTest.cs index 51e84ed87b..e3fa8380cd 100644 --- a/src/ReactiveUI.Tests/Commands/ReactiveCommandTest.cs +++ b/src/ReactiveUI.Tests/Commands/ReactiveCommandTest.cs @@ -3,10 +3,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for full license information. +using System.Diagnostics; using System.Windows.Input; using DynamicData; - +using FluentAssertions; using Microsoft.Reactive.Testing; using ReactiveUI.Testing; @@ -1163,49 +1164,116 @@ public void SynchronousCommandsFailCorrectly() [Fact] public async Task ReactiveCommandCreateFromTaskHandlesTaskExceptionAsync() + { + using var testSequencer = new TestSequencer(); + var subj = new Subject(); + var isExecuting = false; + Exception? fail = null; + var fixture = ReactiveCommand.CreateFromTask( + async _ => + { + await subj.Take(1); + throw new Exception("break execution"); + }, + outputScheduler: ImmediateScheduler.Instance); + + fixture.IsExecuting.Subscribe(async x => + { + isExecuting = x; + await testSequencer.AdvancePhaseAsync("Executing {false, true, false}"); + }); + fixture.ThrownExceptions.Subscribe(async ex => + { + fail = ex; + await testSequencer.AdvancePhaseAsync("Exception"); + }); + + await testSequencer.AdvancePhaseAsync("Executing {false}"); + Assert.False(isExecuting); + Assert.Null(fail); + + fixture.Execute().Subscribe(); + await testSequencer.AdvancePhaseAsync("Executing {true}"); + Assert.True(isExecuting); + Assert.Null(fail); + + subj.OnNext(Unit.Default); + + // Wait to allow execution to complete + await testSequencer.AdvancePhaseAsync("Executing {false}"); + await testSequencer.AdvancePhaseAsync("Exception"); + Assert.False(isExecuting); + Assert.Equal("break execution", fail?.Message); + testSequencer.Dispose(); + } + + [Fact] + public async Task ReactiveCommandCreateFromTaskThenCancelSetsIsExecutingFalseOnlyAfterCancellationCompleteAsync() + { + using var testSequencer = new TestSequencer(); + var statusTrail = new List<(int Position, string Status)>(); + var position = 0; + + var fixture = ReactiveCommand.CreateFromTask(async (token) => { - using var testSequencer = new TestSequencer(); - var subj = new Subject(); - var isExecuting = false; - Exception? fail = null; - var fixture = ReactiveCommand.CreateFromTask( - async _ => - { - await subj.Take(1); - throw new Exception("break execution"); - }, - outputScheduler: ImmediateScheduler.Instance); - - fixture.IsExecuting.Subscribe(async x => + // Phase 1: command execution has begun. + await testSequencer.AdvancePhaseAsync("Phase 1"); + statusTrail.Add((position++, "started command")); + try { - isExecuting = x; - await testSequencer.AdvancePhaseAsync("Executing {false, true, false}"); - }); - fixture.ThrownExceptions.Subscribe(async ex => + await Task.Delay(10000, token); + } + catch (OperationCanceledException) { - fail = ex; - await testSequencer.AdvancePhaseAsync("Exception"); - }); + // Phase 2: command task has detected cancellation request. + await testSequencer.AdvancePhaseAsync("Phase 2"); - await testSequencer.AdvancePhaseAsync("Executing {false}"); - Assert.False(isExecuting); - Assert.Null(fail); + // Phase 3: test has observed IsExecuting while cancellation is in progress. + await testSequencer.AdvancePhaseAsync("Phase 3"); + throw; + } - fixture.Execute().Subscribe(); - await testSequencer.AdvancePhaseAsync("Executing {true}"); - Assert.True(isExecuting); - Assert.Null(fail); + statusTrail.Add((position++, "finished command")); + }); - subj.OnNext(Unit.Default); + var latestIsExecutingValue = false; + fixture.IsExecuting.Subscribe(isExecuting => + { + statusTrail.Add((position++, $"command executing = {isExecuting}")); + Volatile.Write(ref latestIsExecutingValue, isExecuting); + }); - // Wait to allow execution to complete - await testSequencer.AdvancePhaseAsync("Executing {false}"); - await testSequencer.AdvancePhaseAsync("Exception"); - Assert.False(isExecuting); - Assert.Equal("break execution", fail?.Message); - testSequencer.Dispose(); + var disposable = fixture.Execute().Subscribe(); + + // Phase 1: command execution has begun. + await testSequencer.AdvancePhaseAsync("Phase 1"); + + Volatile.Read(ref latestIsExecutingValue).Should().BeTrue("IsExecuting should be true when execution is underway"); + + disposable.Dispose(); + + // Phase 2: command task has detected cancellation request. + await testSequencer.AdvancePhaseAsync("Phase 2"); + + Volatile.Read(ref latestIsExecutingValue).Should().BeTrue("IsExecuting should remain true while cancellation is in progress"); + + // Phase 3: test has observed IsExecuting while cancellation is in progress. + await testSequencer.AdvancePhaseAsync("Phase 3"); + + var start = Environment.TickCount; + while (unchecked(Environment.TickCount - start) < 1000 && Volatile.Read(ref latestIsExecutingValue)) + { + await Task.Yield(); } + Volatile.Read(ref latestIsExecutingValue).Should().BeFalse("IsExecuting should be false once cancellation completes"); + statusTrail.Should().Equal( + (0, "command executing = False"), + (1, "command executing = True"), + (2, "started command"), + (3, "command executing = False")); + } + [Fact] public async Task ReactiveCommandExecutesFromInvokeCommand() { @@ -1233,20 +1301,259 @@ public async Task ReactiveCommandExecutesFromInvokeCommand() // set var fooVm = new Mocks.FooViewModel(new()); - Assert.Equal(42, fooVm.Foo.Value); // initial value unchanged + fooVm.Foo.Value.Should().Be(42, "initial value unchanged"); // act scheduler.AdvanceByMs(11); // async processing - Assert.Equal(0, fooVm.Foo.Value); // value set to default Setpoint value + fooVm.Foo.Value.Should().Be(0, "value set to default Setpoint value"); fooVm.Setpoint = 123; scheduler.AdvanceByMs(5); // async task processing // assert - Assert.Equal(0, fooVm.Foo.Value); // value unchanged as async task still processing + fooVm.Foo.Value.Should().Be(0, "value unchanged as async task still processing"); scheduler.AdvanceByMs(6); // process async setpoint setting - Assert.Equal(123, fooVm.Foo.Value); + fooVm.Foo.Value.Should().Be(123, "value set to Setpoint value"); return Task.CompletedTask; }); + + [Fact] + public async Task ReactiveCommandCreateFromTaskHandlesExecuteCancellation() + { + using var testSequencer = new TestSequencer(); + var statusTrail = new List<(int Position, string Status)>(); + var position = 0; + var fixture = ReactiveCommand.CreateFromTask( + async cts => + { + await testSequencer.AdvancePhaseAsync("Phase 1"); // #1 + statusTrail.Add((position++, "started command")); + try + { + await Task.Delay(10000, cts); + } + catch (OperationCanceledException) + { + // User Handles cancellation. + statusTrail.Add((position++, "starting cancelling command")); + await testSequencer.AdvancePhaseAsync("Phase 2"); // #2 + + // dummy cleanup + await testSequencer.AdvancePhaseAsync("Phase 3"); // #3 + statusTrail.Add((position++, "finished cancelling command")); + throw; + } + + return Unit.Default; + }, + outputScheduler: ImmediateScheduler.Instance); + + Exception? fail = null; + fixture.ThrownExceptions.Subscribe(ex => fail = ex); + var latestIsExecutingValue = false; + fixture.IsExecuting.Subscribe(isExecuting => + { + statusTrail.Add((position++, $"command executing = {isExecuting}")); + Volatile.Write(ref latestIsExecutingValue, isExecuting); + }); + + fail.Should().BeNull(); + var result = false; + var disposable = fixture.Execute().Subscribe(_ => result = true); + await testSequencer.AdvancePhaseAsync("Phase 1"); // #1 + Volatile.Read(ref latestIsExecutingValue).Should().BeTrue(); + statusTrail.Any(x => x.Status == "started command").Should().BeTrue(); + disposable.Dispose(); + await testSequencer.AdvancePhaseAsync("Phase 2"); // #2 + Volatile.Read(ref latestIsExecutingValue).Should().BeTrue(); + await testSequencer.AdvancePhaseAsync("Phase 3"); // #3 + + var start = Environment.TickCount; + while (unchecked(Environment.TickCount - start) < 1000 && Volatile.Read(ref latestIsExecutingValue)) + { + await Task.Yield(); + } + + // No result expected as cancelled + result.Should().BeFalse(); + statusTrail.Should().Equal( + (0, "command executing = False"), + (1, "command executing = True"), + (2, "started command"), + (3, "starting cancelling command"), + (4, "finished cancelling command"), + (5, "command executing = False")); + (fail as OperationCanceledException).Should().NotBeNull(); + } + + [Fact] + public void ReactiveCommandCreateFromTaskHandlesTaskException() => + new TestScheduler().With( + async scheduler => + { + var subj = new Subject(); + Exception? fail = null; + var fixture = ReactiveCommand.CreateFromTask( + async cts => + { + await subj.Take(1); + throw new Exception("break execution"); + }, + outputScheduler: scheduler); + fixture.IsExecuting.ToObservableChangeSet(ImmediateScheduler.Instance).Bind(out var isExecuting).Subscribe(); + fixture.ThrownExceptions.Subscribe(ex => fail = ex); + isExecuting[0].Should().BeFalse(); + fail.Should().BeNull(); + fixture.Execute().Subscribe(); + + scheduler.AdvanceByMs(10); + isExecuting[1].Should().BeTrue(); + fail.Should().BeNull(); + + scheduler.AdvanceByMs(10); + subj.OnNext(Unit.Default); + + scheduler.AdvanceByMs(10); + isExecuting[2].Should().BeFalse(); + fail?.Message.Should().Be("break execution"); + + // Required for correct async / await task handling + await Task.Delay(0); + }); + + [Fact] + public async Task ReactiveCommandCreateFromTaskHandlesCancellation() + { + using var testSequencer = new TestSequencer(); + var statusTrail = new List<(int Position, string Status)>(); + var position = 0; + var fixture = ReactiveCommand.CreateFromTask( + async cts => + { + statusTrail.Add((position++, "started command")); + await testSequencer.AdvancePhaseAsync("Phase 1"); // #1 + try + { + await Task.Delay(10000, cts); + } + catch (OperationCanceledException) + { + // User Handles cancellation. + statusTrail.Add((position++, "starting cancelling command")); + await testSequencer.AdvancePhaseAsync("Phase 2"); // #2 + + // dummy cleanup + statusTrail.Add((position++, "finished cancelling command")); + await testSequencer.AdvancePhaseAsync("Phase 3"); // #3 + throw; + } + + return Unit.Default; + }, + outputScheduler: ImmediateScheduler.Instance); + + Exception? fail = null; + fixture.ThrownExceptions.Subscribe(ex => fail = ex); + var latestIsExecutingValue = false; + fixture.IsExecuting.Subscribe(isExecuting => + { + statusTrail.Add((position++, $"command executing = {isExecuting}")); + Volatile.Write(ref latestIsExecutingValue, isExecuting); + }); + + fail.Should().BeNull(); + var result = false; + var disposable = fixture.Execute().Subscribe(_ => result = true); + await testSequencer.AdvancePhaseAsync("Phase 1"); // #1 + Volatile.Read(ref latestIsExecutingValue).Should().BeTrue(); + statusTrail.Any(x => x.Status == "started command").Should().BeTrue(); + disposable.Dispose(); + await testSequencer.AdvancePhaseAsync("Phase 2"); // #2 + Volatile.Read(ref latestIsExecutingValue).Should().BeTrue(); + await testSequencer.AdvancePhaseAsync("Phase 3"); // #3 + var start = Environment.TickCount; + while (unchecked(Environment.TickCount - start) < 1000 && Volatile.Read(ref latestIsExecutingValue)) + { + await Task.Yield(); + } + + // No result expected as cancelled + result.Should().BeFalse(); + statusTrail.Should().Equal( + (0, "command executing = False"), + (1, "command executing = True"), + (2, "started command"), + (3, "starting cancelling command"), + (4, "finished cancelling command"), + (5, "command executing = False")); + (fail as OperationCanceledException).Should().NotBeNull(); + } + + [Fact] + public async Task ReactiveCommandCreateFromTaskHandlesCompletion() + { + using var testSequencer = new TestSequencer(); + var statusTrail = new List<(int Position, string Status)>(); + var position = 0; + var fixture = ReactiveCommand.CreateFromTask( + async cts => + { + await testSequencer.AdvancePhaseAsync("Phase 1"); // #1 + statusTrail.Add((position++, "started command")); + try + { + await Task.Delay(1000, cts); + } + catch (OperationCanceledException) + { + // User Handles cancellation. + statusTrail.Add((position++, "starting cancelling command")); + + // dummy cleanup + await Task.Delay(5000, CancellationToken.None); + statusTrail.Add((position++, "finished cancelling command")); + throw; + } + + statusTrail.Add((position++, "finished command")); + await testSequencer.AdvancePhaseAsync("Phase 2"); // #2 + return Unit.Default; + }, + outputScheduler: ImmediateScheduler.Instance); + + Exception? fail = null; + fixture.ThrownExceptions.Subscribe(ex => fail = ex); + var latestIsExecutingValue = false; + fixture.IsExecuting.Subscribe(isExecuting => + { + statusTrail.Add((position++, $"command executing = {isExecuting}")); + Volatile.Write(ref latestIsExecutingValue, isExecuting); + }); + + fail.Should().BeNull(); + var result = false; + fixture.Execute().Subscribe(_ => result = true); + await testSequencer.AdvancePhaseAsync("Phase 1"); // #1 + Volatile.Read(ref latestIsExecutingValue).Should().BeTrue(); + await testSequencer.AdvancePhaseAsync("Phase 2"); // #2 + + var start = Environment.TickCount; + while (unchecked(Environment.TickCount - start) < 1000 && Volatile.Read(ref latestIsExecutingValue)) + { + await Task.Yield(); + } + + result.Should().BeTrue(); + statusTrail.Should().Equal( + (0, "command executing = False"), + (1, "command executing = True"), + (2, "started command"), + (3, "finished command"), + (4, "command executing = False")); + fail.Should().BeNull(); + + // Check execution completed + Volatile.Read(ref latestIsExecutingValue).Should().BeFalse(); + } } diff --git a/src/ReactiveUI.WinUI/ReactiveUI.WinUI.csproj b/src/ReactiveUI.WinUI/ReactiveUI.WinUI.csproj index 295ea854fb..5a6c0c7b9d 100644 --- a/src/ReactiveUI.WinUI/ReactiveUI.WinUI.csproj +++ b/src/ReactiveUI.WinUI/ReactiveUI.WinUI.csproj @@ -6,6 +6,7 @@ ReactiveUI.WinUI.Desktop mvvm;reactiveui;rx;reactive extensions;observable;LINQ;events;winui true + $(NoWarn);NETSDK1206 IS_WINUI;WINUI_TARGET; win-x64;win-x86;win-arm64 10.0.19041.0 diff --git a/src/ReactiveUI/Mixins/ObservableMixins.cs b/src/ReactiveUI/Mixins/ObservableMixins.cs index abb299343f..38aa8cd41f 100644 --- a/src/ReactiveUI/Mixins/ObservableMixins.cs +++ b/src/ReactiveUI/Mixins/ObservableMixins.cs @@ -21,4 +21,101 @@ public static class ObservableMixins observable .Where(x => x is not null) .Select(x => x!); -} \ No newline at end of file + + /// + /// Converts an asynchronous action into an observable sequence. Each subscription + /// to the resulting sequence causes the action to be started. The CancellationToken + /// passed to the asynchronous action is tied to the observable sequence's subscription + /// that triggered the action's invocation and can be used for best-effort cancellation. + /// + /// Asynchronous action to convert. + /// An observable sequence exposing a Unit value upon completion of the action, or an exception. + internal static IObservable<(IObservable Result, Action Cancel)> FromAsyncWithAllNotifications( + Func actionAsync) => Observable.Defer( + () => + { + var cts = new CancellationTokenSource(); + var result = Observable.FromAsync( + async ctsBase => + { + using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cts.Token, ctsBase); + await actionAsync(linkedCts.Token); + }); + return Observable.Return<(IObservable Result, Action Cancel)>((result, () => cts.Cancel())); + }); + + /// + /// Converts an asynchronous action into an observable sequence. Each subscription + /// to the resulting sequence causes the action to be started. The CancellationToken + /// passed to the asynchronous action is tied to the observable sequence's subscription + /// that triggered the action's invocation and can be used for best-effort cancellation. + /// + /// The type of the parameter. + /// Asynchronous action to convert. + /// The parameter. + /// An observable sequence exposing a Unit value upon completion of the action, or an exception. + internal static IObservable<(IObservable Result, Action Cancel)> FromAsyncWithAllNotifications( + Func actionAsync, TParam param) => Observable.Defer( + () => + { + var cts = new CancellationTokenSource(); + var result = Observable.FromAsync( + async ctsBase => + { + using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cts.Token, ctsBase); + await actionAsync(param, linkedCts.Token); + }); + + return Observable.Return<(IObservable Result, Action Cancel)>((result, () => cts.Cancel())); + }); + + /// + /// Converts an asynchronous action into an observable sequence. Each subscription + /// to the resulting sequence causes the action to be started. The CancellationToken + /// passed to the asynchronous action is tied to the observable sequence's subscription + /// that triggered the action's invocation and can be used for best-effort cancellation. + /// + /// The type of the result. + /// Asynchronous action to convert. + /// An observable sequence exposing a Unit value upon completion of the action, or an exception. + internal static IObservable<(IObservable Result, Action Cancel)> FromAsyncWithAllNotifications( + Func> actionAsync) => Observable.Defer( + () => + { + var cts = new CancellationTokenSource(); + var result = Observable.FromAsync( + async ctsBase => + { + var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cts.Token, ctsBase); + return await actionAsync(linkedCts.Token); + }); + + return Observable.Return<(IObservable Result, Action Cancel)>((result, () => cts.Cancel())); + }); + + /// + /// Converts an asynchronous action into an observable sequence. Each subscription + /// to the resulting sequence causes the action to be started. The CancellationToken + /// passed to the asynchronous action is tied to the observable sequence's subscription + /// that triggered the action's invocation and can be used for best-effort cancellation. + /// + /// The type of the parameter. + /// The type of the result. + /// Asynchronous action to convert. + /// The parameter. + /// An observable sequence exposing a Unit value upon completion of the action, or an exception. + internal static IObservable<(IObservable Result, Action Cancel)> FromAsyncWithAllNotifications( + Func> actionAsync, TParam param) => Observable.Defer( + () => + { + var cts = new CancellationTokenSource(); + var result = Observable.FromAsync( + async cancelFromRx => + { + using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cts.Token, cancelFromRx); + return await actionAsync(param, linkedCts.Token); + }); + + return Observable.Return<(IObservable Result, Action Cancel)>((result, () => cts.Cancel())); + }); +} diff --git a/src/ReactiveUI/ReactiveCommand/ReactiveCommand.cs b/src/ReactiveUI/ReactiveCommand/ReactiveCommand.cs index db3180c943..f0dfc4bbfb 100644 --- a/src/ReactiveUI/ReactiveCommand/ReactiveCommand.cs +++ b/src/ReactiveUI/ReactiveCommand/ReactiveCommand.cs @@ -468,7 +468,7 @@ public static class ReactiveCommand throw new ArgumentNullException(nameof(execute)); } - return CreateFromObservable(() => Observable.FromAsync(execute), canExecute, outputScheduler); + return CreateFromObservableCancellable(() => ObservableMixins.FromAsyncWithAllNotifications(execute), canExecute, outputScheduler); } /// @@ -524,7 +524,7 @@ public static class ReactiveCommand throw new ArgumentNullException(nameof(execute)); } - return CreateFromObservable(() => Observable.FromAsync(execute), canExecute, outputScheduler); + return CreateFromObservableCancellable(() => ObservableMixins.FromAsyncWithAllNotifications(execute), canExecute, outputScheduler); } /// @@ -595,8 +595,8 @@ public static class ReactiveCommand throw new ArgumentNullException(nameof(execute)); } - return CreateFromObservable( - param => Observable.FromAsync(ct => execute(param, ct)), + return CreateFromObservableCancellable( + param => ObservableMixins.FromAsyncWithAllNotifications(ct => execute(param, ct)), canExecute, outputScheduler); } @@ -663,11 +663,76 @@ public static class ReactiveCommand throw new ArgumentNullException(nameof(execute)); } - return CreateFromObservable( - param => Observable.FromAsync(ct => execute(param, ct)), + return CreateFromObservableCancellable( + param => ObservableMixins.FromAsyncWithAllNotifications(ct => execute(param, ct)), canExecute, outputScheduler); } + + /// + /// Creates a parameterless with asynchronous execution logic. + /// + /// The type of the parameter. + /// The type of the command's result. + /// Provides an observable representing the command's asynchronous execution logic. + /// An optional observable that dictates the availability of the command for execution. + /// An optional scheduler that is used to surface events. Defaults to RxApp.MainThreadScheduler. + /// + /// The ReactiveCommand instance. + /// + /// execute. + internal static ReactiveCommand CreateFromObservableCancellable( + Func Result, Action Cancel)>> execute, + IObservable? canExecute = null, + IScheduler? outputScheduler = null) + { + if (execute is null) + { + throw new ArgumentNullException(nameof(execute)); + } + + return new ReactiveCommand( + _ => execute(), + canExecute, + outputScheduler); + } + + /// + /// Creates a with asynchronous execution logic that takes a parameter of type . + /// + /// + /// Provides an observable representing the command's asynchronous execution logic. + /// + /// + /// An optional observable that dictates the availability of the command for execution. + /// + /// + /// An optional scheduler that is used to surface events. Defaults to RxApp.MainThreadScheduler. + /// + /// + /// The ReactiveCommand instance. + /// + /// + /// The type of the parameter passed through to command execution. + /// + /// + /// The type of the command's result. + /// + internal static ReactiveCommand CreateFromObservableCancellable( + Func Result, Action Cancel)>> execute, + IObservable? canExecute = null, + IScheduler? outputScheduler = null) + { + if (execute is null) + { + throw new ArgumentNullException(nameof(execute)); + } + + return new ReactiveCommand( + execute, + canExecute, + outputScheduler); + } } /// @@ -696,7 +761,7 @@ public class ReactiveCommand : ReactiveCommandBase _exceptions; - private readonly Func> _execute; + private readonly Func Result, Action Cancel)>> _execute; [SuppressMessage("Design", "CA2213: Dispose member", Justification = "Internal use only")] private readonly Subject _executionInfo; private readonly IObservable _isExecuting; @@ -705,7 +770,9 @@ public class ReactiveCommand : ReactiveCommandBase _synchronizedExecutionInfo; /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class for work + /// that signals cancellation through a separate callback (as opposed to cancelling by + /// unsubscribing). /// /// The Func to perform when the command is executed. /// A observable which has a value if the command can execute. @@ -715,7 +782,7 @@ public class ReactiveCommand : ReactiveCommandBase /// Thrown if any dependent parameters are null. protected internal ReactiveCommand( - Func> execute, + Func Result, Action Cancel)>> execute, IObservable? canExecute, IScheduler? outputScheduler) { @@ -728,11 +795,11 @@ public class ReactiveCommand : ReactiveCommandBase next.Demarcation switch - { - ExecutionDemarcation.Begin => acc + 1, - ExecutionDemarcation.End => acc > 0 ? acc - 1 : acc = 0, - _ => acc - }) + { + ExecutionDemarcation.Begin => acc + 1, + ExecutionDemarcation.End => acc > 0 ? acc - 1 : acc = 0, + _ => acc + }) .Select(inFlightCount => inFlightCount > 0) .StartWith(false) .DistinctUntilChanged() @@ -758,6 +825,38 @@ public class ReactiveCommand : ReactiveCommandBase + /// Initializes a new instance of the class. + /// + /// The Func to perform when the command is executed. + /// A observable which has a value if the command can execute. + /// The scheduler where to send output after the main execution. + /// + /// execute. + /// + /// Thrown if any dependent parameters are null. + protected internal ReactiveCommand( + Func> execute, + IObservable? canExecute, + IScheduler? outputScheduler) + : this( + p => + { + var resultObservable = execute(p); + return Observable.Defer( + () => + { + var cancelationSubject = new Subject(); + void Cancel() => cancelationSubject.OnNext(Unit.Default); + return Observable + .Return((resultObservable.TakeUntil(cancelationSubject), (Action)Cancel)); + }); + }, + canExecute, + outputScheduler) + { + } + private enum ExecutionDemarcation { Begin, @@ -783,22 +882,31 @@ public override IObservable Execute(TParam parameter) { return Observable.Defer( () => - { - _synchronizedExecutionInfo.OnNext(ExecutionInfo.CreateBegin()); - - return Observable.Empty; - }) - .Concat(_execute(parameter)) - .Do(result => _synchronizedExecutionInfo.OnNext(ExecutionInfo.CreateResult(result))) - .Catch( - ex => - { - _exceptions.OnNext(ex); - return Observable.Throw(ex); - }) - .Finally(() => _synchronizedExecutionInfo.OnNext(ExecutionInfo.CreateEnd())) - .PublishLast() - .RefCount(); + { + _synchronizedExecutionInfo.OnNext(ExecutionInfo.CreateBegin()); + return Observable<(IObservable, Action)>.Empty; + }) + .Concat(_execute(parameter)) + .SelectMany(sourceAndCancellation => + { + var (sourceObservable, cancelCallback) = sourceAndCancellation; + var sharedSource = sourceObservable.Publish().RefCount(2); + + // This is the subscription that survives for however long sourceObservable takes to complete (or fail). + sharedSource + .Do(result => _synchronizedExecutionInfo.OnNext(ExecutionInfo.CreateResult(result))) + .Catch( + ex => + { + _exceptions.OnNext(ex); + return Observable.Empty(); + }) + .Finally(() => _synchronizedExecutionInfo.OnNext(ExecutionInfo.CreateEnd())) + .Subscribe(); + + // TODO: Check if it is a problem that we always cancel, even on normal completion!!! + return sharedSource.Finally(() => cancelCallback()); + }); } catch (Exception ex) { @@ -809,39 +917,11 @@ public override IObservable Execute(TParam parameter) } /// - public override IObservable Execute() - { - try - { - return Observable.Defer( - () => - { - _synchronizedExecutionInfo.OnNext(ExecutionInfo.CreateBegin()); - - return Observable.Empty; - }) - .Concat(_execute(default!)) - .Do(result => _synchronizedExecutionInfo.OnNext(ExecutionInfo.CreateResult(result))) - .Catch( - ex => - { - _exceptions.OnNext(ex); - return Observable.Throw(ex); - }) - .Finally(() => _synchronizedExecutionInfo.OnNext(ExecutionInfo.CreateEnd())) - .PublishLast() - .RefCount(); - } - catch (Exception ex) - { - _synchronizedExecutionInfo.OnNext(ExecutionInfo.CreateEnd()); - _exceptions.OnNext(ex); - return Observable.Throw(ex); - } - } + public override IObservable Execute() => Execute(default!); /// - public override IDisposable Subscribe(IObserver observer) => _results.Subscribe(observer); + public override IDisposable Subscribe(IObserver observer) => + _results.Subscribe(observer); /// protected override void Dispose(bool disposing) @@ -866,11 +946,13 @@ private ExecutionInfo(ExecutionDemarcation demarcation, TResult result) public TResult Result { get; } - public static ExecutionInfo CreateBegin() => new(ExecutionDemarcation.Begin, default!); + public static ExecutionInfo CreateBegin() => + new(ExecutionDemarcation.Begin, default!); public static ExecutionInfo CreateResult(TResult result) => new(ExecutionDemarcation.Result, result); - public static ExecutionInfo CreateEnd() => new(ExecutionDemarcation.End, default!); + public static ExecutionInfo CreateEnd() => + new(ExecutionDemarcation.End, default!); } }