Skip to content

Commit

Permalink
Protocol: fix emitOnCapturedContext. (#192)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmds authored Apr 13, 2023
1 parent 056ffa2 commit ab86c36
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 43 deletions.
88 changes: 51 additions & 37 deletions src/Tmds.DBus.Protocol/DBusConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ public void Invoke(Exception? exception, Message message)
private string? _localName;
private Message? _currentMessage;
private Observer? _currentObserver;
private SynchronizationContext? _currentSynchronizationContext;
private TaskCompletionSource<Exception?>? _disconnectedTcs;

public string? UniqueName => _localName;
Expand Down Expand Up @@ -398,16 +399,27 @@ private void EmitOnSynchronizationContextHelper(Observer observer, Synchronizati
{
_currentMessage = message;
_currentObserver = observer;
_currentSynchronizationContext = synchronizationContext;

#pragma warning disable VSTHRD001 // Await JoinableTaskFactory.SwitchToMainThreadAsync() to switch to the UI thread instead of APIs that can deadlock or require specifying a priority.
// note: Send blocks the current thread until the SynchronizationContext ran the delegate.
synchronizationContext.Send(static o => {
DBusConnection conn = (DBusConnection)o;
conn._currentObserver!.Emit(conn._currentMessage!);
SynchronizationContext previousContext = SynchronizationContext.Current;
try
{
DBusConnection conn = (DBusConnection)o;
SynchronizationContext.SetSynchronizationContext(conn._currentSynchronizationContext);
conn._currentObserver!.InvokeHandler(conn._currentMessage!);
}
finally
{
SynchronizationContext.SetSynchronizationContext(previousContext);
}
}, this);

_currentMessage = null;
_currentObserver = null;
_currentSynchronizationContext = null;
}

public void AddMethodHandlers(IList<IMethodHandler> methodHandlers)
Expand Down Expand Up @@ -472,7 +484,7 @@ public void Dispose()
{
foreach (var observer in matchMaker.Observers)
{
observer.Disconnect(new DisconnectedException(disconnectReason));
observer.Dispose(new DisconnectedException(disconnectReason), removeObserver: false);
}
}
_matchMakers.Clear();
Expand Down Expand Up @@ -690,7 +702,7 @@ private async ValueTask<IDisposable> AddMatchAsync(SynchronizationContext? synch
}
catch
{
observer.Dispose(invokeHandler: false);
observer.Dispose(exception: null);

throw;
}
Expand Down Expand Up @@ -734,9 +746,9 @@ public Observer(SynchronizationContext? synchronizationContext, MatchMaker match
Subscribes = subscribes;
}

public void Dispose() => Dispose(invokeHandler: true);
public void Dispose() => Dispose(s_objectDisposedException);

public void Dispose(bool invokeHandler)
public void Dispose(Exception? exception, bool removeObserver = true)
{
lock (_gate)
{
Expand All @@ -747,67 +759,69 @@ public void Dispose(bool invokeHandler)
_disposed = true;
}

if (invokeHandler)
if (exception is not null)
{
_messageHandler.Invoke(s_objectDisposedException, null!);
Emit(exception);
}

_matchMaker.Connection.RemoveObserver(_matchMaker, this);
if (removeObserver)
{
_matchMaker.Connection.RemoveObserver(_matchMaker, this);
}
}

public void EmitOnSynchronizationContext(Message message)
public void Emit(Message message)
{
if (_synchronizationContext is null)
{
Emit(message);
InvokeHandler(message);
}
else
{
_matchMaker.Connection.EmitOnSynchronizationContextHelper(this, _synchronizationContext, message);
}
}

public void Emit(Message message)
private void Emit(Exception exception)
{
if (Subscribes && !_matchMaker.HasSubscribed)
if (_synchronizationContext is null ||
SynchronizationContext.Current == _synchronizationContext)
{
return;
_messageHandler.Invoke(exception, null!);
}

lock (_gate)
else
{
if (_disposed)
{
return;
}

_messageHandler.Invoke(null, message);
_synchronizationContext.Send(
delegate {
SynchronizationContext previousContext = SynchronizationContext.Current;
try
{
SynchronizationContext.SetSynchronizationContext(_synchronizationContext);
_messageHandler.Invoke(exception, null!);
}
finally
{
SynchronizationContext.SetSynchronizationContext(previousContext);
}
}, null);
}
}

internal void Disconnect(DisconnectedException disconnectedException)
internal void InvokeHandler(Message message)
{
if (Subscribes && !_matchMaker.HasSubscribed)
{
return;
}

lock (_gate)
{
if (_disposed)
{
return;
}
_disposed = true;
}

if (_synchronizationContext is null)
{
InvokeHandler(disconnectedException);
}
else
{
_synchronizationContext.Send(delegate { InvokeHandler(disconnectedException); }, null);
}

void InvokeHandler(DisconnectedException disconnectedException)
{
_messageHandler.Invoke(disconnectedException, null!);
_messageHandler.Invoke(null, message);
}
}
}
Expand Down
45 changes: 39 additions & 6 deletions test/Tmds.DBus.Protocol.Tests/ConnectionTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using Xunit;

Expand All @@ -18,19 +19,48 @@ public async Task MethodAsync()
Assert.Equal("hello world", reply);
}

[Fact]
public async Task SignalAsync()
[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task SignalAsync(bool setSynchronizationContext)
{
if (setSynchronizationContext)
{
SynchronizationContext.SetSynchronizationContext(new MySynchronizationContext());
}
SynchronizationContext? expectedSynchronizationContext = SynchronizationContext.Current;

var connections = PairedConnection.CreatePair();
using var conn1 = connections.Item1;
using var conn2 = connections.Item2;

var proxy = new HelloWorld(conn1, "servicename");
var tcs = new TaskCompletionSource<string>();
await proxy.WatchHelloWorldAsync((ex, msg) => tcs.SetResult(msg));
var msgTcs = new TaskCompletionSource<(string, SynchronizationContext?)>();
var exTcs = new TaskCompletionSource<(Exception?, SynchronizationContext?)>();

await proxy.WatchHelloWorldAsync((ex, msg) =>
{
if (msg is not null)
{
msgTcs.SetResult((msg, SynchronizationContext.Current));
}
else
{
exTcs.SetResult((ex, SynchronizationContext.Current));
}
});

SendHelloWorld(conn2);
var reply = await tcs.Task;
Assert.Equal("hello world", reply);

var msg = await msgTcs.Task;
Assert.Equal("hello world", msg.Item1);
Assert.Equal(expectedSynchronizationContext, msg.Item2);

conn1.Dispose();

var ex = await exTcs.Task;
Assert.IsType<DisconnectedException>(ex.Item1);
Assert.Equal(expectedSynchronizationContext, ex.Item2);

static void SendHelloWorld(Connection connection)
{
Expand All @@ -51,6 +81,9 @@ static void SendHelloWorld(Connection connection)
}
}

sealed class MySynchronizationContext : SynchronizationContext
{ }

[InlineData("tcp:host=localhost,port=1")]
[InlineData("unix:path=/does/not/exist")]
[Theory]
Expand Down

0 comments on commit ab86c36

Please sign in to comment.