Skip to content

Commit

Permalink
fix(mitm): fix reusing sockets
Browse files Browse the repository at this point in the history
  • Loading branch information
blakebyrnes committed Mar 31, 2021
1 parent 3d7ee18 commit 5d56597
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 31 deletions.
1 change: 1 addition & 0 deletions mitm-socket/index.ts
Expand Up @@ -33,6 +33,7 @@ export default class MitmSocket extends TypedEventEmitter<{
public dialTime: Date;
public connectTime: Date;
public closeTime: Date;
public isReused = false;

public get pid(): number | undefined {
return this.child?.pid;
Expand Down
10 changes: 8 additions & 2 deletions mitm-socket/lib/connect.go
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"log"
"net"
"os"
"os/signal"
"syscall"
Expand Down Expand Up @@ -33,8 +34,8 @@ func main() {
debug := connectArgs.Debug

domainSocketPiper := &DomainSocketPiper{
Path: socketPath,
debug: connectArgs.Debug,
Path: socketPath,
debug: connectArgs.Debug,
keepAlive: connectArgs.KeepAlive,
}

Expand All @@ -47,6 +48,11 @@ func main() {
addr := fmt.Sprintf("%s:%s", connectArgs.Host, connectArgs.Port)
dialConn, err := Dial(addr, connectArgs)

tcpConn := dialConn.(*net.TCPConn)
if connectArgs.KeepAlive {
tcpConn.SetKeepAlive(true)
}

if err != nil {
log.Fatalf("Dial (proxy/remote) Error: %+v\n", err)
}
Expand Down
8 changes: 4 additions & 4 deletions mitm-socket/lib/domain_socket_piper.go
Expand Up @@ -49,7 +49,7 @@ func (piper *DomainSocketPiper) Pipe(remoteConn net.Conn, sigc chan os.Signal) {
}
}

copyUntilTimeout := func(dst io.Writer, src net.Conn, notifyChan chan error, counter *uint32) {
copy := func(dst io.Writer, src net.Conn, notifyChan chan error, counter *uint32) {
var readErr error
var writeErr error
var n int
Expand Down Expand Up @@ -83,7 +83,7 @@ func (piper *DomainSocketPiper) Pipe(remoteConn net.Conn, sigc chan os.Signal) {
}
return
}
time.Sleep(1 * time.Second)
time.Sleep(200 * time.Millisecond)
}
} else if n == 0 {
time.Sleep(200 * time.Millisecond)
Expand All @@ -92,8 +92,8 @@ func (piper *DomainSocketPiper) Pipe(remoteConn net.Conn, sigc chan os.Signal) {
}

// Pipe data
go copyUntilTimeout(remoteConn, piper.client, localNotify, &piper.completeCounter)
go copyUntilTimeout(piper.client, remoteConn, remoteNotify, &piper.completeCounter)
go copy(remoteConn, piper.client, localNotify, &piper.completeCounter)
go copy(piper.client, remoteConn, remoteNotify, &piper.completeCounter)

// Read until one of these errors occur
var err error
Expand Down
108 changes: 86 additions & 22 deletions mitm/lib/MitmRequestAgent.ts
Expand Up @@ -58,22 +58,16 @@ export default class MitmRequestAgent {
rejectUnauthorized: allowUnverifiedCertificates === false,
};

ctx.setState(ResourceState.GetSocket);
let mitmSocket = await this.getAvailableSocket(ctx, requestSettings);
if (!mitmSocket) {
mitmSocket = await this.waitForFreeSocket(ctx.url.origin);
}
MitmRequestContext.assignMitmSocket(ctx, mitmSocket);
await this.assignSocket(ctx, requestSettings);

ctx.cacheHandler.onRequest();
await HeadersHandler.modifyHeaders(ctx);

requestSettings.headers = ctx.requestHeaders;
requestSettings.createConnection = () => mitmSocket.socket;
requestSettings.agent = null;

if (ctx.isServerHttp2) {
HeadersHandler.prepareRequestHeadersForHttp2(ctx);
return this.http2Request(ctx, mitmSocket);
return this.http2Request(ctx);
}

return this.http1Request(ctx, requestSettings);
Expand All @@ -84,14 +78,16 @@ export default class MitmRequestAgent {
return;
}
const connectionHeader = ctx.responseHeaders?.Connection ?? ctx.responseHeaders?.connection;
const isCloseRequested = connectionHeader === 'close';
const isCloseRequested = connectionHeader !== 'keep-alive';

const socket = ctx.proxyToServerMitmSocket;

if (!socket.isReusable() || isCloseRequested) {
return socket.close();
}

socket.isReused = true;

const pool = this.getSocketPoolByOrigin(ctx.url.origin);
const pending = pool.pending.shift();
if (pending) {
Expand All @@ -118,6 +114,8 @@ export default class MitmRequestAgent {
this.sockets.clear();
}

/////// ////////// Socket Connection Management ///////////////////////////////////////////////////

private async createSocketConnection(
ctx: IMitmRequestContext,
options: RequestOptions,
Expand All @@ -132,6 +130,7 @@ export default class MitmRequestAgent {
ctx.setState(ResourceState.LookupDns);
const ipIfNeeded = await session.lookupDns(options.host);
ctx.dnsResolvedIp = ipIfNeeded || 'Not Found';

const mitmSocket = new MitmSocket(session.sessionId, {
host: ipIfNeeded || options.host,
port: String(options.port),
Expand Down Expand Up @@ -162,7 +161,18 @@ export default class MitmRequestAgent {
return mitmSocket;
}

/////// ////////// Socket Connection Management ///////////////////////////////////////////////////
private async assignSocket(
ctx: IMitmRequestContext,
requestSettings: RequestOptions,
): Promise<MitmSocket> {
ctx.setState(ResourceState.GetSocket);
let mitmSocket = await this.getAvailableSocket(ctx, requestSettings);
if (!mitmSocket) {
mitmSocket = await this.waitForFreeSocket(ctx.url.origin);
}
MitmRequestContext.assignMitmSocket(ctx, mitmSocket);
return mitmSocket;
}

private waitForFreeSocket(origin: string): Promise<MitmSocket> {
const socketPool = this.getSocketPoolByOrigin(origin);
Expand Down Expand Up @@ -257,22 +267,76 @@ export default class MitmRequestAgent {
});
}

private http1Request(
private async http1Request(
ctx: IMitmRequestContext,
requestSettings: http.RequestOptions,
): http.ClientRequest {
): Promise<http.ClientRequest> {
const httpModule = ctx.isSSL ? https : http;
ctx.setState(ResourceState.CreateProxyToServerRequest);
return httpModule.request(requestSettings);

let didHaveFlushErrors = false;

const request = httpModule.request({
...requestSettings,
createConnection: () => ctx.proxyToServerMitmSocket.socket,
agent: null,
});

function initError(error): void {
if (error.code === 'ECONNRESET') {
didHaveFlushErrors = true;
return;
}
log.info(`MitmHttpRequest.Http1SendRequestError`, {
sessionId: ctx.requestSession.sessionId,
request: requestSettings,
error,
});
}

request.once('error', initError);

let response: http.IncomingMessage;
request.once('response', x => {
response = x;
});
const rebroadcast = (event: string, handler: (result: any) => void): http.ClientRequest => {
if (event === 'response' && response) {
handler(response);
response = null;
}
// hand off to another fn
if (event === 'error') request.off('error', initError);
return request;
};
const originalOn = request.on.bind(request);
const originalOnce = request.once.bind(request);
request.on = function onOverride(event, handler): http.ClientRequest {
originalOn(event, handler);
return rebroadcast(event, handler);
};
request.once = function onOverride(event, handler): http.ClientRequest {
originalOnce(event, handler);
return rebroadcast(event, handler);
};

// if re-using, we need to make sure the connection can still be written to by probing it
if (ctx.proxyToServerMitmSocket.isReused) {
if (!request.headersSent) request.flushHeaders();
// give this 100 ms to flush (go is on a wait timer right now)
await new Promise(resolve => setTimeout(resolve, 100));
if (didHaveFlushErrors) {
await this.assignSocket(ctx, requestSettings);
return this.http1Request(ctx, requestSettings);
}
}
return request;
}

/////// ////////// Http2 helpers //////////////////////////////////////////////////////////////////

private http2Request(
ctx: IMitmRequestContext,
connectResult: MitmSocket,
): http2.ClientHttp2Stream {
const client = this.createHttp2Session(ctx, connectResult);
private http2Request(ctx: IMitmRequestContext): http2.ClientHttp2Stream {
const client = this.createHttp2Session(ctx);
ctx.setState(ResourceState.CreateProxyToServerRequest);
return client.request(ctx.requestHeaders, { waitForTrailers: true });
}
Expand Down Expand Up @@ -459,7 +523,7 @@ export default class MitmRequestAgent {
});
}

private createHttp2Session(ctx: IMitmRequestContext, mitmSocket: MitmSocket): ClientHttp2Session {
private createHttp2Session(ctx: IMitmRequestContext): ClientHttp2Session {
const origin = ctx.url.origin;
const existing = this.getHttp2Session(origin);
if (existing) return existing.client;
Expand All @@ -468,7 +532,7 @@ export default class MitmRequestAgent {

ctx.setState(ResourceState.CreateH2Session);
const proxyToServerH2Client = http2.connect(origin, {
createConnection: () => mitmSocket.socket,
createConnection: () => ctx.proxyToServerMitmSocket.socket,
});

proxyToServerH2Client.on('stream', this.onHttp2ServerToProxyPush.bind(this, ctx));
Expand Down Expand Up @@ -543,7 +607,7 @@ export default class MitmRequestAgent {
this.http2Sessions.push({
origin,
client: proxyToServerH2Client,
mitmSocket,
mitmSocket: ctx.proxyToServerMitmSocket,
});

return proxyToServerH2Client;
Expand Down
74 changes: 71 additions & 3 deletions mitm/test/MitmRequestAgent.test.ts
@@ -1,7 +1,11 @@
import { Helpers } from '@secret-agent/testing';
import { runHttpsServer } from '@secret-agent/testing/helpers';
import { getProxyAgent, runHttpsServer } from '@secret-agent/testing/helpers';
import * as WebSocket from 'ws';
import * as HttpProxyAgent from 'http-proxy-agent';
import { IncomingHttpHeaders, IncomingMessage } from 'http';
import * as net from 'net';
import { URL } from 'url';
import * as https from 'https';
import MitmServer from '../lib/MitmProxy';
import RequestSession from '../handlers/RequestSession';
import HeadersHandler from '../handlers/HeadersHandler';
Expand All @@ -21,6 +25,10 @@ beforeAll(() => {
});
});

beforeEach(() => {
process.env.MITM_ALLOW_INSECURE = 'false';
});

afterAll(Helpers.afterAll);
afterEach(Helpers.afterEach);

Expand All @@ -29,6 +37,7 @@ test('should create up to a max number of secure connections per origin', async
MitmRequestAgent.defaultMaxConnectionsPerOrigin = 2;
const server = await runHttpsServer((req, res) => {
remotePorts.push(req.connection.remotePort);
res.socket.setKeepAlive(true);
res.end('I am here');
});
const mitmServer = await startMitmServer();
Expand Down Expand Up @@ -57,7 +66,6 @@ test('should create up to a max number of secure connections per origin', async
promises.push(p);
}
await Promise.all(promises);
process.env.MITM_ALLOW_INSECURE = 'false';

expect(connectionsByOrigin[server.baseUrl].all.size).toBe(2);
await session.close();
Expand Down Expand Up @@ -98,7 +106,6 @@ test('should create new connections as needed when no keepalive', async () => {
promises.push(p);
}
await Promise.all(promises);
process.env.MITM_ALLOW_INSECURE = 'false';

// they all close after use, so should be gone now
expect(connectionsByOrigin[server.baseUrl].all.size).toBe(0);
Expand All @@ -108,6 +115,67 @@ test('should create new connections as needed when no keepalive', async () => {
expect(uniquePorts.size).toBe(4);
});

test('should be able to handle a reused socket that closes on server', async () => {
const server = await Helpers.runHttpsServer(async (req, res) => {
res.setHeader('Connection', 'keep-alive');

res.end('Looks good');
});
const mitmServer = await startMitmServer();

const session = createMitmSession();
const proxyCredentials = session.getProxyCredentials();
process.env.MITM_ALLOW_INSECURE = 'true';

{
let headers: IncomingHttpHeaders;
const response = await Helpers.httpRequest(
server.baseUrl,
'GET',
`http://localhost:${mitmServer.port}`,
proxyCredentials,
{
connection: 'keep-alive',
},
res => {
headers = res.headers;
},
);
expect(headers.connection).toBe('keep-alive');
expect(response).toBe('Looks good');
}

// node seems to default reset the connection at 5 seconds
await new Promise(resolve => setTimeout(resolve, 5e3));

{
const request = https.request({
host: 'localhost',
port: server.port,
method: 'GET',
path: '/',
headers: {
connection: 'keep-alive',
},
rejectUnauthorized: false,
agent: getProxyAgent(
new URL(server.baseUrl),
`http://localhost:${mitmServer.port}`,
proxyCredentials,
),
});
const responseP = new Promise<IncomingMessage>(resolve => request.on('response', resolve));
request.end();
const response = await responseP;
expect(response.headers.connection).toBe('keep-alive');
const body = [];
for await (const chunk of response) {
body.push(chunk.toString());
}
expect(body.join('')).toBe('Looks good');
}
});

test('it should not put upgrade connections in a pool', async () => {
const httpServer = await Helpers.runHttpServer();
const mitmServer = await startMitmServer();
Expand Down

0 comments on commit 5d56597

Please sign in to comment.