diff --git a/gateway/src/gateway.ts b/gateway/src/gateway.ts index efcdae2..550ef92 100644 --- a/gateway/src/gateway.ts +++ b/gateway/src/gateway.ts @@ -11,7 +11,8 @@ import { textResponse } from './utils' export async function gateway(request: Request, ctx: ExecutionContext, env: GatewayEnv): Promise { const { pathname } = new URL(request.url) - const providerMatch = /^\/([^/]+)\/(.*)$/.exec(pathname) + const proxyRegex = env.proxyRegex ?? /^\/(.+?)\/(.*)$/ + const providerMatch = proxyRegex.exec(pathname) if (!providerMatch) { return textResponse(404, 'Path not found') } diff --git a/gateway/src/index.ts b/gateway/src/index.ts index b0019fe..cec99bf 100644 --- a/gateway/src/index.ts +++ b/gateway/src/index.ts @@ -30,6 +30,8 @@ export interface GatewayEnv { kv: KVNamespace kvVersion: string subFetch: SubFetch + /** proxyRegex: defaults to `/^\/(.+?)\/(.*)$/`, e.g. proxy at the root */ + proxyRegex?: RegExp } export async function gatewayFetch(request: Request, ctx: ExecutionContext, env: GatewayEnv): Promise { diff --git a/gateway/test/gateway.spec.ts b/gateway/test/gateway.spec.ts new file mode 100644 index 0000000..5859ebb --- /dev/null +++ b/gateway/test/gateway.spec.ts @@ -0,0 +1,224 @@ +import { createExecutionContext, env, waitOnExecutionContext } from 'cloudflare:test' +import { gatewayFetch, LimitDbD1 } from '@pydantic/ai-gateway' +import OpenAI from 'openai' +import { describe, expect, it } from 'vitest' +import { test } from './setup' +import { buildGatewayEnv, type DisableEvent, IDS } from './worker' + +describe('invalid request', () => { + test('401 on no auth header', async ({ gateway }) => { + const response = await gateway.fetch('https://example.com/openai/gpt-5') + const text = await response.text() + expect(response.status, `got ${response.status} response: ${text}`).toBe(401) + expect(text).toMatchInlineSnapshot(`"Unauthorized - Missing Authorization Header"`) + }) + test('401 on unknown auth header', async ({ gateway }) => { + const response = await gateway.fetch('https://example.com/openai/gpt-5', { + headers: { Authorization: 'unknown-token' }, + }) + const text = await response.text() + expect(response.status, `got ${response.status} response: ${text}`).toBe(401) + expect(text).toMatchInlineSnapshot(`"Unauthorized - Key not found"`) + }) + test('400 on unknown provider', async ({ gateway }) => { + const response = await gateway.fetch('https://example.com/wrong/gpt-5', { + headers: { Authorization: 'unknown-token' }, + }) + const text = await response.text() + expect(response.status, `got ${response.status} response: ${text}`).toBe(400) + expect(text).toMatchInlineSnapshot( + `"Invalid provider 'wrong', should be one of groq, openai, google-vertex, anthropic, bedrock"`, + ) + }) +}) + +describe('key status', () => { + test('should not change key status if limit is not exceeded', async ({ gateway }) => { + const { fetch } = gateway + const client = new OpenAI({ apiKey: 'healthy', baseURL: 'https://example.com/test', fetch }) + await client.chat.completions.create({ + model: 'gpt-5', + messages: [ + { role: 'developer', content: 'You are a helpful assistant.' }, + { role: 'user', content: 'Give me an essay on the history of the universe.' }, + ], + }) + const allSpends = await env.limitsDB + .prepare( + `SELECT entityId, entityType, scope, round(spend, 3) spend, spendingLimit FROM spend order by spendingLimit`, + ) + .run<{ entityId: number; entityType: number; scope: number; spend: string; spendingLimit: number }>() + expect(allSpends.results).toMatchSnapshot('spend-table') + const allKeyStatus = await env.limitsDB + .prepare('SELECT count(*) as count FROM keyStatus') + .first<{ count: number }>() + expect(allKeyStatus?.count).toBe(0) + }) + + test('should block request if key is disabled', async ({ gateway }) => { + const { fetch } = gateway + + const response = await fetch('https://example.com/openai/xxx', { headers: { Authorization: 'disabled' } }) + const text = await response.text() + expect(response.status, `got response: ${response.status} ${text}`).toBe(403) + expect(text).toMatchInlineSnapshot(`"Unauthorized - Key disabled"`) + + const spendCount = await env.limitsDB.prepare('SELECT count(*) count FROM spend').first<{ count: number }>() + expect(spendCount?.count).toBe(0) + const keyStatusCount = await env.limitsDB + .prepare('SELECT count(*) count FROM keyStatus') + .first<{ count: number }>() + expect(keyStatusCount?.count).toBe(0) + }) + + test('should change key status if limit is exceeded', async ({ gateway }) => { + const { fetch, disableEvents } = gateway + + expect(disableEvents).toEqual([]) + + const client = new OpenAI({ apiKey: 'tiny-limit', baseURL: 'https://example.com/test', fetch }) + await client.chat.completions.create({ + model: 'gpt-5', + messages: [ + { role: 'developer', content: 'You are a helpful assistant.' }, + { role: 'user', content: 'Give me an essay on the history of the universe.' }, + ], + }) + + const apiValue = await env.KV.get('apiKeyAuth:test:tiny-limit') + expect(apiValue).toBeTypeOf('string') + expect(JSON.parse(apiValue!)).toMatchSnapshot('kv-value') + + const allSpends1 = await env.limitsDB + .prepare( + `SELECT entityId, entityType, scope, round(spend, 3) spend, spendingLimit FROM spend order by spendingLimit`, + ) + .run<{ entityId: number; entityType: number; scope: number; spend: string; spendingLimit: number }>() + expect(allSpends1.results).toMatchInlineSnapshot(` + [ + { + "entityId": 5, + "entityType": 3, + "scope": 1, + "spend": 0.018, + "spendingLimit": 0.01, + }, + { + "entityId": 1, + "entityType": 1, + "scope": 3, + "spend": 0.018, + "spendingLimit": 4, + }, + ] + `) + + expect(disableEvents).toEqual([ + { + id: IDS.keyTinyLimit, + reason: 'limits exceeded: key-daily', + newStatus: 'limit-exceeded', + expirationTtl: expect.any(Number), + }, + ]) + expect(disableEvents[0]!.expirationTtl).toBeGreaterThanOrEqual(0) + expect(disableEvents[0]!.expirationTtl).toBeLessThanOrEqual(86400) + + const keyStatusQuery = await env.limitsDB + .prepare("SELECT id, status, strftime('%s', expiresAt) - strftime('%s','now') as expiresAtDiff FROM keyStatus") + .run<{ id: string; status: string; expiresAtDiff: number }>() + expect(keyStatusQuery.results).toEqual([ + { id: IDS.keyTinyLimit, status: 'limit-exceeded', expiresAtDiff: expect.any(Number) }, + ]) + expect(Math.abs(keyStatusQuery.results[0]!.expiresAtDiff - disableEvents[0]!.expirationTtl!)).toBeLessThan(2) + + { + const response = await fetch('https://example.com/openai/xxx', { headers: { Authorization: 'tiny-limit' } }) + const text = await response.text() + expect(response.status, `got ${response.status} response: ${text}`).toBe(403) + expect(text).toMatchInlineSnapshot(`"Unauthorized - Key limit-exceeded"`) + } + + expect(disableEvents).toEqual([ + { + id: IDS.keyTinyLimit, + reason: 'limits exceeded: key-daily', + newStatus: 'limit-exceeded', + expirationTtl: expect.any(Number), + }, + ]) + }) +}) + +describe('LimitDbD1', () => { + it('updates limit', async () => { + const db = new LimitDbD1(env.limitsDB) + await db.incrementSpend( + [{ entityId: IDS.userDefault, entityType: 'user', scope: 'daily', scopeInterval: 123, limit: 2 }], + 1, + ) + + { + const state = await env.limitsDB.prepare('SELECT * FROM spend').first() + expect(state).toMatchInlineSnapshot(` + { + "entityId": 2, + "entityType": 2, + "scope": 1, + "scopeInterval": 123, + "spend": 1, + "spendingLimit": 2, + } + `) + } + + await db.updateUserLimits(IDS.userDefault, { daily: 3, weekly: 5 }) + + { + const state = await env.limitsDB.prepare('SELECT * FROM spend').first() + expect(state).toMatchInlineSnapshot(` + { + "entityId": 2, + "entityType": 2, + "scope": 1, + "scopeInterval": 123, + "spend": 1, + "spendingLimit": 3, + } + `) + } + }) +}) + +describe('custom proxyRegex', () => { + it('proxyRegex', async () => { + const ctx = createExecutionContext() + const disableEvents: DisableEvent[] = [] + + async function mockFetch(url: RequestInfo | URL, init?: RequestInit): Promise { + const request = new Request( + url, + init as RequestInit, + ) + const response = await gatewayFetch( + request, + ctx, + buildGatewayEnv(env, disableEvents, fetch, /^\/proxy\/([^/]+)\/(.*)$/), + ) + await waitOnExecutionContext(ctx) + return response + } + + const client = new OpenAI({ apiKey: 'healthy', baseURL: 'https://example.com/proxy/openai', fetch: mockFetch }) + + const completion = await client.chat.completions.create({ + model: 'gpt-5', + messages: [ + { role: 'developer', content: 'You are a helpful assistant.' }, + { role: 'user', content: 'What is the capital of France?' }, + ], + max_completion_tokens: 1024, + }) + expect(completion).toMatchSnapshot('proxyRegex') + }) +}) diff --git a/gateway/test/index.spec.ts.snap b/gateway/test/gateway.spec.ts.snap similarity index 52% rename from gateway/test/index.spec.ts.snap rename to gateway/test/gateway.spec.ts.snap index b774fde..f5614b6 100644 --- a/gateway/test/index.spec.ts.snap +++ b/gateway/test/gateway.spec.ts.snap @@ -1,5 +1,46 @@ // Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html +exports[`custom proxyRegex > proxyRegex > proxyRegex 1`] = ` +{ + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "annotations": [], + "content": "Paris.", + "refusal": null, + "role": "assistant", + }, + }, + ], + "created": 1758119097, + "id": "chatcmpl-CGnNZyR6FU8Xsw5X4N2YdTtaYBuWV", + "model": "gpt-5-2025-08-07", + "object": "chat.completion", + "service_tier": "default", + "system_fingerprint": null, + "usage": { + "completion_tokens": 11, + "completion_tokens_details": { + "accepted_prediction_tokens": 0, + "audio_tokens": 0, + "reasoning_tokens": 0, + "rejected_prediction_tokens": 0, + }, + "prompt_tokens": 23, + "prompt_tokens_details": { + "audio_tokens": 0, + "cached_tokens": 0, + }, + "pydantic_ai_gateway": { + "cost_estimate": 0.00013875, + }, + "total_tokens": 34, + }, +} +`; + exports[`key status > should change key status if limit is exceeded > kv-value 1`] = ` { "id": 5, diff --git a/gateway/test/index.spec.ts b/gateway/test/index.spec.ts index 2add7d3..4fdc387 100644 --- a/gateway/test/index.spec.ts +++ b/gateway/test/index.spec.ts @@ -1,9 +1,5 @@ -import { env } from 'cloudflare:test' -import { LimitDbD1 } from '@pydantic/ai-gateway' -import OpenAI from 'openai' -import { describe, expect, it } from 'vitest' +import { describe, expect } from 'vitest' import { test } from './setup' -import { IDS } from './worker' describe('index', () => { test('responds with index html', async ({ gateway }) => { @@ -27,188 +23,3 @@ describe('index', () => { expect(response.headers.get('content-type')).toBe('text/plain; charset=utf-8') }) }) - -describe('invalid request', () => { - test('401 on no auth header', async ({ gateway }) => { - const response = await gateway.fetch('https://example.com/openai/gpt-5') - const text = await response.text() - expect(response.status, `got ${response.status} response: ${text}`).toBe(401) - expect(text).toMatchInlineSnapshot(`"Unauthorized - Missing Authorization Header"`) - }) - test('401 on unknown auth header', async ({ gateway }) => { - const response = await gateway.fetch('https://example.com/openai/gpt-5', { - headers: { Authorization: 'unknown-token' }, - }) - const text = await response.text() - expect(response.status, `got ${response.status} response: ${text}`).toBe(401) - expect(text).toMatchInlineSnapshot(`"Unauthorized - Key not found"`) - }) - test('400 on unknown provider', async ({ gateway }) => { - const response = await gateway.fetch('https://example.com/wrong/gpt-5', { - headers: { Authorization: 'unknown-token' }, - }) - const text = await response.text() - expect(response.status, `got ${response.status} response: ${text}`).toBe(400) - expect(text).toMatchInlineSnapshot( - `"Invalid provider 'wrong', should be one of groq, openai, google-vertex, anthropic, bedrock"`, - ) - }) -}) - -describe('key status', () => { - test('should not change key status if limit is not exceeded', async ({ gateway }) => { - const { fetch } = gateway - const client = new OpenAI({ apiKey: 'healthy', baseURL: 'https://example.com/test', fetch }) - await client.chat.completions.create({ - model: 'gpt-5', - messages: [ - { role: 'developer', content: 'You are a helpful assistant.' }, - { role: 'user', content: 'Give me an essay on the history of the universe.' }, - ], - }) - const allSpends = await env.limitsDB - .prepare( - `SELECT entityId, entityType, scope, round(spend, 3) spend, spendingLimit FROM spend order by spendingLimit`, - ) - .run<{ entityId: number; entityType: number; scope: number; spend: string; spendingLimit: number }>() - expect(allSpends.results).toMatchSnapshot('spend-table') - const allKeyStatus = await env.limitsDB - .prepare('SELECT count(*) as count FROM keyStatus') - .first<{ count: number }>() - expect(allKeyStatus?.count).toBe(0) - }) - - test('should block request if key is disabled', async ({ gateway }) => { - const { fetch } = gateway - - const response = await fetch('https://example.com/openai/xxx', { headers: { Authorization: 'disabled' } }) - const text = await response.text() - expect(response.status, `got response: ${response.status} ${text}`).toBe(403) - expect(text).toMatchInlineSnapshot(`"Unauthorized - Key disabled"`) - - const spendCount = await env.limitsDB.prepare('SELECT count(*) count FROM spend').first<{ count: number }>() - expect(spendCount?.count).toBe(0) - const keyStatusCount = await env.limitsDB - .prepare('SELECT count(*) count FROM keyStatus') - .first<{ count: number }>() - expect(keyStatusCount?.count).toBe(0) - }) - - test('should change key status if limit is exceeded', async ({ gateway }) => { - const { fetch, disableEvents } = gateway - - expect(disableEvents).toEqual([]) - - const client = new OpenAI({ apiKey: 'tiny-limit', baseURL: 'https://example.com/test', fetch }) - await client.chat.completions.create({ - model: 'gpt-5', - messages: [ - { role: 'developer', content: 'You are a helpful assistant.' }, - { role: 'user', content: 'Give me an essay on the history of the universe.' }, - ], - }) - - const apiValue = await env.KV.get('apiKeyAuth:test:tiny-limit') - expect(apiValue).toBeTypeOf('string') - expect(JSON.parse(apiValue!)).toMatchSnapshot('kv-value') - - const allSpends1 = await env.limitsDB - .prepare( - `SELECT entityId, entityType, scope, round(spend, 3) spend, spendingLimit FROM spend order by spendingLimit`, - ) - .run<{ entityId: number; entityType: number; scope: number; spend: string; spendingLimit: number }>() - expect(allSpends1.results).toMatchInlineSnapshot(` - [ - { - "entityId": 5, - "entityType": 3, - "scope": 1, - "spend": 0.018, - "spendingLimit": 0.01, - }, - { - "entityId": 1, - "entityType": 1, - "scope": 3, - "spend": 0.018, - "spendingLimit": 4, - }, - ] - `) - - expect(disableEvents).toEqual([ - { - id: IDS.keyTinyLimit, - reason: 'limits exceeded: key-daily', - newStatus: 'limit-exceeded', - expirationTtl: expect.any(Number), - }, - ]) - expect(disableEvents[0]!.expirationTtl).toBeGreaterThanOrEqual(0) - expect(disableEvents[0]!.expirationTtl).toBeLessThanOrEqual(86400) - - const keyStatusQuery = await env.limitsDB - .prepare("SELECT id, status, strftime('%s', expiresAt) - strftime('%s','now') as expiresAtDiff FROM keyStatus") - .run<{ id: string; status: string; expiresAtDiff: number }>() - expect(keyStatusQuery.results).toEqual([ - { id: IDS.keyTinyLimit, status: 'limit-exceeded', expiresAtDiff: expect.any(Number) }, - ]) - expect(Math.abs(keyStatusQuery.results[0]!.expiresAtDiff - disableEvents[0]!.expirationTtl!)).toBeLessThan(2) - - { - const response = await fetch('https://example.com/openai/xxx', { headers: { Authorization: 'tiny-limit' } }) - const text = await response.text() - expect(response.status, `got ${response.status} response: ${text}`).toBe(403) - expect(text).toMatchInlineSnapshot(`"Unauthorized - Key limit-exceeded"`) - } - - expect(disableEvents).toEqual([ - { - id: IDS.keyTinyLimit, - reason: 'limits exceeded: key-daily', - newStatus: 'limit-exceeded', - expirationTtl: expect.any(Number), - }, - ]) - }) -}) - -describe('LimitDbD1', () => { - it('updates limit', async () => { - const db = new LimitDbD1(env.limitsDB) - await db.incrementSpend( - [{ entityId: IDS.userDefault, entityType: 'user', scope: 'daily', scopeInterval: 123, limit: 2 }], - 1, - ) - - { - const state = await env.limitsDB.prepare('SELECT * FROM spend').first() - expect(state).toMatchInlineSnapshot(` - { - "entityId": 2, - "entityType": 2, - "scope": 1, - "scopeInterval": 123, - "spend": 1, - "spendingLimit": 2, - } - `) - } - - await db.updateUserLimits(IDS.userDefault, { daily: 3, weekly: 5 }) - - { - const state = await env.limitsDB.prepare('SELECT * FROM spend').first() - expect(state).toMatchInlineSnapshot(` - { - "entityId": 2, - "entityType": 2, - "scope": 1, - "scopeInterval": 123, - "spend": 1, - "spendingLimit": 3, - } - `) - } - }) -}) diff --git a/gateway/test/worker.ts b/gateway/test/worker.ts index ea689d3..e5f4fe3 100644 --- a/gateway/test/worker.ts +++ b/gateway/test/worker.ts @@ -21,7 +21,12 @@ export interface DisableEvent { expirationTtl?: number } -export function buildGatewayEnv(env: Env, disableEvents: DisableEvent[], subFetch: SubFetch): GatewayEnv { +export function buildGatewayEnv( + env: Env, + disableEvents: DisableEvent[], + subFetch: SubFetch, + proxyRegex?: RegExp, +): GatewayEnv { return { githubSha: 'test', keysDb: new TestKeysDB(env, disableEvents), @@ -29,6 +34,7 @@ export function buildGatewayEnv(env: Env, disableEvents: DisableEvent[], subFetc kv: env.KV, kvVersion: 'test', subFetch, + proxyRegex, } }