diff --git a/src/__fixtures__/test-constants.ts b/src/__fixtures__/test-constants.ts index 7ebbeebec..c563f0e22 100644 --- a/src/__fixtures__/test-constants.ts +++ b/src/__fixtures__/test-constants.ts @@ -1,4 +1,5 @@ export const TEST_URL = 'http://base.test'; export const TEST_HOST = 'test-host'; +export const TEST_LOCAL_HOST = 'localhost:8080'; export const TEST_AUTH_ENTITY = 'test-entity'; export const TEST_SIGNALING_ADDRESS = 'https://signaling.test'; diff --git a/src/robot/__tests__/client.spec.ts b/src/robot/__tests__/client.spec.ts index 3502da7de..4de097d91 100644 --- a/src/robot/__tests__/client.spec.ts +++ b/src/robot/__tests__/client.spec.ts @@ -12,6 +12,7 @@ import { } from '../../__fixtures__/credentials'; import { TEST_HOST, + TEST_LOCAL_HOST, TEST_SIGNALING_ADDRESS, } from '../../__fixtures__/test-constants'; import { baseDialConfig } from '../__fixtures__/dial-configs'; @@ -383,4 +384,225 @@ describe('RobotClient', () => { expect(mockResetFn).not.toHaveBeenCalled(); }); }); + + describe('dial error handling', () => { + const captureDisconnectedEvents = () => { + const events: unknown[] = []; + const setupListener = (client: RobotClient) => { + client.on('disconnected', (event) => { + events.push(event); + }); + }; + return { events, setupListener }; + }; + + const findEventWithError = ( + events: unknown[], + errorMessage?: string + ): unknown => { + return events.find((event) => { + if ( + typeof event !== 'object' || + event === null || + !('error' in event) + ) { + return false; + } + if (errorMessage === undefined || errorMessage === '') { + return true; + } + const { error } = event as { error: Error }; + return error.message === errorMessage; + }); + }; + + it('should return client instance when WebRTC connection succeeds', async () => { + // Arrange + const client = setupClientMocks(); + + // Act + const result = await client.dial({ + ...baseDialConfig, + noReconnect: true, + }); + + // Assert + expect(result).toBe(client); + }); + + it('should throw error when both WebRTC and gRPC connections fail', async () => { + // Arrange + const client = new RobotClient(); + const webrtcError = new Error('WebRTC connection failed'); + + vi.mocked(rpcModule.dialWebRTC).mockRejectedValue(webrtcError); + + // Act & Assert + await expect( + client.dial({ + ...baseDialConfig, + noReconnect: true, + }) + ).rejects.toThrow('Failed to connect via all methods'); + }); + + it('should emit DISCONNECTED events for both failures before throwing', async () => { + // Arrange + const client = new RobotClient(); + const webrtcError = new Error('WebRTC connection failed'); + const { events, setupListener } = captureDisconnectedEvents(); + + setupListener(client); + vi.mocked(rpcModule.dialWebRTC).mockRejectedValue(webrtcError); + + // Act + try { + await client.dial({ + ...baseDialConfig, + noReconnect: true, + }); + } catch { + // Expected to throw + } + + // Assert + expect(events.length).toBeGreaterThanOrEqual(2); + const webrtcEvent = findEventWithError( + events, + 'WebRTC connection failed' + ); + expect(webrtcEvent).toBeDefined(); + expect(webrtcEvent).toMatchObject({ error: webrtcError }); + }); + + it('should emit DISCONNECTED event when gRPC fails and throw', async () => { + // Arrange + const client = new RobotClient(); + const { events, setupListener } = captureDisconnectedEvents(); + + setupListener(client); + + // Act + try { + await client.dial({ + host: TEST_HOST, + noReconnect: true, + }); + } catch { + // Expected to throw + } + + // Assert + expect(events.length).toBeGreaterThanOrEqual(1); + const errorEvent = findEventWithError(events); + expect(errorEvent).toBeDefined(); + expect((errorEvent as { error: Error }).error).toBeInstanceOf(Error); + }); + + it('should include both errors in thrown error cause', async () => { + // Arrange + const client = new RobotClient(); + const webrtcError = new Error('WebRTC connection failed'); + + vi.mocked(rpcModule.dialWebRTC).mockRejectedValue(webrtcError); + + // Act + let caughtError: Error | undefined; + try { + await client.dial({ + ...baseDialConfig, + noReconnect: true, + }); + } catch (error) { + caughtError = error as Error; + } + + // Assert + expect(caughtError).toBeDefined(); + expect(caughtError).toBeInstanceOf(Error); + expect(caughtError!.message).toBe('Failed to connect via all methods'); + expect(caughtError!.cause).toBeDefined(); + expect(Array.isArray(caughtError!.cause)).toBe(true); + const causes = caughtError!.cause as Error[]; + expect(causes).toHaveLength(2); + expect(causes[0]).toBe(webrtcError); + expect(causes[1]).toBeInstanceOf(Error); + }); + + it('should convert non-Error objects to Errors before throwing', async () => { + // Arrange + const client = new RobotClient(); + const webrtcError = 'string error'; + + vi.mocked(rpcModule.dialWebRTC).mockRejectedValue(webrtcError); + + // Act + let caughtError: Error | undefined; + try { + await client.dial({ + ...baseDialConfig, + noReconnect: true, + }); + } catch (error) { + caughtError = error as Error; + } + + // Assert + expect(caughtError).toBeDefined(); + expect(caughtError).toBeInstanceOf(Error); + expect(caughtError!.cause).toBeDefined(); + expect(Array.isArray(caughtError!.cause)).toBe(true); + const causes = caughtError!.cause as Error[]; + expect(causes.length).toBeGreaterThan(0); + const [firstCause] = causes; + expect(firstCause).toBeInstanceOf(Error); + expect(firstCause?.message).toBe('string error'); + }); + + it('should fallback to gRPC when WebRTC fails and emit WebRTC error', async () => { + // Arrange + const client = new RobotClient(); + const webrtcError = new Error('WebRTC connection failed'); + const { events, setupListener } = captureDisconnectedEvents(); + + setupListener(client); + vi.mocked(rpcModule.dialWebRTC).mockRejectedValue(webrtcError); + vi.mocked(rpcModule.dialDirect).mockResolvedValue( + createMockRobotServiceTransport() + ); + + // Act + const result = await client.dial({ + ...baseDialConfig, + host: TEST_LOCAL_HOST, + noReconnect: true, + }); + + // Assert + expect(result).toBe(client); + expect(events.length).toBeGreaterThanOrEqual(1); + const webrtcEvent = findEventWithError( + events, + 'WebRTC connection failed' + ); + expect(webrtcEvent).toBeDefined(); + }); + + it('should return client instance when only gRPC connection is used', async () => { + // Arrange + const client = new RobotClient(); + vi.mocked(rpcModule.dialDirect).mockResolvedValue( + createMockRobotServiceTransport() + ); + + // Act + const result = await client.dial({ + host: TEST_LOCAL_HOST, + noReconnect: true, + }); + + // Assert + expect(result).toBe(client); + }); + }); }); diff --git a/src/robot/client.ts b/src/robot/client.ts index 97ed36656..ee98a48c4 100644 --- a/src/robot/client.ts +++ b/src/robot/client.ts @@ -631,14 +631,20 @@ export class RobotClient extends EventDispatcher implements Robot { : conf.reconnectMaxAttempts; this.currentRetryAttempt = 0; + let webRTCError: Error | undefined; + let directError: Error | undefined; // Try to dial via WebRTC first. if (isDialWebRTCConf(conf) && !conf.reconnectAbortSignal?.abort) { try { return await backOff(async () => this.dialWebRTC(conf), backOffOpts); - } catch { + } catch (error) { + webRTCError = error instanceof Error ? error : new Error(String(error)); // eslint-disable-next-line no-console - console.debug('Failed to connect via WebRTC'); + console.debug('Failed to connect via WebRTC', webRTCError); + this.emit(MachineConnectionEvent.DISCONNECTED, { + error: webRTCError, + }); } } @@ -647,12 +653,22 @@ export class RobotClient extends EventDispatcher implements Robot { if (!conf.reconnectAbortSignal?.abort) { try { return await backOff(async () => this.dialDirect(conf), backOffOpts); - } catch { + } catch (error) { + directError = error instanceof Error ? error : new Error(String(error)); // eslint-disable-next-line no-console - console.debug('Failed to connect via gRPC'); + console.debug('Failed to connect via gRPC', directError); + this.emit(MachineConnectionEvent.DISCONNECTED, { + error: directError, + }); } } + if (webRTCError && directError) { + throw new Error('Failed to connect via all methods', { + cause: [webRTCError, directError], + }); + } + return this; } diff --git a/src/rpc/__tests__/dial.spec.ts b/src/rpc/__tests__/dial.spec.ts index 938b5b7f1..6e0c64fa8 100644 --- a/src/rpc/__tests__/dial.spec.ts +++ b/src/rpc/__tests__/dial.spec.ts @@ -27,8 +27,8 @@ import { } from '../../__mocks__/webrtc'; import { withICEServers } from '../__fixtures__/dial-webrtc-options'; import { createMockTransport } from '../../__mocks__/transports'; -import { createMockSignalingExchange } from '../__mocks__/signaling-exchanges'; import { ClientChannel } from '../client-channel'; +import type { Transport } from '@connectrpc/connect'; vi.mock('../peer'); vi.mock('../signaling-exchange'); @@ -52,15 +52,12 @@ const setupDialWebRTCMocks = () => { const peerConnection = createMockPeerConnection(); const dataChannel = createMockDataChannel(); const transport = createMockTransport(); - const signalingExchange = createMockSignalingExchange(transport); vi.mocked(newPeerConnectionForClient).mockResolvedValue({ pc: peerConnection, dc: dataChannel, }); - vi.mocked(SignalingExchange).mockImplementation(() => signalingExchange); - const optionalWebRTCConfigFn = vi.fn().mockResolvedValue({ config: { additionalIceServers: [], @@ -68,12 +65,20 @@ const setupDialWebRTCMocks = () => { }, }); - vi.mocked(createClient).mockReturnValue({ + const mockClient = { optionalWebRTCConfig: optionalWebRTCConfigFn, - } as unknown as ReturnType); + } as unknown as ReturnType; + vi.mocked(createClient).mockReturnValue(mockClient); vi.mocked(createGrpcWebTransport).mockReturnValue(transport); + const signalingExchange = { + doExchange: vi.fn().mockResolvedValue(transport), + terminate: vi.fn(), + } as unknown as SignalingExchange; + + vi.mocked(SignalingExchange).mockImplementation(() => signalingExchange); + return { peerConnection, dataChannel, @@ -207,21 +212,18 @@ describe('dialWebRTC', () => { expect(vi.mocked(peerConnection.close)).toHaveBeenCalled(); }); - it('should close peer connection if dialDirect fails', async () => { + it('should propagate error if transport creation fails', async () => { // Arrange - const { peerConnection, transport } = setupDialWebRTCMocks(); - // First call succeeds (getOptionalWebRTCConfig), second call fails (signaling) - vi.mocked(createGrpcWebTransport) - .mockReturnValueOnce(transport) - .mockImplementationOnce(() => { - throw new Error('Transport creation failed'); - }); + setupDialWebRTCMocks(); + vi.mocked(createGrpcWebTransport).mockImplementation(() => { + throw new Error('Transport creation failed'); + }); // Act & Assert await expect(dialWebRTC(TEST_URL, TEST_HOST)).rejects.toThrow( 'Transport creation failed' ); - expect(vi.mocked(peerConnection.close)).toHaveBeenCalled(); + expect(newPeerConnectionForClient).not.toHaveBeenCalled(); }); it('should rethrow errors after cleanup', async () => { @@ -327,6 +329,103 @@ describe('validateDialOptions', () => { }); }); +describe('resource management', () => { + it('should reuse a single transport for config fetching and signaling', async () => { + // Arrange + setupDialWebRTCMocks(); + + // Act + await dialWebRTC(TEST_URL, TEST_HOST); + + // Assert + expect(createGrpcWebTransport).toHaveBeenCalledTimes(1); + expect(createGrpcWebTransport).toHaveBeenCalledWith({ + baseUrl: TEST_URL, + credentials: 'same-origin', + }); + }); + + it('should reuse a single signaling client for config fetching and signaling', async () => { + // Arrange + setupDialWebRTCMocks(); + + // Act + await dialWebRTC(TEST_URL, TEST_HOST); + + // Assert + expect(createClient).toHaveBeenCalledTimes(1); + expect(createClient).toHaveBeenCalledWith( + expect.anything(), + expect.anything() + ); + }); + + it('should not leak transports on successful connection', async () => { + // Arrange + const { transport } = setupDialWebRTCMocks(); + const transportCount = { created: 0 }; + + vi.mocked(createGrpcWebTransport).mockImplementation(() => { + transportCount.created += 1; + return transport; + }); + + // Act + await dialWebRTC(TEST_URL, TEST_HOST); + + // Assert + expect(transportCount.created).toBe(1); + }); + + it('should not leak transports on connection failure', async () => { + // Arrange + const { transport, signalingExchange } = setupDialWebRTCMocks(); + const transportCount = { created: 0 }; + + vi.mocked(createGrpcWebTransport).mockImplementation(() => { + transportCount.created += 1; + return transport; + }); + + const error = new Error('Connection failed'); + vi.mocked(signalingExchange.doExchange).mockRejectedValueOnce(error); + + // Act + await dialWebRTC(TEST_URL, TEST_HOST).catch(() => { + // Ignore error for this test + }); + + // Assert + expect(transportCount.created).toBe(1); + }); + + it('should use the same transport reference for both config and signaling', async () => { + // Arrange + setupDialWebRTCMocks(); + const capturedTransports: Transport[] = []; + + vi.mocked(createClient).mockImplementation( + (_service, capturedTransport) => { + capturedTransports.push(capturedTransport); + return { + optionalWebRTCConfig: vi.fn().mockResolvedValue({ + config: { + additionalIceServers: [], + disableTrickle: false, + }, + }), + } as unknown as ReturnType; + } + ); + + // Act + await dialWebRTC(TEST_URL, TEST_HOST); + + // Assert + expect(capturedTransports.length).toBe(1); + }); +}); + describe('dialDirect', () => { afterEach(() => { vi.restoreAllMocks(); diff --git a/src/rpc/dial.ts b/src/rpc/dial.ts index d1f0bd2be..fb7318de2 100644 --- a/src/rpc/dial.ts +++ b/src/rpc/dial.ts @@ -309,20 +309,24 @@ export interface WebRTCConnection { dataChannel: RTCDataChannel; } -const getOptionalWebRTCConfig = async ( +const getSignalingClient = async ( signalingAddress: string, - callOpts: CallOptions, - dialOpts?: DialOptions, + signalingExchangeOpts: DialOptions | undefined, transportCredentialsInclude = false -): Promise => { - const optsCopy = { ...dialOpts } as DialOptions; - const directTransport = await dialDirect( +) => { + const transport = await dialDirect( signalingAddress, - optsCopy, + signalingExchangeOpts, transportCredentialsInclude ); - const signalingClient = createClient(SignalingService, directTransport); + return createClient(SignalingService, transport); +}; + +const getOptionalWebRTCConfig = async ( + callOpts: CallOptions, + signalingClient: ReturnType> +): Promise => { try { const resp = await signalingClient.optionalWebRTCConfig({}, callOpts); return resp.config ?? new WebRTCConfig(); @@ -363,18 +367,25 @@ export const dialWebRTC = async ( }; /** - * First complete our WebRTC options, gathering any extra information like - * TURN servers from a cloud server. + * First, derive options specifically for signaling against our target. Then + * complete our WebRTC options, gathering any extra information like TURN + * servers from a cloud server. This also creates the transport and signaling + * client that we'll reuse to avoid resource leaks. */ - const webrtcOpts = await processWebRTCOpts( + const exchangeOpts = processSignalingExchangeOpts( usableSignalingAddress, - callOpts, - dialOpts, - transportCredentialsInclude + dialOpts ); - // then derive options specifically for signaling against our target. - const exchangeOpts = processSignalingExchangeOpts( + + const signalingClient = await getSignalingClient( usableSignalingAddress, + exchangeOpts, + transportCredentialsInclude + ); + + const webrtcOpts = await processWebRTCOpts( + signalingClient, + callOpts, dialOpts ); @@ -385,21 +396,6 @@ export const dialWebRTC = async ( ); let successful = false; - let directTransport: Transport; - try { - directTransport = await dialDirect( - usableSignalingAddress, - exchangeOpts, - transportCredentialsInclude - ); - } catch (error) { - pc.close(); - dc.close(); - throw error; - } - - const signalingClient = createClient(SignalingService, directTransport); - const exchange = new SignalingExchange( signalingClient, callOpts, @@ -453,18 +449,11 @@ export const dialWebRTC = async ( }; const processWebRTCOpts = async ( - signalingAddress: string, + signalingClient: ReturnType>, callOpts: CallOptions, - dialOpts?: DialOptions, - transportCredentialsInclude = false + dialOpts: DialOptions | undefined ): Promise => { - // Get TURN servers, if any. - const config = await getOptionalWebRTCConfig( - signalingAddress, - callOpts, - dialOpts, - transportCredentialsInclude - ); + const config = await getOptionalWebRTCConfig(callOpts, signalingClient); const additionalIceServers: RTCIceServer[] = config.additionalIceServers.map( (ice) => { const iceUrls = [];