Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion apps/browser-proxy/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ AWS_REGION=us-east-1
LOGFLARE_SOURCE_URL="<logflare-source-url>"
# enable PROXY protocol support
#PROXIED=true
WILDCARD_DOMAIN=browser.staging.db.build
SUPABASE_URL="<supabase-url>"
SUPABASE_ANON_KEY="<supabase-anon-key>"
WILDCARD_DOMAIN=browser.staging.db.build
2 changes: 2 additions & 0 deletions apps/browser-proxy/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
},
"dependencies": {
"@aws-sdk/client-s3": "^3.645.0",
"@supabase/supabase-js": "^2.45.4",
"debug": "^4.3.7",
"expiry-map": "^2.0.0",
"findhit-proxywrap": "^0.3.13",
"nanoid": "^5.0.7",
"p-memoize": "^7.1.1",
"pg-gateway": "^0.3.0-beta.3",
"ws": "^8.18.0"
Expand Down
53 changes: 53 additions & 0 deletions apps/browser-proxy/src/connection-manager.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import type { PostgresConnection } from 'pg-gateway'
import type { WebSocket } from 'ws'

type DatabaseId = string
type ConnectionId = string

class ConnectionManager {
private socketsByDatabase: Map<DatabaseId, ConnectionId> = new Map()
private sockets: Map<ConnectionId, PostgresConnection> = new Map()
private websockets: Map<DatabaseId, WebSocket> = new Map()

constructor() {}

public hasSocketForDatabase(databaseId: DatabaseId) {
return this.socketsByDatabase.has(databaseId)
}

public getSocket(connectionId: ConnectionId) {
return this.sockets.get(connectionId)
}

public setSocket(databaseId: DatabaseId, connectionId: ConnectionId, socket: PostgresConnection) {
this.sockets.set(connectionId, socket)
this.socketsByDatabase.set(databaseId, connectionId)
}

public deleteSocketForDatabase(databaseId: DatabaseId) {
const connectionId = this.socketsByDatabase.get(databaseId)
this.socketsByDatabase.delete(databaseId)
if (connectionId) {
this.sockets.delete(connectionId)
}
}

public hasWebsocket(databaseId: DatabaseId) {
return this.websockets.has(databaseId)
}

public getWebsocket(databaseId: DatabaseId) {
return this.websockets.get(databaseId)
}

public setWebsocket(databaseId: DatabaseId, websocket: WebSocket) {
this.websockets.set(databaseId, websocket)
}

public deleteWebsocket(databaseId: DatabaseId) {
this.websockets.delete(databaseId)
this.deleteSocketForDatabase(databaseId)
}
}

export const connectionManager = new ConnectionManager()
17 changes: 12 additions & 5 deletions apps/browser-proxy/src/create-message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ export function createStartupMessage(
user: string,
database: string,
additionalParams: Record<string, string> = {}
): ArrayBuffer {
): Uint8Array {
const encoder = new TextEncoder()

// Protocol version number (3.0)
Expand All @@ -22,9 +22,8 @@ export function createStartupMessage(
}
messageLength += 1 // Null terminator

const message = new ArrayBuffer(4 + messageLength)
const view = new DataView(message)
const uint8Array = new Uint8Array(message)
const uint8Array = new Uint8Array(4 + messageLength)
const view = new DataView(uint8Array.buffer)

let offset = 0
view.setInt32(offset, messageLength + 4, false) // Total message length (including itself)
Expand All @@ -44,5 +43,13 @@ export function createStartupMessage(

uint8Array.set([0], offset) // Final null terminator

return message
return uint8Array
}

export function createTerminateMessage(): Uint8Array {
const uint8Array = new Uint8Array(5)
const view = new DataView(uint8Array.buffer)
view.setUint8(0, 'X'.charCodeAt(0))
view.setUint32(1, 4, false)
return uint8Array
}
5 changes: 5 additions & 0 deletions apps/browser-proxy/src/debug.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import createDebug from 'debug'

createDebug.formatters.e = (fn) => fn()

export const debug = createDebug('browser-proxy')
198 changes: 25 additions & 173 deletions apps/browser-proxy/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,179 +1,12 @@
import * as nodeNet from 'node:net'
import * as https from 'node:https'
import { BackendError, PostgresConnection } from 'pg-gateway'
import { fromNodeSocket } from 'pg-gateway/node'
import { WebSocketServer, type WebSocket } from 'ws'
import makeDebug from 'debug'
import { extractDatabaseId, isValidServername } from './servername.ts'
import { getTls, setSecureContext } from './tls.ts'
import { createStartupMessage } from './create-message.ts'
import { extractIP } from './extract-ip.ts'
import {
DatabaseShared,
DatabaseUnshared,
logEvent,
UserConnected,
UserDisconnected,
} from './telemetry.ts'
import { httpsServer } from './websocket-server.ts'
import { tcpServer } from './tcp-server.ts'

const debug = makeDebug('browser-proxy')

const tcpConnections = new Map<string, PostgresConnection>()
const websocketConnections = new Map<string, WebSocket>()

const httpsServer = https.createServer({
SNICallback: (servername, callback) => {
debug('SNICallback', servername)
if (isValidServername(servername)) {
debug('SNICallback', 'valid')
callback(null)
} else {
debug('SNICallback', 'invalid')
callback(new Error('invalid SNI'))
}
},
process.on('unhandledRejection', (reason, promise) => {
console.error({ location: 'unhandledRejection', reason, promise })
})
await setSecureContext(httpsServer)
// reset the secure context every week to pick up any new TLS certificates
setInterval(() => setSecureContext(httpsServer), 1000 * 60 * 60 * 24 * 7)

const websocketServer = new WebSocketServer({
server: httpsServer,
})

websocketServer.on('error', (error) => {
debug('websocket server error', error)
})

websocketServer.on('connection', (socket, request) => {
debug('websocket connection')

const host = request.headers.host

if (!host) {
debug('No host header present')
socket.close()
return
}

const databaseId = extractDatabaseId(host)

if (websocketConnections.has(databaseId)) {
socket.send('sorry, too many clients already')
socket.close()
return
}

websocketConnections.set(databaseId, socket)

logEvent(new DatabaseShared({ databaseId }))

socket.on('message', (data: Buffer) => {
debug('websocket message', data.toString('hex'))
const tcpConnection = tcpConnections.get(databaseId)
tcpConnection?.streamWriter?.write(data)
})

socket.on('close', () => {
websocketConnections.delete(databaseId)
logEvent(new DatabaseUnshared({ databaseId }))
})
})

// we need to use proxywrap to make our tcp server to enable the PROXY protocol support
const net = (
process.env.PROXIED ? (await import('findhit-proxywrap')).default.proxy(nodeNet) : nodeNet
) as typeof nodeNet

const tcpServer = net.createServer()

tcpServer.on('connection', async (socket) => {
let databaseId: string | undefined

const connection = await fromNodeSocket(socket, {
tls: getTls,
onTlsUpgrade(state) {
if (!state.tlsInfo?.serverName || !isValidServername(state.tlsInfo.serverName)) {
throw BackendError.create({
code: '08006',
message: 'invalid SNI',
severity: 'FATAL',
})
}

const _databaseId = extractDatabaseId(state.tlsInfo.serverName!)

if (!websocketConnections.has(_databaseId!)) {
throw BackendError.create({
code: 'XX000',
message: 'the browser is not sharing the database',
severity: 'FATAL',
})
}

if (tcpConnections.has(_databaseId)) {
throw BackendError.create({
code: '53300',
message: 'sorry, too many clients already',
severity: 'FATAL',
})
}

// only set the databaseId after we've verified the connection
databaseId = _databaseId
tcpConnections.set(databaseId!, connection)
logEvent(new UserConnected({ databaseId }))
},
serverVersion() {
return '16.3'
},
onAuthenticated() {
const websocket = websocketConnections.get(databaseId!)

if (!websocket) {
throw BackendError.create({
code: 'XX000',
message: 'the browser is not sharing the database',
severity: 'FATAL',
})
}

const clientIpMessage = createStartupMessage('postgres', 'postgres', {
client_ip: extractIP(socket.remoteAddress!),
})
websocket.send(clientIpMessage)
},
onMessage(message, state) {
if (!state.isAuthenticated) {
return
}

const websocket = websocketConnections.get(databaseId!)

if (!websocket) {
throw BackendError.create({
code: 'XX000',
message: 'the browser is not sharing the database',
severity: 'FATAL',
})
}

debug('tcp message', { message })
websocket.send(message)

// return an empty buffer to indicate that the message has been handled
return new Uint8Array()
},
})

socket.on('close', () => {
if (databaseId) {
tcpConnections.delete(databaseId)
logEvent(new UserDisconnected({ databaseId }))
const websocket = websocketConnections.get(databaseId)
websocket?.send(createStartupMessage('postgres', 'postgres', { client_ip: '' }))
}
})
process.on('uncaughtException', (error) => {
console.error({ location: 'uncaughtException', error })
})

httpsServer.listen(443, () => {
Expand All @@ -183,3 +16,22 @@ httpsServer.listen(443, () => {
tcpServer.listen(5432, () => {
console.log('tcp server listening on port 5432')
})

const shutdown = async () => {
await Promise.allSettled([
new Promise<void>((res) =>
httpsServer.close(() => {
res()
})
),
new Promise<void>((res) =>
tcpServer.close(() => {
res()
})
),
])
process.exit(0)
}

process.on('SIGTERM', shutdown)
process.on('SIGINT', shutdown)
2 changes: 2 additions & 0 deletions apps/browser-proxy/src/pg-dump-middleware/constants.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
export const VECTOR_OID = 99999
export const FIRST_NORMAL_OID = 16384
Loading