Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion deploy/example.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,28 @@ export const config: Config<ProviderKeys> = {
injectCost: true,
// credentials are used by the ProviderProxy to authenticate the forwarded request
credentials: env.OPENAI_API_KEY,
apiTypes: ['chat', 'responses'],
},
b: {
providerId: 'groq',
baseUrl: 'https://api.groq.com',
injectCost: true,
credentials: env.GROQ_API_KEY,
apiTypes: ['groq'],
},
b: { providerId: 'groq', baseUrl: 'https://api.groq.com', injectCost: true, credentials: env.GROQ_API_KEY },
c: {
providerId: 'google-vertex',
baseUrl: 'https://us-central1-aiplatform.googleapis.com',
injectCost: true,
credentials: env.GOOGLE_SERVICE_ACCOUNT_KEY,
apiTypes: ['gemini', 'anthropic'],
},
d: {
providerId: 'anthropic',
baseUrl: 'https://api.anthropic.com',
injectCost: true,
credentials: env.ANTHROPIC_API_KEY,
apiTypes: ['anthropic'],
},
},
// individual apiKeys
Expand Down
12 changes: 11 additions & 1 deletion deploy/test.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@ export const config: Config<ProviderKeys> = {
injectCost: true,
// credentials are used by the ProviderProxy to authenticate the forwarded request
credentials: env.OPENAI_API_KEY,
apiTypes: ['chat'],
},
groq: {
baseUrl: 'http://localhost:8005/groq',
providerId: 'groq',
injectCost: true,
credentials: env.GROQ_API_KEY,
apiTypes: ['groq'],
},
// google: {
// baseUrl:
Expand All @@ -51,14 +53,22 @@ export const config: Config<ProviderKeys> = {
providerId: 'anthropic',
injectCost: true,
credentials: env.ANTHROPIC_API_KEY,
apiTypes: ['anthropic'],
},
bedrock: {
baseUrl: 'http://localhost:8005/bedrock',
providerId: 'bedrock',
injectCost: true,
credentials: env.AWS_BEARER_TOKEN_BEDROCK,
apiTypes: ['anthropic', 'converse'],
},
test: {
baseUrl: 'http://test.example.com/test',
providerId: 'test',
injectCost: true,
credentials: 'test',
apiTypes: ['test'],
},
test: { baseUrl: 'http://test.example.com/test', providerId: 'test', injectCost: true, credentials: 'test' },
},
// individual apiKeys
apiKeys: {
Expand Down
2 changes: 1 addition & 1 deletion deploy/test/index.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ describe('deploy', () => {

const client = new OpenAI({
apiKey: 'healthy-key',
baseURL: 'https://example.com/openai',
baseURL: 'https://example.com/chat',
fetch: SELF.fetch.bind(SELF),
})

Expand Down
19 changes: 12 additions & 7 deletions gateway/src/gateway.ts
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also a logic that needs attention!

Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { currentScopeIntervals, type ExceededScope, endOfMonth, endOfWeek, type
import { OtelTrace } from './otel'
import { genAiOtelAttributes } from './otel/attributes'
import { getProvider } from './providers'
import { type ApiKeyInfo, guardProviderID, providerIdArray } from './types'
import { type ApiKeyInfo, apiTypesArray, guardAPIType } from './types'
import { runAfter, textResponse } from './utils'

export async function gateway(
Expand All @@ -14,14 +14,14 @@ export async function gateway(
ctx: ExecutionContext,
options: GatewayOptions,
): Promise<Response> {
const providerMatch = /^\/([^/]+)\/(.*)$/.exec(proxyPath)
if (!providerMatch) {
const apiTypeMatch = /^\/([^/]+)\/(.*)$/.exec(proxyPath)
if (!apiTypeMatch) {
return textResponse(404, 'Path not found')
}
const [, provider, restOfPath] = providerMatch as unknown as [string, string, string]
const [, apiType, restOfPath] = apiTypeMatch as unknown as [string, string, string]

if (!guardProviderID(provider)) {
return textResponse(400, `Invalid provider '${provider}', should be one of ${providerIdArray.join(', ')}`)
if (!guardAPIType(apiType)) {
return textResponse(400, `Invalid API type '${apiType}', should be one of ${apiTypesArray.join(', ')}`)
}

const apiKeyInfo = await apiKeyAuth(request, ctx, options)
Expand All @@ -30,7 +30,12 @@ export async function gateway(
return textResponse(403, `Unauthorized - Key ${apiKeyInfo.status}`)
}

let providerProxies = apiKeyInfo.providers.filter((p) => p.providerId === provider)
let providerProxies = apiKeyInfo.providers.filter((p) => p.apiTypes.includes(apiType))

const routingGroup = request.headers.get('pydantic-ai-gateway-routing-group')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason we can't use a query parameter here? I that would be easier to debug.

Also, I think paig-routing is a better name maybe?

if (routingGroup !== null) {
providerProxies = providerProxies.filter((p) => p.routingGroup === routingGroup)
}

const profile = request.headers.get('pydantic-ai-gateway-profile')
if (profile !== null) {
Expand Down
25 changes: 17 additions & 8 deletions gateway/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,22 @@ export interface ApiKeyInfo {
export type ProviderID = 'groq' | 'openai' | 'google-vertex' | 'anthropic' | 'test' | 'bedrock'
// TODO | 'azure' | 'fireworks' | 'mistral' | 'cohere'

const providerIds: Record<ProviderID, boolean> = {
groq: true,
openai: true,
'google-vertex': true,
export type APIType = 'chat' | 'responses' | 'converse' | 'anthropic' | 'gemini' | 'groq' | 'test'

const apiTypes: Record<APIType, boolean> = {
chat: true,
responses: true,
converse: true,
anthropic: true,
gemini: true,
groq: true,
test: true,
bedrock: true,
}

export const providerIdArray = Object.keys(providerIds).filter((id) => id !== 'test') as ProviderID[]
export const apiTypesArray = Object.keys(apiTypes) as APIType[]

export function guardProviderID(id: string): id is ProviderID {
return id in providerIds
export function guardAPIType(type: string): type is APIType {
return type in apiTypes
}

export interface ProviderProxy {
Expand All @@ -67,6 +70,12 @@ export interface ProviderProxy {
priority?: number
/** @disableKey: weather to disable the key in case of error, if missing defaults to True. */
disableKey?: boolean

/** @apiTypes: the APIs that the provider supports. Example: ['chat', 'responses'] */
apiTypes: APIType[]
/** @routingGroups: a grouping of APIs that serve the same models.
* @example: 'anthropic' would route the requests to Anthropic, Bedrock and Vertex AI. */
routingGroup?: string
}

export interface OtelSettings {
Expand Down
14 changes: 7 additions & 7 deletions gateway/test/gateway.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ import { buildGatewayEnv, type DisableEvent, IDS } from './worker'

describe('invalid request', () => {
test('401 on no auth header', async ({ gateway }) => {
const response = await gateway.fetch('https://example.com/openai/gpt-5')
const response = await gateway.fetch('https://example.com/chat/gpt-5')
const text = await response.text()
expect(response.status, `got ${response.status} response: ${text}`).toBe(401)
expect(text).toMatchInlineSnapshot(`"Unauthorized - Missing Authorization Header"`)
})
test('401 on unknown auth header', async ({ gateway }) => {
const response = await gateway.fetch('https://example.com/openai/gpt-5', {
const response = await gateway.fetch('https://example.com/chat/gpt-5', {
headers: { Authorization: 'unknown-token' },
})
const text = await response.text()
Expand All @@ -35,7 +35,7 @@ describe('invalid request', () => {
const text = await response.text()
expect(response.status, `got ${response.status} response: ${text}`).toBe(400)
expect(text).toMatchInlineSnapshot(
`"Invalid provider 'wrong', should be one of groq, openai, google-vertex, anthropic, bedrock"`,
`"Invalid API type 'wrong', should be one of chat, responses, converse, anthropic, gemini, groq, test"`,
)
})
})
Expand Down Expand Up @@ -66,7 +66,7 @@ describe('key status', () => {
test('should block request if key is disabled', async ({ gateway }) => {
const { fetch } = gateway

const response = await fetch('https://example.com/openai/xxx', { headers: { Authorization: 'disabled' } })
const response = await fetch('https://example.com/chat/xxx', { headers: { Authorization: 'disabled' } })
const text = await response.text()
expect(response.status, `got response: ${response.status} ${text}`).toBe(403)
expect(text).toMatchInlineSnapshot(`"Unauthorized - Key disabled"`)
Expand Down Expand Up @@ -132,7 +132,7 @@ describe('key status', () => {
expect(Math.abs(keyStatusQuery.results[0]!.expiresAtDiff - disableEvents[0]!.expirationTtl!)).toBeLessThan(2)

{
const response = await fetch('https://example.com/openai/xxx', { headers: { Authorization: 'tiny-limit' } })
const response = await fetch('https://example.com/chat/xxx', { headers: { Authorization: 'tiny-limit' } })
const text = await response.text()
expect(response.status, `got ${response.status} response: ${text}`).toBe(403)
expect(text).toMatchInlineSnapshot(`"Unauthorized - Key limit-exceeded"`)
Expand Down Expand Up @@ -215,7 +215,7 @@ describe('custom proxyPrefixLength', () => {
const disableEvents: DisableEvent[] = []
const mockFetch = mockFetchFactory(disableEvents)

const client = new OpenAI({ apiKey: 'healthy', baseURL: 'https://example.com/proxy/openai', fetch: mockFetch })
const client = new OpenAI({ apiKey: 'healthy', baseURL: 'https://example.com/proxy/chat', fetch: mockFetch })

const completion = await client.chat.completions.create({
model: 'gpt-5',
Expand Down Expand Up @@ -249,7 +249,7 @@ describe('custom middleware', () => {
)[] = []

const ctx = createExecutionContext()
const request = new Request<unknown, IncomingRequestCfProperties>('https://example.com/openai/gpt-5', {
const request = new Request<unknown, IncomingRequestCfProperties>('https://example.com/chat/gpt-5', {
headers: { Authorization: 'healthy' },
})

Expand Down
3 changes: 3 additions & 0 deletions gateway/test/gateway.spec.ts.snap
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ exports[`key status > should change key status if limit is exceeded > kv-value 1
"projectSpendingLimitMonthly": 4,
"providers": [
{
"apiTypes": [
"test",
],
"baseUrl": "http://test.example.com/test",
"credentials": "test",
"injectCost": true,
Expand Down
2 changes: 1 addition & 1 deletion gateway/test/providers/bedrock.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ describe('bedrock', () => {
test('should call bedrock via gateway', async ({ gateway }) => {
const { fetch, otelBatch } = gateway

const result = await fetch('https://example.com/bedrock/model/amazon.nova-micro-v1%3A0/converse', {
const result = await fetch('https://example.com/converse/model/amazon.nova-micro-v1%3A0/converse', {
method: 'POST',
headers: { 'Content-Type': 'application/json', Authorization: 'healthy' },
body: JSON.stringify({
Expand Down
8 changes: 4 additions & 4 deletions gateway/test/providers/openai.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ describe('openai', () => {
test('openai chat', async ({ gateway }) => {
const { fetch, otelBatch } = gateway

const client = new OpenAI({ apiKey: 'healthy', baseURL: 'https://example.com/openai', fetch })
const client = new OpenAI({ apiKey: 'healthy', baseURL: 'https://example.com/chat', fetch })

const completion = await client.chat.completions.create({
model: 'gpt-5',
Expand Down Expand Up @@ -103,7 +103,7 @@ describe('openai', () => {
test('openai responses', async ({ gateway }) => {
const { fetch, otelBatch } = gateway

const client = new OpenAI({ apiKey: 'healthy', baseURL: 'https://example.com/openai', fetch })
const client = new OpenAI({ apiKey: 'healthy', baseURL: 'https://example.com/chat', fetch })

const completion = await client.responses.create({
model: 'gpt-5',
Expand All @@ -118,7 +118,7 @@ describe('openai', () => {
test('openai responses with builtin tools', async ({ gateway }) => {
const { fetch, otelBatch } = gateway

const client = new OpenAI({ apiKey: 'healthy', baseURL: 'https://example.com/openai', fetch })
const client = new OpenAI({ apiKey: 'healthy', baseURL: 'https://example.com/chat', fetch })

const completion = await client.responses.create({
model: 'gpt-5',
Expand Down Expand Up @@ -150,7 +150,7 @@ describe('openai', () => {
test('openai chat stream', async ({ gateway }) => {
const { fetch, otelBatch } = gateway

const client = new OpenAI({ apiKey: 'healthy', baseURL: 'https://example.com/openai', fetch })
const client = new OpenAI({ apiKey: 'healthy', baseURL: 'https://example.com/chat', fetch })

const stream = await client.chat.completions.create({
stream: true,
Expand Down
2 changes: 1 addition & 1 deletion gateway/test/providers/openai.spec.ts.snap
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ exports[`openai > openai chat stream > span 1`] = `
{
"key": "url.full",
"value": {
"stringValue": "https://example.com/openai/chat/completions",
"stringValue": "https://example.com/chat/chat/completions",
},
},
{
Expand Down
19 changes: 17 additions & 2 deletions gateway/test/worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,13 @@ class TestKeysDB extends KeysDbD1 {
super(env.limitsDB)
this.disableEvents = disableEvents
this.allProviders = [
{ baseUrl: 'http://test.example.com/test', providerId: 'test', injectCost: true, credentials: 'test' },
{
baseUrl: 'http://test.example.com/test',
providerId: 'test',
injectCost: true,
credentials: 'test',
apiTypes: ['test'],
},
{
// baseUrl decides what URL the request will be forwarded to
baseUrl: 'http://localhost:8005/openai',
Expand All @@ -70,19 +76,28 @@ class TestKeysDB extends KeysDbD1 {
injectCost: true,
// credentials are used by the ProviderProxy to authenticate the forwarded request
credentials: env.OPENAI_API_KEY,
apiTypes: ['chat'],
},
{
baseUrl: 'http://localhost:8005/groq',
providerId: 'groq',
injectCost: true,
credentials: env.GROQ_API_KEY,
apiTypes: ['groq'],
},
{ baseUrl: 'http://localhost:8005/groq', providerId: 'groq', injectCost: true, credentials: env.GROQ_API_KEY },
{
baseUrl: 'http://localhost:8005/anthropic',
providerId: 'anthropic',
injectCost: true,
credentials: env.ANTHROPIC_API_KEY,
apiTypes: ['anthropic'],
},
{
baseUrl: 'http://localhost:8005/bedrock',
providerId: 'bedrock',
injectCost: true,
credentials: env.AWS_BEARER_TOKEN_BEDROCK,
apiTypes: ['anthropic', 'converse'],
},
]
}
Expand Down