Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions apps/browser-proxy/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,22 @@ httpsServer.listen(443, () => {
tcpServer.listen(5432, () => {
console.log('tcp server listening on port 5432')
})

const shutdown = async () => {
await Promise.allSettled([
new Promise<void>((res) =>
httpsServer.close(() => {
res()
})
),
new Promise<void>((res) =>
tcpServer.close(() => {
res()
})
),
])
process.exit(0)
}

process.on('SIGTERM', shutdown)
process.on('SIGINT', shutdown)
2 changes: 2 additions & 0 deletions apps/browser-proxy/src/pg-dump-middleware/constants.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
export const VECTOR_OID = 99999
export const FIRST_NORMAL_OID = 16384
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import { VECTOR_OID } from './constants.ts'
import { parseDataRowFields, parseRowDescription } from './utils.ts'

export function isGetExtensionMembershipQuery(message: Uint8Array): boolean {
// Check if it's a SimpleQuery message (starts with 'Q')
if (message[0] !== 0x51) {
// 'Q' in ASCII
return false
}

const query =
"SELECT classid, objid, refobjid FROM pg_depend WHERE refclassid = 'pg_extension'::regclass AND deptype = 'e' ORDER BY 3"

// Skip the message type (1 byte) and message length (4 bytes)
const messageString = new TextDecoder().decode(message.slice(5))

// Trim any trailing null character
const trimmedMessage = messageString.replace(/\0+$/, '')

// Check if the message exactly matches the query
return trimmedMessage === query
}

export function patchGetExtensionMembershipResult(data: Uint8Array, vectorOid: string): Uint8Array {
let offset = 0
const messages: Uint8Array[] = []
let isDependencyTable = false
let objidIndex = -1
let refobjidIndex = -1
let patchedRowCount = 0
let totalRowsProcessed = 0

const expectedColumns = ['classid', 'objid', 'refobjid']

while (offset < data.length) {
const messageType = data[offset]
const messageLength = new DataView(data.buffer, data.byteOffset + offset + 1, 4).getUint32(
0,
false
)
const message = data.subarray(offset, offset + messageLength + 1)

if (messageType === 0x54) {
// RowDescription
const columnNames = parseRowDescription(message)
isDependencyTable =
columnNames.length === 3 && columnNames.every((col) => expectedColumns.includes(col))
if (isDependencyTable) {
objidIndex = columnNames.indexOf('objid')
refobjidIndex = columnNames.indexOf('refobjid')
}
} else if (messageType === 0x44 && isDependencyTable) {
// DataRow
const fields = parseDataRowFields(message)
totalRowsProcessed++

if (fields.length === 3) {
const refobjid = fields[refobjidIndex]!.value

if (refobjid === vectorOid) {
const patchedMessage = patchDependencyRow(message, refobjidIndex)
messages.push(patchedMessage)
patchedRowCount++
offset += messageLength + 1
continue
}
}
}

messages.push(message)
offset += messageLength + 1
}

return new Uint8Array(
messages.reduce((acc, val) => {
const combined = new Uint8Array(acc.length + val.length)
combined.set(acc)
combined.set(val, acc.length)
return combined
}, new Uint8Array())
)
}

function patchDependencyRow(message: Uint8Array, refobjidIndex: number): Uint8Array {
const newArray = new Uint8Array(message)
let offset = 7 // Start after message type (1 byte), message length (4 bytes), and field count (2 bytes)

// Navigate to the refobjid field
for (let i = 0; i < refobjidIndex; i++) {
const fieldLength = new DataView(newArray.buffer, offset, 4).getInt32(0)
offset += 4 // Skip the length field
if (fieldLength > 0) {
offset += fieldLength // Skip the field value
}
}

// Now we're at the start of the refobjid field
const refobjidLength = new DataView(newArray.buffer, offset, 4).getInt32(0)
offset += 4 // Move past the length field

const encoder = new TextEncoder()

// Write the new OID value
const newRefobjidBytes = encoder.encode(VECTOR_OID.toString().padStart(refobjidLength, '0'))
newArray.set(newRefobjidBytes, offset)

return newArray
}
125 changes: 125 additions & 0 deletions apps/browser-proxy/src/pg-dump-middleware/get-extensions-query.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import { VECTOR_OID } from './constants.ts'
import { parseDataRowFields, parseRowDescription } from './utils.ts'

export function isGetExtensionsQuery(message: Uint8Array): boolean {
// Check if it's a SimpleQuery message (starts with 'Q')
if (message[0] !== 0x51) {
// 'Q' in ASCII
return false
}

const query =
'SELECT x.tableoid, x.oid, x.extname, n.nspname, x.extrelocatable, x.extversion, x.extconfig, x.extcondition FROM pg_extension x JOIN pg_namespace n ON n.oid = x.extnamespace'

// Skip the message type (1 byte) and message length (4 bytes)
const messageString = new TextDecoder().decode(message.slice(5))

// Trim any trailing null character
const trimmedMessage = messageString.replace(/\0+$/, '')

// Check if the message exactly matches the query
return trimmedMessage === query
}

export function patchGetExtensionsResult(data: Uint8Array) {
let offset = 0
const messages: Uint8Array[] = []
let isVectorExtensionTable = false
let oidColumnIndex = -1
let extnameColumnIndex = -1
let vectorOid: string | null = null

const expectedColumns = [
'tableoid',
'oid',
'extname',
'nspname',
'extrelocatable',
'extversion',
'extconfig',
'extcondition',
]

while (offset < data.length) {
const messageType = data[offset]
const messageLength = new DataView(data.buffer, data.byteOffset + offset + 1, 4).getUint32(
0,
false
)

const message = data.subarray(offset, offset + messageLength + 1)

if (messageType === 0x54) {
// RowDescription
const columnNames = parseRowDescription(message)

isVectorExtensionTable =
columnNames.length === expectedColumns.length &&
columnNames.every((col) => expectedColumns.includes(col))

if (isVectorExtensionTable) {
oidColumnIndex = columnNames.indexOf('oid')
extnameColumnIndex = columnNames.indexOf('extname')
}
} else if (messageType === 0x44 && isVectorExtensionTable) {
// DataRow
const fields = parseDataRowFields(message)
if (fields[extnameColumnIndex]?.value === 'vector') {
vectorOid = fields[oidColumnIndex]!.value!
const patchedMessage = patchOidField(message, oidColumnIndex, fields)
messages.push(patchedMessage)
offset += messageLength + 1
continue
}
}

messages.push(message)
offset += messageLength + 1
}

return {
message: Buffer.concat(messages),
vectorOid,
}
}

function patchOidField(
message: Uint8Array,
oidIndex: number,
fields: { value: string | null; length: number }[]
): Uint8Array {
const oldOidField = fields[oidIndex]!
const newOid = VECTOR_OID.toString().padStart(oldOidField.length, '0')

const newArray = new Uint8Array(message)

let offset = 7 // Start after message type (1 byte), message length (4 bytes), and field count (2 bytes)

// Navigate to the OID field
for (let i = 0; i < oidIndex; i++) {
const fieldLength = new DataView(newArray.buffer, offset, 4).getInt32(0)
offset += 4 // Skip the length field
if (fieldLength > 0) {
offset += fieldLength // Skip the field value
}
}

// Now we're at the start of the OID field
const oidLength = new DataView(newArray.buffer, offset, 4).getInt32(0)
offset += 4 // Move past the length field

// Ensure the new OID fits in the allocated space
if (newOid.length !== oidLength) {
console.warn(
`New OID length (${newOid.length}) doesn't match the original length (${oidLength}). Skipping patch.`
)
return message
}

// Write the new OID value
for (let i = 0; i < oidLength; i++) {
newArray[offset + i] = newOid.charCodeAt(i)
}

return newArray
}
111 changes: 111 additions & 0 deletions apps/browser-proxy/src/pg-dump-middleware/pg-dump-middleware.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import type { ClientParameters } from 'pg-gateway'
import { isGetExtensionsQuery, patchGetExtensionsResult } from './get-extensions-query.ts'
import {
isGetExtensionMembershipQuery,
patchGetExtensionMembershipResult,
} from './get-extension-membership-query.ts'
import { FIRST_NORMAL_OID } from './constants.ts'
import type { Socket } from 'node:net'

type ConnectionId = string

type State =
| { step: 'wait-for-get-extensions-query' }
| { step: 'get-extensions-query-received' }
| { step: 'wait-for-get-extension-membership-query'; vectorOid: string }
| { step: 'get-extension-membership-query-received'; vectorOid: string }
| { step: 'complete' }

/**
* Middleware to patch pg_dump results for PGlite < v0.2.8
* PGlite < v0.2.8 has a bug in which userland extensions are not dumped because their oid is lower than FIRST_NORMAL_OID
* This middleware patches the results of the get_extensions and get_extension_membership queries to increase the oid of the `vector` extension so it can be dumped
* For more context, see: https://github.com/electric-sql/pglite/issues/352
*/
class PgDumpMiddleware {
private state: Map<ConnectionId, State> = new Map()

constructor() {}

client(
socket: Socket,
connectionId: string,
context: {
clientParams?: ClientParameters
},
message: Uint8Array
) {
if (context.clientParams?.application_name !== 'pg_dump') {
return message
}

if (!this.state.has(connectionId)) {
this.state.set(connectionId, { step: 'wait-for-get-extensions-query' })
socket.on('close', () => {
this.state.delete(connectionId)
})
}

const connectionState = this.state.get(connectionId)!

switch (connectionState.step) {
case 'wait-for-get-extensions-query':
// https://github.com/postgres/postgres/blob/a19f83f87966f763991cc76404f8e42a36e7e842/src/bin/pg_dump/pg_dump.c#L5834-L5837
if (isGetExtensionsQuery(message)) {
this.state.set(connectionId, { step: 'get-extensions-query-received' })
}
break
case 'wait-for-get-extension-membership-query':
// https://github.com/postgres/postgres/blob/a19f83f87966f763991cc76404f8e42a36e7e842/src/bin/pg_dump/pg_dump.c#L18173-L18178
if (isGetExtensionMembershipQuery(message)) {
this.state.set(connectionId, {
step: 'get-extension-membership-query-received',
vectorOid: connectionState.vectorOid,
})
}
break
}

return message
}

server(
connectionId: string,
context: {
clientParams?: ClientParameters
},
message: Uint8Array
) {
if (context.clientParams?.application_name !== 'pg_dump' || !this.state.has(connectionId)) {
return message
}

const connectionState = this.state.get(connectionId)!

switch (connectionState.step) {
case 'get-extensions-query-received':
const patched = patchGetExtensionsResult(message)
if (patched.vectorOid) {
if (parseInt(patched.vectorOid) >= FIRST_NORMAL_OID) {
this.state.set(connectionId, {
step: 'complete',
})
} else {
this.state.set(connectionId, {
step: 'wait-for-get-extension-membership-query',
vectorOid: patched.vectorOid,
})
}
}
return patched.message
case 'get-extension-membership-query-received':
const patchedMessage = patchGetExtensionMembershipResult(message, connectionState.vectorOid)
this.state.set(connectionId, { step: 'complete' })
return patchedMessage
default:
return message
}
}
}

export const pgDumpMiddleware = new PgDumpMiddleware()
Loading