Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion gateway/src/gateway.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ import { textResponse } from './utils'

export async function gateway(request: Request, ctx: ExecutionContext, env: GatewayEnv): Promise<Response> {
const { pathname } = new URL(request.url)
const providerMatch = /^\/([^/]+)\/(.*)$/.exec(pathname)
const proxyRegex = env.proxyRegex ?? /^\/(.+?)\/(.*)$/
const providerMatch = proxyRegex.exec(pathname)
if (!providerMatch) {
return textResponse(404, 'Path not found')
}
Expand Down
2 changes: 2 additions & 0 deletions gateway/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ export interface GatewayEnv {
kv: KVNamespace
kvVersion: string
subFetch: SubFetch
/** proxyRegex: defaults to `/^\/(.+?)\/(.*)$/`, e.g. proxy at the root */
proxyRegex?: RegExp
}

export async function gatewayFetch(request: Request, ctx: ExecutionContext, env: GatewayEnv): Promise<Response> {
Expand Down
224 changes: 224 additions & 0 deletions gateway/test/gateway.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
import { createExecutionContext, env, waitOnExecutionContext } from 'cloudflare:test'
import { gatewayFetch, LimitDbD1 } from '@pydantic/ai-gateway'
import OpenAI from 'openai'
import { describe, expect, it } from 'vitest'
import { test } from './setup'
import { buildGatewayEnv, type DisableEvent, IDS } from './worker'

describe('invalid request', () => {
test('401 on no auth header', async ({ gateway }) => {
const response = await gateway.fetch('https://example.com/openai/gpt-5')
const text = await response.text()
expect(response.status, `got ${response.status} response: ${text}`).toBe(401)
expect(text).toMatchInlineSnapshot(`"Unauthorized - Missing Authorization Header"`)
})
test('401 on unknown auth header', async ({ gateway }) => {
const response = await gateway.fetch('https://example.com/openai/gpt-5', {
headers: { Authorization: 'unknown-token' },
})
const text = await response.text()
expect(response.status, `got ${response.status} response: ${text}`).toBe(401)
expect(text).toMatchInlineSnapshot(`"Unauthorized - Key not found"`)
})
test('400 on unknown provider', async ({ gateway }) => {
const response = await gateway.fetch('https://example.com/wrong/gpt-5', {
headers: { Authorization: 'unknown-token' },
})
const text = await response.text()
expect(response.status, `got ${response.status} response: ${text}`).toBe(400)
expect(text).toMatchInlineSnapshot(
`"Invalid provider 'wrong', should be one of groq, openai, google-vertex, anthropic, bedrock"`,
)
})
})

describe('key status', () => {
test('should not change key status if limit is not exceeded', async ({ gateway }) => {
const { fetch } = gateway
const client = new OpenAI({ apiKey: 'healthy', baseURL: 'https://example.com/test', fetch })
await client.chat.completions.create({
model: 'gpt-5',
messages: [
{ role: 'developer', content: 'You are a helpful assistant.' },
{ role: 'user', content: 'Give me an essay on the history of the universe.' },
],
})
const allSpends = await env.limitsDB
.prepare(
`SELECT entityId, entityType, scope, round(spend, 3) spend, spendingLimit FROM spend order by spendingLimit`,
)
.run<{ entityId: number; entityType: number; scope: number; spend: string; spendingLimit: number }>()
expect(allSpends.results).toMatchSnapshot('spend-table')
const allKeyStatus = await env.limitsDB
.prepare('SELECT count(*) as count FROM keyStatus')
.first<{ count: number }>()
expect(allKeyStatus?.count).toBe(0)
})

test('should block request if key is disabled', async ({ gateway }) => {
const { fetch } = gateway

const response = await fetch('https://example.com/openai/xxx', { headers: { Authorization: 'disabled' } })
const text = await response.text()
expect(response.status, `got response: ${response.status} ${text}`).toBe(403)
expect(text).toMatchInlineSnapshot(`"Unauthorized - Key disabled"`)

const spendCount = await env.limitsDB.prepare('SELECT count(*) count FROM spend').first<{ count: number }>()
expect(spendCount?.count).toBe(0)
const keyStatusCount = await env.limitsDB
.prepare('SELECT count(*) count FROM keyStatus')
.first<{ count: number }>()
expect(keyStatusCount?.count).toBe(0)
})

test('should change key status if limit is exceeded', async ({ gateway }) => {
const { fetch, disableEvents } = gateway

expect(disableEvents).toEqual([])

const client = new OpenAI({ apiKey: 'tiny-limit', baseURL: 'https://example.com/test', fetch })
await client.chat.completions.create({
model: 'gpt-5',
messages: [
{ role: 'developer', content: 'You are a helpful assistant.' },
{ role: 'user', content: 'Give me an essay on the history of the universe.' },
],
})

const apiValue = await env.KV.get('apiKeyAuth:test:tiny-limit')
expect(apiValue).toBeTypeOf('string')
expect(JSON.parse(apiValue!)).toMatchSnapshot('kv-value')

const allSpends1 = await env.limitsDB
.prepare(
`SELECT entityId, entityType, scope, round(spend, 3) spend, spendingLimit FROM spend order by spendingLimit`,
)
.run<{ entityId: number; entityType: number; scope: number; spend: string; spendingLimit: number }>()
expect(allSpends1.results).toMatchInlineSnapshot(`
[
{
"entityId": 5,
"entityType": 3,
"scope": 1,
"spend": 0.018,
"spendingLimit": 0.01,
},
{
"entityId": 1,
"entityType": 1,
"scope": 3,
"spend": 0.018,
"spendingLimit": 4,
},
]
`)

expect(disableEvents).toEqual([
{
id: IDS.keyTinyLimit,
reason: 'limits exceeded: key-daily',
newStatus: 'limit-exceeded',
expirationTtl: expect.any(Number),
},
])
expect(disableEvents[0]!.expirationTtl).toBeGreaterThanOrEqual(0)
expect(disableEvents[0]!.expirationTtl).toBeLessThanOrEqual(86400)

const keyStatusQuery = await env.limitsDB
.prepare("SELECT id, status, strftime('%s', expiresAt) - strftime('%s','now') as expiresAtDiff FROM keyStatus")
.run<{ id: string; status: string; expiresAtDiff: number }>()
expect(keyStatusQuery.results).toEqual([
{ id: IDS.keyTinyLimit, status: 'limit-exceeded', expiresAtDiff: expect.any(Number) },
])
expect(Math.abs(keyStatusQuery.results[0]!.expiresAtDiff - disableEvents[0]!.expirationTtl!)).toBeLessThan(2)

{
const response = await fetch('https://example.com/openai/xxx', { headers: { Authorization: 'tiny-limit' } })
const text = await response.text()
expect(response.status, `got ${response.status} response: ${text}`).toBe(403)
expect(text).toMatchInlineSnapshot(`"Unauthorized - Key limit-exceeded"`)
}

expect(disableEvents).toEqual([
{
id: IDS.keyTinyLimit,
reason: 'limits exceeded: key-daily',
newStatus: 'limit-exceeded',
expirationTtl: expect.any(Number),
},
])
})
})

describe('LimitDbD1', () => {
it('updates limit', async () => {
const db = new LimitDbD1(env.limitsDB)
await db.incrementSpend(
[{ entityId: IDS.userDefault, entityType: 'user', scope: 'daily', scopeInterval: 123, limit: 2 }],
1,
)

{
const state = await env.limitsDB.prepare('SELECT * FROM spend').first()
expect(state).toMatchInlineSnapshot(`
{
"entityId": 2,
"entityType": 2,
"scope": 1,
"scopeInterval": 123,
"spend": 1,
"spendingLimit": 2,
}
`)
}

await db.updateUserLimits(IDS.userDefault, { daily: 3, weekly: 5 })

{
const state = await env.limitsDB.prepare('SELECT * FROM spend').first()
expect(state).toMatchInlineSnapshot(`
{
"entityId": 2,
"entityType": 2,
"scope": 1,
"scopeInterval": 123,
"spend": 1,
"spendingLimit": 3,
}
`)
}
})
})

describe('custom proxyRegex', () => {
it('proxyRegex', async () => {
const ctx = createExecutionContext()
const disableEvents: DisableEvent[] = []

async function mockFetch(url: RequestInfo | URL, init?: RequestInit): Promise<Response> {
const request = new Request<unknown, IncomingRequestCfProperties>(
url,
init as RequestInit<IncomingRequestCfProperties>,
)
const response = await gatewayFetch(
request,
ctx,
buildGatewayEnv(env, disableEvents, fetch, /^\/proxy\/([^/]+)\/(.*)$/),
)
await waitOnExecutionContext(ctx)
return response
}

const client = new OpenAI({ apiKey: 'healthy', baseURL: 'https://example.com/proxy/openai', fetch: mockFetch })

const completion = await client.chat.completions.create({
model: 'gpt-5',
messages: [
{ role: 'developer', content: 'You are a helpful assistant.' },
{ role: 'user', content: 'What is the capital of France?' },
],
max_completion_tokens: 1024,
})
expect(completion).toMatchSnapshot('proxyRegex')
})
})
Original file line number Diff line number Diff line change
@@ -1,5 +1,46 @@
// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html

exports[`custom proxyRegex > proxyRegex > proxyRegex 1`] = `
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"message": {
"annotations": [],
"content": "Paris.",
"refusal": null,
"role": "assistant",
},
},
],
"created": 1758119097,
"id": "chatcmpl-CGnNZyR6FU8Xsw5X4N2YdTtaYBuWV",
"model": "gpt-5-2025-08-07",
"object": "chat.completion",
"service_tier": "default",
"system_fingerprint": null,
"usage": {
"completion_tokens": 11,
"completion_tokens_details": {
"accepted_prediction_tokens": 0,
"audio_tokens": 0,
"reasoning_tokens": 0,
"rejected_prediction_tokens": 0,
},
"prompt_tokens": 23,
"prompt_tokens_details": {
"audio_tokens": 0,
"cached_tokens": 0,
},
"pydantic_ai_gateway": {
"cost_estimate": 0.00013875,
},
"total_tokens": 34,
},
}
`;

exports[`key status > should change key status if limit is exceeded > kv-value 1`] = `
{
"id": 5,
Expand Down
Loading