Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 9 additions & 19 deletions apps/sim/app/api/mcp/oauth/callback/route.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,9 @@ import {
import { NextRequest } from 'next/server'
import { beforeEach, describe, expect, it, vi } from 'vitest'

const { mockMcpAuth, mockCreateSsrfGuardedMcpFetch, mockGuardedFetch, mockDiscoverServerTools } =
vi.hoisted(() => ({
mockMcpAuth: vi.fn(),
mockCreateSsrfGuardedMcpFetch: vi.fn(),
mockGuardedFetch: vi.fn(),
mockDiscoverServerTools: vi.fn(),
}))
const { mockDiscoverServerTools } = vi.hoisted(() => ({
mockDiscoverServerTools: vi.fn(),
}))

vi.mock('@sim/db', () => dbChainMock)
vi.mock('@sim/db/schema', () => schemaMock)
Expand All @@ -28,13 +24,7 @@ vi.mock('drizzle-orm', () => ({
eq: vi.fn(),
isNull: vi.fn(),
}))
vi.mock('@modelcontextprotocol/sdk/client/auth.js', () => ({
auth: mockMcpAuth,
}))
vi.mock('@/lib/mcp/oauth', () => mcpOauthMock)
vi.mock('@/lib/mcp/pinned-fetch', () => ({
createSsrfGuardedMcpFetch: mockCreateSsrfGuardedMcpFetch,
}))
vi.mock('@/lib/mcp/service', () => ({
mcpService: { discoverServerTools: mockDiscoverServerTools },
}))
Expand All @@ -45,7 +35,6 @@ describe('MCP OAuth callback route', () => {
beforeEach(() => {
vi.clearAllMocks()
resetDbChainMock()
mockCreateSsrfGuardedMcpFetch.mockReturnValue(mockGuardedFetch)
authMockFns.mockGetSession.mockResolvedValue({ user: { id: 'user-1' } })
mcpOauthMockFns.mockLoadOauthRowByState.mockResolvedValue({
id: 'oauth-row-1',
Expand All @@ -61,24 +50,25 @@ describe('MCP OAuth callback route', () => {
},
])
mcpOauthMockFns.mockLoadPreregisteredClient.mockResolvedValue(undefined)
mockMcpAuth.mockResolvedValue('AUTHORIZED')
mcpOauthMockFns.mockMcpAuthGuarded.mockResolvedValue('AUTHORIZED')
mockDiscoverServerTools.mockResolvedValue(undefined)
})

it('performs the token exchange through the SSRF-guarded fetch', async () => {
it('performs the token exchange through the SSRF-guarded mcpAuthGuarded wrapper', async () => {
const request = new NextRequest(
'http://localhost:3000/api/mcp/oauth/callback?state=state-1&code=auth-code-1'
)

await GET(request)

expect(mockCreateSsrfGuardedMcpFetch).toHaveBeenCalledTimes(1)
expect(mockMcpAuth).toHaveBeenCalledWith(
// The route must call the guarded wrapper (which defaults fetchFn to the
// SSRF-guarded fetch internally) rather than the raw SDK `auth()` — see
// apps/sim/lib/mcp/oauth/auth.test.ts for the wrapper's own fetchFn coverage.
expect(mcpOauthMockFns.mockMcpAuthGuarded).toHaveBeenCalledWith(
expect.anything(),
expect.objectContaining({
serverUrl: 'https://mcp.example.com/mcp',
authorizationCode: 'auth-code-1',
fetchFn: mockGuardedFetch,
})
)
})
Expand Down
8 changes: 3 additions & 5 deletions apps/sim/app/api/mcp/oauth/callback/route.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import { auth as mcpAuth } from '@modelcontextprotocol/sdk/client/auth.js'
import { db } from '@sim/db'
import { mcpServers } from '@sim/db/schema'
import { createLogger } from '@sim/logger'
Expand All @@ -17,9 +16,9 @@ import {
loadOauthRowByState,
loadPreregisteredClient,
type McpOauthCallbackReason,
mcpAuthGuarded,
SimMcpOauthProvider,
} from '@/lib/mcp/oauth'
import { createSsrfGuardedMcpFetch } from '@/lib/mcp/pinned-fetch'
import { mcpService } from '@/lib/mcp/service'

const logger = createLogger('McpOauthCallbackAPI')
Expand Down Expand Up @@ -145,12 +144,11 @@ export const GET = withRouteHandler(async (request: NextRequest) => {

const preregistered = await loadPreregisteredClient(server.id)
const provider = new SimMcpOauthProvider({ row, preregistered })
let result: Awaited<ReturnType<typeof mcpAuth>>
let result: Awaited<ReturnType<typeof mcpAuthGuarded>>
try {
result = await mcpAuth(provider, {
result = await mcpAuthGuarded(provider, {
serverUrl: server.url,
authorizationCode: code,
fetchFn: createSsrfGuardedMcpFetch(),
})
} catch (e) {
logger.error('Token exchange failed during MCP OAuth callback', e)
Expand Down
29 changes: 10 additions & 19 deletions apps/sim/app/api/mcp/oauth/start/route.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,13 @@ import {
import { NextRequest } from 'next/server'
import { beforeEach, describe, expect, it, vi } from 'vitest'

const { mockMcpAuth, mockCreateSsrfGuardedMcpFetch, mockGuardedFetch } = vi.hoisted(() => ({
mockMcpAuth: vi.fn(),
mockCreateSsrfGuardedMcpFetch: vi.fn(),
mockGuardedFetch: vi.fn(),
}))

vi.mock('@sim/db', () => dbChainMock)
vi.mock('@sim/db/schema', () => schemaMock)
vi.mock('drizzle-orm', () => ({
and: vi.fn(),
eq: vi.fn(),
isNull: vi.fn(),
}))
vi.mock('@modelcontextprotocol/sdk/client/auth.js', () => ({
auth: mockMcpAuth,
}))
vi.mock('@/lib/mcp/pinned-fetch', () => ({
createSsrfGuardedMcpFetch: mockCreateSsrfGuardedMcpFetch,
}))
vi.mock('@/lib/auth/hybrid', () => hybridAuthMock)
vi.mock('@/lib/workspaces/permissions/utils', () => permissionsMock)
vi.mock('@/lib/mcp/oauth', () => mcpOauthMock)
Expand Down Expand Up @@ -77,21 +65,24 @@ describe('MCP OAuth start route', () => {
updatedAt: new Date(),
})
mcpOauthMockFns.mockLoadPreregisteredClient.mockResolvedValue(undefined)
mockMcpAuth.mockRejectedValue(new McpOauthRedirectRequiredMock('https://mcp.exa.ai/authorize'))
mockCreateSsrfGuardedMcpFetch.mockReturnValue(mockGuardedFetch)
mcpOauthMockFns.mockMcpAuthGuarded.mockRejectedValue(
new McpOauthRedirectRequiredMock('https://mcp.exa.ai/authorize')
)
})

it('routes OAuth discovery through the SSRF-guarded fetch', async () => {
it('routes OAuth discovery through the SSRF-guarded mcpAuthGuarded wrapper', async () => {
const request = new NextRequest(
'http://localhost:3000/api/mcp/oauth/start?workspaceId=workspace-1&serverId=server-1'
)

await GET(request)

expect(mockCreateSsrfGuardedMcpFetch).toHaveBeenCalledTimes(1)
expect(mockMcpAuth).toHaveBeenCalledWith(
// The route must call the guarded wrapper (which defaults fetchFn to the
// SSRF-guarded fetch internally) rather than the raw SDK `auth()` — see
// apps/sim/lib/mcp/oauth/auth.test.ts for the wrapper's own fetchFn coverage.
expect(mcpOauthMockFns.mockMcpAuthGuarded).toHaveBeenCalledWith(
expect.anything(),
expect.objectContaining({ serverUrl: 'https://mcp.exa.ai/mcp', fetchFn: mockGuardedFetch })
expect.objectContaining({ serverUrl: 'https://mcp.exa.ai/mcp' })
)
})

Expand Down Expand Up @@ -152,7 +143,7 @@ describe('MCP OAuth start route', () => {

expect(response.status).toBe(409)
expect(body.error).toBe('OAuth authorization already in progress for this server')
expect(mockMcpAuth).not.toHaveBeenCalled()
expect(mcpOauthMockFns.mockMcpAuthGuarded).not.toHaveBeenCalled()
})

it('does not leak non-OAuth internal error details to the client', async () => {
Expand Down
6 changes: 2 additions & 4 deletions apps/sim/app/api/mcp/oauth/start/route.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import { auth as mcpAuth } from '@modelcontextprotocol/sdk/client/auth.js'
import { OAuthError, ServerError } from '@modelcontextprotocol/sdk/server/auth/errors.js'
import { db } from '@sim/db'
import { mcpServers } from '@sim/db/schema'
Expand All @@ -17,10 +16,10 @@ import {
loadPreregisteredClient,
McpOauthInsecureUrlError,
McpOauthRedirectRequired,
mcpAuthGuarded,
SimMcpOauthProvider,
setOauthRowUser,
} from '@/lib/mcp/oauth'
import { createSsrfGuardedMcpFetch } from '@/lib/mcp/pinned-fetch'
import { createMcpErrorResponse } from '@/lib/mcp/utils'

const logger = createLogger('McpOauthStartAPI')
Expand Down Expand Up @@ -130,9 +129,8 @@ export const GET = withRouteHandler(
const provider = new SimMcpOauthProvider({ row, preregistered })

try {
const result = await mcpAuth(provider, {
const result = await mcpAuthGuarded(provider, {
serverUrl: server.url,
fetchFn: createSsrfGuardedMcpFetch(),
})
if (result === 'AUTHORIZED') {
return NextResponse.json({ status: 'already_authorized' })
Expand Down
54 changes: 54 additions & 0 deletions apps/sim/lib/mcp/oauth/auth.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/**
* @vitest-environment node
*/
import type { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth.js'
import { beforeEach, describe, expect, it, vi } from 'vitest'

const { mockAuth, mockCreateSsrfGuardedMcpFetch, mockGuardedFetch } = vi.hoisted(() => ({
mockAuth: vi.fn(),
mockCreateSsrfGuardedMcpFetch: vi.fn(),
mockGuardedFetch: vi.fn(),
}))

vi.mock('@modelcontextprotocol/sdk/client/auth.js', () => ({
auth: mockAuth,
}))
vi.mock('@/lib/mcp/pinned-fetch', () => ({
createSsrfGuardedMcpFetch: mockCreateSsrfGuardedMcpFetch,
}))

import { mcpAuthGuarded } from '@/lib/mcp/oauth/auth'

describe('mcpAuthGuarded', () => {
const provider = {} as OAuthClientProvider

beforeEach(() => {
vi.clearAllMocks()
mockCreateSsrfGuardedMcpFetch.mockReturnValue(mockGuardedFetch)
mockAuth.mockResolvedValue('AUTHORIZED')
})

it('defaults fetchFn to the SSRF-guarded fetch', async () => {
await mcpAuthGuarded(provider, { serverUrl: 'https://mcp.example.com/mcp' })

expect(mockCreateSsrfGuardedMcpFetch).toHaveBeenCalledTimes(1)
expect(mockAuth).toHaveBeenCalledWith(provider, {
serverUrl: 'https://mcp.example.com/mcp',
fetchFn: mockGuardedFetch,
})
})

it('lets a caller-supplied fetchFn override the default', async () => {
const overrideFetch = vi.fn()

await mcpAuthGuarded(provider, {
serverUrl: 'https://mcp.example.com/mcp',
fetchFn: overrideFetch,
})

expect(mockAuth).toHaveBeenCalledWith(provider, {
serverUrl: 'https://mcp.example.com/mcp',
fetchFn: overrideFetch,
})
})
})
19 changes: 19 additions & 0 deletions apps/sim/lib/mcp/oauth/auth.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import { auth, type OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth.js'
import { createSsrfGuardedMcpFetch } from '@/lib/mcp/pinned-fetch'

type McpAuthOptions = Parameters<typeof auth>[1]

/**
* Wraps the MCP SDK's `auth()` and defaults `fetchFn` to the SSRF-guarded
* fetch. Every URL touched during an MCP OAuth exchange — discovery,
* authorization, token, and revocation endpoints — can come from
* attacker-controllable authorization-server metadata, so callers must not
* be able to omit the guard by forgetting to pass `fetchFn` explicitly.
* Pass `fetchFn` in `options` to override (e.g. in tests).
*/
export function mcpAuthGuarded(
provider: OAuthClientProvider,
options: McpAuthOptions
): ReturnType<typeof auth> {
return auth(provider, { fetchFn: createSsrfGuardedMcpFetch(), ...options })
}
1 change: 1 addition & 0 deletions apps/sim/lib/mcp/oauth/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
export { mcpAuthGuarded } from './auth'
export type {
McpOauthCallbackMessage,
McpOauthCallbackReason,
Expand Down
2 changes: 2 additions & 0 deletions packages/testing/src/mocks/mcp-oauth.mock.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import { vi } from 'vitest'
*/
export const mcpOauthMockFns = {
mockAssertSafeOauthServerUrl: vi.fn(),
mockMcpAuthGuarded: vi.fn(),
mockGetOrCreateOauthRow: vi.fn(),
mockLoadOauthRow: vi.fn(),
mockLoadOauthRowByState: vi.fn(),
Expand Down Expand Up @@ -63,6 +64,7 @@ function buildSimMcpOauthProvider(value: object) {
*/
export const mcpOauthMock = {
assertSafeOauthServerUrl: mcpOauthMockFns.mockAssertSafeOauthServerUrl,
mcpAuthGuarded: mcpOauthMockFns.mockMcpAuthGuarded,
getOrCreateOauthRow: mcpOauthMockFns.mockGetOrCreateOauthRow,
loadOauthRow: mcpOauthMockFns.mockLoadOauthRow,
loadOauthRowByState: mcpOauthMockFns.mockLoadOauthRowByState,
Expand Down
Loading