diff --git a/apps/sim/app/api/auth/sso/register/route.ts b/apps/sim/app/api/auth/sso/register/route.ts index 00e499d6fb3..94c57c93478 100644 --- a/apps/sim/app/api/auth/sso/register/route.ts +++ b/apps/sim/app/api/auth/sso/register/route.ts @@ -4,6 +4,10 @@ import { z } from 'zod' import { auth, getSession } from '@/lib/auth' import { hasSSOAccess } from '@/lib/billing' import { env } from '@/lib/core/config/env' +import { + secureFetchWithPinnedIP, + validateUrlWithDNS, +} from '@/lib/core/security/input-validation.server' import { REDACTED_MARKER } from '@/lib/core/security/redaction' const logger = createLogger('SSORegisterRoute') @@ -156,24 +160,66 @@ export async function POST(request: NextRequest) { hasJwksEndpoint: !!oidcConfig.jwksEndpoint, }) - const discoveryResponse = await fetch(discoveryUrl, { - headers: { Accept: 'application/json' }, - }) + const urlValidation = await validateUrlWithDNS(discoveryUrl, 'OIDC discovery URL') + if (!urlValidation.isValid || !urlValidation.resolvedIP) { + logger.warn('OIDC discovery URL failed SSRF validation', { + discoveryUrl, + error: urlValidation.error, + }) + return NextResponse.json( + { error: urlValidation.error ?? 'SSRF validation failed' }, + { status: 400 } + ) + } + + const discoveryResponse = await secureFetchWithPinnedIP( + discoveryUrl, + urlValidation.resolvedIP, + { + headers: { Accept: 'application/json' }, + } + ) if (!discoveryResponse.ok) { logger.error('Failed to fetch OIDC discovery document', { status: discoveryResponse.status, - statusText: discoveryResponse.statusText, }) return NextResponse.json( { - error: `Failed to fetch OIDC discovery document from ${discoveryUrl}. Status: ${discoveryResponse.status}. Provide all endpoints explicitly or verify the issuer URL.`, + error: + 'Failed to fetch OIDC discovery document. Provide all endpoints explicitly or verify the issuer URL.', }, { status: 400 } ) } - const discovery = await discoveryResponse.json() + const discovery = (await discoveryResponse.json()) as Record + + const discoveredEndpoints: Record = { + authorization_endpoint: discovery.authorization_endpoint, + token_endpoint: discovery.token_endpoint, + userinfo_endpoint: discovery.userinfo_endpoint, + jwks_uri: discovery.jwks_uri, + } + + for (const [key, value] of Object.entries(discoveredEndpoints)) { + if (typeof value === 'string') { + const endpointValidation = await validateUrlWithDNS(value, `OIDC ${key}`) + if (!endpointValidation.isValid) { + logger.warn('OIDC discovered endpoint failed SSRF validation', { + endpoint: key, + url: value, + error: endpointValidation.error, + }) + return NextResponse.json( + { + error: `Discovered OIDC ${key} failed security validation: ${endpointValidation.error}`, + }, + { status: 400 } + ) + } + } + } oidcConfig.authorizationEndpoint = oidcConfig.authorizationEndpoint || discovery.authorization_endpoint @@ -196,7 +242,8 @@ export async function POST(request: NextRequest) { }) return NextResponse.json( { - error: `Failed to fetch OIDC discovery document from ${discoveryUrl}. Please verify the issuer URL is correct or provide all endpoints explicitly.`, + error: + 'Failed to fetch OIDC discovery document. Please verify the issuer URL is correct or provide all endpoints explicitly.', }, { status: 400 } ) diff --git a/apps/sim/app/api/chat/[identifier]/otp/route.test.ts b/apps/sim/app/api/chat/[identifier]/otp/route.test.ts index 84a824238c2..7904dc1c3ab 100644 --- a/apps/sim/app/api/chat/[identifier]/otp/route.test.ts +++ b/apps/sim/app/api/chat/[identifier]/otp/route.test.ts @@ -10,11 +10,14 @@ const { mockRedisSet, mockRedisGet, mockRedisDel, + mockRedisTtl, + mockRedisEval, mockGetRedisClient, mockRedisClient, mockDbSelect, mockDbInsert, mockDbDelete, + mockDbUpdate, mockSendEmail, mockRenderOTPEmail, mockAddCorsHeaders, @@ -29,15 +32,20 @@ const { const mockRedisSet = vi.fn() const mockRedisGet = vi.fn() const mockRedisDel = vi.fn() + const mockRedisTtl = vi.fn() + const mockRedisEval = vi.fn() const mockRedisClient = { set: mockRedisSet, get: mockRedisGet, del: mockRedisDel, + ttl: mockRedisTtl, + eval: mockRedisEval, } const mockGetRedisClient = vi.fn() const mockDbSelect = vi.fn() const mockDbInsert = vi.fn() const mockDbDelete = vi.fn() + const mockDbUpdate = vi.fn() const mockSendEmail = vi.fn() const mockRenderOTPEmail = vi.fn() const mockAddCorsHeaders = vi.fn() @@ -53,11 +61,14 @@ const { mockRedisSet, mockRedisGet, mockRedisDel, + mockRedisTtl, + mockRedisEval, mockGetRedisClient, mockRedisClient, mockDbSelect, mockDbInsert, mockDbDelete, + mockDbUpdate, mockSendEmail, mockRenderOTPEmail, mockAddCorsHeaders, @@ -80,11 +91,13 @@ vi.mock('@sim/db', () => ({ select: mockDbSelect, insert: mockDbInsert, delete: mockDbDelete, + update: mockDbUpdate, transaction: vi.fn(async (callback: (tx: Record) => unknown) => { return callback({ select: mockDbSelect, insert: mockDbInsert, delete: mockDbDelete, + update: mockDbUpdate, }) }), }, @@ -126,12 +139,24 @@ vi.mock('@/lib/messaging/email/mailer', () => ({ sendEmail: mockSendEmail, })) -vi.mock('@/components/emails/render-email', () => ({ +vi.mock('@/components/emails', () => ({ renderOTPEmail: mockRenderOTPEmail, })) -vi.mock('@/app/api/chat/utils', () => ({ +vi.mock('@/lib/core/security/deployment', () => ({ addCorsHeaders: mockAddCorsHeaders, + isEmailAllowed: (email: string, allowedEmails: string[]) => { + if (allowedEmails.includes(email)) return true + const atIndex = email.indexOf('@') + if (atIndex > 0) { + const domain = email.substring(atIndex + 1) + if (domain && allowedEmails.some((allowed: string) => allowed === `@${domain}`)) return true + } + return false + }, +})) + +vi.mock('@/app/api/chat/utils', () => ({ setChatAuthCookie: mockSetChatAuthCookie, })) @@ -209,6 +234,7 @@ describe('Chat OTP API Route', () => { mockRedisSet.mockResolvedValue('OK') mockRedisGet.mockResolvedValue(null) mockRedisDel.mockResolvedValue(1) + mockRedisTtl.mockResolvedValue(600) const createDbChain = (result: unknown) => ({ from: vi.fn().mockReturnValue({ @@ -225,6 +251,11 @@ describe('Chat OTP API Route', () => { mockDbDelete.mockImplementation(() => ({ where: vi.fn().mockResolvedValue(undefined), })) + mockDbUpdate.mockImplementation(() => ({ + set: vi.fn().mockReturnValue({ + where: vi.fn().mockResolvedValue(undefined), + }), + })) mockGetStorageMethod.mockReturnValue('redis') @@ -349,7 +380,7 @@ describe('Chat OTP API Route', () => { describe('PUT - Verify OTP (Redis path)', () => { beforeEach(() => { mockGetStorageMethod.mockReturnValue('redis') - mockRedisGet.mockResolvedValue(mockOTP) + mockRedisGet.mockResolvedValue(`${mockOTP}:0`) }) it('should retrieve OTP from Redis and verify successfully', async () => { @@ -374,9 +405,7 @@ describe('Chat OTP API Route', () => { await PUT(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) expect(mockRedisGet).toHaveBeenCalledWith(`otp:${mockEmail}:${mockChatId}`) - expect(mockRedisDel).toHaveBeenCalledWith(`otp:${mockEmail}:${mockChatId}`) - expect(mockDbSelect).toHaveBeenCalledTimes(1) }) }) @@ -405,7 +434,7 @@ describe('Chat OTP API Route', () => { } return Promise.resolve([ { - value: mockOTP, + value: `${mockOTP}:0`, expiresAt: new Date(Date.now() + 10 * 60 * 1000), }, ]) @@ -475,7 +504,7 @@ describe('Chat OTP API Route', () => { }) it('should delete OTP from Redis after verification', async () => { - mockRedisGet.mockResolvedValue(mockOTP) + mockRedisGet.mockResolvedValue(`${mockOTP}:0`) mockDbSelect.mockImplementationOnce(() => ({ from: vi.fn().mockReturnValue({ @@ -519,7 +548,7 @@ describe('Chat OTP API Route', () => { return Promise.resolve([{ id: mockChatId, authType: 'email' }]) } return Promise.resolve([ - { value: mockOTP, expiresAt: new Date(Date.now() + 10 * 60 * 1000) }, + { value: `${mockOTP}:0`, expiresAt: new Date(Date.now() + 10 * 60 * 1000) }, ]) }), }), @@ -543,6 +572,97 @@ describe('Chat OTP API Route', () => { }) }) + describe('Brute-force protection', () => { + beforeEach(() => { + mockGetStorageMethod.mockReturnValue('redis') + }) + + it('should atomically increment attempts on wrong OTP', async () => { + mockRedisGet.mockResolvedValue('654321:0') + mockRedisEval.mockResolvedValue('654321:1') + + mockDbSelect.mockImplementationOnce(() => ({ + from: vi.fn().mockReturnValue({ + where: vi.fn().mockReturnValue({ + limit: vi.fn().mockResolvedValue([{ id: mockChatId, authType: 'email' }]), + }), + }), + })) + + const request = new NextRequest('http://localhost:3000/api/chat/test/otp', { + method: 'PUT', + body: JSON.stringify({ email: mockEmail, otp: 'wrong1' }), + }) + + await PUT(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockRedisEval).toHaveBeenCalledWith( + expect.any(String), + 1, + `otp:${mockEmail}:${mockChatId}`, + 5 + ) + expect(mockCreateErrorResponse).toHaveBeenCalledWith('Invalid verification code', 400) + }) + + it('should invalidate OTP and return 429 after max failed attempts', async () => { + mockRedisGet.mockResolvedValue('654321:4') + mockRedisEval.mockResolvedValue('LOCKED') + + mockDbSelect.mockImplementationOnce(() => ({ + from: vi.fn().mockReturnValue({ + where: vi.fn().mockReturnValue({ + limit: vi.fn().mockResolvedValue([{ id: mockChatId, authType: 'email' }]), + }), + }), + })) + + const request = new NextRequest('http://localhost:3000/api/chat/test/otp', { + method: 'PUT', + body: JSON.stringify({ email: mockEmail, otp: 'wrong5' }), + }) + + await PUT(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockRedisEval).toHaveBeenCalled() + expect(mockCreateErrorResponse).toHaveBeenCalledWith( + 'Too many failed attempts. Please request a new code.', + 429 + ) + }) + + it('should store OTP with zero attempts on generation', async () => { + mockDbSelect.mockImplementationOnce(() => ({ + from: vi.fn().mockReturnValue({ + where: vi.fn().mockReturnValue({ + limit: vi.fn().mockResolvedValue([ + { + id: mockChatId, + authType: 'email', + allowedEmails: [mockEmail], + title: 'Test Chat', + }, + ]), + }), + }), + })) + + const request = new NextRequest('http://localhost:3000/api/chat/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + await POST(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockRedisSet).toHaveBeenCalledWith( + `otp:${mockEmail}:${mockChatId}`, + expect.stringMatching(/^\d{6}:0$/), + 'EX', + 900 + ) + }) + }) + describe('Behavior consistency between Redis and Database', () => { it('should have same behavior for missing OTP in both storage methods', async () => { mockGetStorageMethod.mockReturnValue('redis') diff --git a/apps/sim/app/api/chat/[identifier]/otp/route.ts b/apps/sim/app/api/chat/[identifier]/otp/route.ts index 983e65d9850..4d2358fcf5b 100644 --- a/apps/sim/app/api/chat/[identifier]/otp/route.ts +++ b/apps/sim/app/api/chat/[identifier]/otp/route.ts @@ -1,4 +1,4 @@ -import { randomUUID } from 'crypto' +import { randomInt, randomUUID } from 'crypto' import { db } from '@sim/db' import { chat, verification } from '@sim/db/schema' import { createLogger } from '@sim/logger' @@ -7,7 +7,7 @@ import type { NextRequest } from 'next/server' import { z } from 'zod' import { renderOTPEmail } from '@/components/emails' import { getRedisClient } from '@/lib/core/config/redis' -import { addCorsHeaders } from '@/lib/core/security/deployment' +import { addCorsHeaders, isEmailAllowed } from '@/lib/core/security/deployment' import { getStorageMethod } from '@/lib/core/storage' import { generateRequestId } from '@/lib/core/utils/request' import { sendEmail } from '@/lib/messaging/email/mailer' @@ -16,12 +16,28 @@ import { createErrorResponse, createSuccessResponse } from '@/app/api/workflows/ const logger = createLogger('ChatOtpAPI') -function generateOTP() { - return Math.floor(100000 + Math.random() * 900000).toString() +function generateOTP(): string { + return randomInt(100000, 1000000).toString() } const OTP_EXPIRY = 15 * 60 // 15 minutes const OTP_EXPIRY_MS = OTP_EXPIRY * 1000 +const MAX_OTP_ATTEMPTS = 5 + +/** + * OTP values are stored as "code:attempts" (e.g. "654321:0"). + * This keeps the attempt counter in the same key/row as the OTP itself. + */ +function encodeOTPValue(otp: string, attempts: number): string { + return `${otp}:${attempts}` +} + +function decodeOTPValue(value: string): { otp: string; attempts: number } { + const lastColon = value.lastIndexOf(':') + if (lastColon === -1) return { otp: value, attempts: 0 } + const attempts = Number.parseInt(value.slice(lastColon + 1), 10) + return { otp: value.slice(0, lastColon), attempts: Number.isNaN(attempts) ? 0 : attempts } +} /** * Stores OTP in Redis or database depending on storage method. @@ -30,14 +46,14 @@ const OTP_EXPIRY_MS = OTP_EXPIRY * 1000 async function storeOTP(email: string, chatId: string, otp: string): Promise { const identifier = `chat-otp:${chatId}:${email}` const storageMethod = getStorageMethod() + const value = encodeOTPValue(otp, 0) if (storageMethod === 'redis') { const redis = getRedisClient() if (!redis) { throw new Error('Redis configured but client unavailable') } - const key = `otp:${email}:${chatId}` - await redis.set(key, otp, 'EX', OTP_EXPIRY) + await redis.set(`otp:${email}:${chatId}`, value, 'EX', OTP_EXPIRY) } else { const now = new Date() const expiresAt = new Date(now.getTime() + OTP_EXPIRY_MS) @@ -47,7 +63,7 @@ async function storeOTP(email: string, chatId: string, otp: string): Promise { if (!redis) { throw new Error('Redis configured but client unavailable') } - const key = `otp:${email}:${chatId}` - return redis.get(key) + return redis.get(`otp:${email}:${chatId}`) } const now = new Date() const [record] = await db - .select({ - value: verification.value, - expiresAt: verification.expiresAt, - }) + .select({ value: verification.value }) .from(verification) .where(and(eq(verification.identifier, identifier), gt(verification.expiresAt, now))) .limit(1) - if (!record) return null + return record?.value ?? null +} + +/** + * Lua script for atomic OTP attempt increment. + * Returns: "LOCKED" if max attempts reached (key deleted), new encoded value otherwise, nil if key missing. + */ +const ATOMIC_INCREMENT_SCRIPT = ` +local val = redis.call('GET', KEYS[1]) +if not val then return nil end +local colon = val:find(':([^:]*$)') +local otp, attempts +if colon then + otp = val:sub(1, colon - 1) + attempts = tonumber(val:sub(colon + 1)) or 0 +else + otp = val + attempts = 0 +end +attempts = attempts + 1 +if attempts >= tonumber(ARGV[1]) then + redis.call('DEL', KEYS[1]) + return 'LOCKED' +end +local newVal = otp .. ':' .. attempts +local ttl = redis.call('TTL', KEYS[1]) +if ttl > 0 then + redis.call('SET', KEYS[1], newVal, 'EX', ttl) +else + redis.call('SET', KEYS[1], newVal) +end +return newVal +` - return record.value +/** + * Atomically increments OTP attempts. Returns 'locked' if max reached, 'incremented' otherwise. + */ +async function incrementOTPAttempts( + email: string, + chatId: string, + currentValue: string +): Promise<'locked' | 'incremented'> { + const identifier = `chat-otp:${chatId}:${email}` + const storageMethod = getStorageMethod() + + if (storageMethod === 'redis') { + const redis = getRedisClient() + if (!redis) { + throw new Error('Redis configured but client unavailable') + } + const key = `otp:${email}:${chatId}` + const result = await redis.eval(ATOMIC_INCREMENT_SCRIPT, 1, key, MAX_OTP_ATTEMPTS) + if (result === null || result === 'LOCKED') return 'locked' + return 'incremented' + } + + // DB path: optimistic locking with retry on conflict + const MAX_RETRIES = 3 + let value = currentValue + + for (let attempt = 0; attempt < MAX_RETRIES; attempt++) { + const { otp, attempts } = decodeOTPValue(value) + const newAttempts = attempts + 1 + + if (newAttempts >= MAX_OTP_ATTEMPTS) { + await db.delete(verification).where(eq(verification.identifier, identifier)) + return 'locked' + } + + const newValue = encodeOTPValue(otp, newAttempts) + const updated = await db + .update(verification) + .set({ value: newValue, updatedAt: new Date() }) + .where(and(eq(verification.identifier, identifier), eq(verification.value, value))) + .returning({ id: verification.id }) + + if (updated.length > 0) return 'incremented' + + // Conflict: another request already incremented — re-read and retry + const fresh = await getOTP(email, chatId) + if (!fresh) return 'locked' + value = fresh + } + + // Exhausted retries — re-read final state to determine outcome + const final = await getOTP(email, chatId) + if (!final) return 'locked' + const { attempts: finalAttempts } = decodeOTPValue(final) + return finalAttempts >= MAX_OTP_ATTEMPTS ? 'locked' : 'incremented' } async function deleteOTP(email: string, chatId: string): Promise { @@ -93,8 +191,7 @@ async function deleteOTP(email: string, chatId: string): Promise { if (!redis) { throw new Error('Redis configured but client unavailable') } - const key = `otp:${email}:${chatId}` - await redis.del(key) + await redis.del(`otp:${email}:${chatId}`) } else { await db.delete(verification).where(eq(verification.identifier, identifier)) } @@ -149,17 +246,7 @@ export async function POST( ? deployment.allowedEmails : [] - const isEmailAllowed = - allowedEmails.includes(email) || - allowedEmails.some((allowed: string) => { - if (allowed.startsWith('@')) { - const domain = email.split('@')[1] - return domain && allowed === `@${domain}` - } - return false - }) - - if (!isEmailAllowed) { + if (!isEmailAllowed(email, allowedEmails)) { return addCorsHeaders(createErrorResponse('Email not authorized for this chat', 403), request) } @@ -216,6 +303,7 @@ export async function PUT( .select({ id: chat.id, authType: chat.authType, + password: chat.password, }) .from(chat) .where(and(eq(chat.identifier, identifier), eq(chat.isActive, true), isNull(chat.archivedAt))) @@ -228,22 +316,41 @@ export async function PUT( const deployment = deploymentResult[0] - const storedOTP = await getOTP(email, deployment.id) - if (!storedOTP) { + const storedValue = await getOTP(email, deployment.id) + if (!storedValue) { return addCorsHeaders( createErrorResponse('No verification code found, request a new one', 400), request ) } + const { otp: storedOTP, attempts } = decodeOTPValue(storedValue) + + if (attempts >= MAX_OTP_ATTEMPTS) { + await deleteOTP(email, deployment.id) + logger.warn(`[${requestId}] OTP already locked out for ${email}`) + return addCorsHeaders( + createErrorResponse('Too many failed attempts. Please request a new code.', 429), + request + ) + } + if (storedOTP !== otp) { + const result = await incrementOTPAttempts(email, deployment.id, storedValue) + if (result === 'locked') { + logger.warn(`[${requestId}] OTP invalidated after max failed attempts for ${email}`) + return addCorsHeaders( + createErrorResponse('Too many failed attempts. Please request a new code.', 429), + request + ) + } return addCorsHeaders(createErrorResponse('Invalid verification code', 400), request) } await deleteOTP(email, deployment.id) const response = addCorsHeaders(createSuccessResponse({ authenticated: true }), request) - setChatAuthCookie(response, deployment.id, deployment.authType) + setChatAuthCookie(response, deployment.id, deployment.authType, deployment.password) return response } catch (error: any) { diff --git a/apps/sim/app/api/chat/utils.test.ts b/apps/sim/app/api/chat/utils.test.ts index acf629072ac..de604854f07 100644 --- a/apps/sim/app/api/chat/utils.test.ts +++ b/apps/sim/app/api/chat/utils.test.ts @@ -7,13 +7,23 @@ import { databaseMock, loggerMock, requestUtilsMock } from '@sim/testing' import type { NextResponse } from 'next/server' import { beforeEach, describe, expect, it, vi } from 'vitest' -const { mockDecryptSecret, mockMergeSubblockStateWithValues, mockMergeSubBlockValues } = vi.hoisted( - () => ({ - mockDecryptSecret: vi.fn(), - mockMergeSubblockStateWithValues: vi.fn().mockReturnValue({}), - mockMergeSubBlockValues: vi.fn().mockReturnValue({}), - }) -) +const { + mockDecryptSecret, + mockMergeSubblockStateWithValues, + mockMergeSubBlockValues, + mockValidateAuthToken, + mockSetDeploymentAuthCookie, + mockAddCorsHeaders, + mockIsEmailAllowed, +} = vi.hoisted(() => ({ + mockDecryptSecret: vi.fn(), + mockMergeSubblockStateWithValues: vi.fn().mockReturnValue({}), + mockMergeSubBlockValues: vi.fn().mockReturnValue({}), + mockValidateAuthToken: vi.fn().mockReturnValue(false), + mockSetDeploymentAuthCookie: vi.fn(), + mockAddCorsHeaders: vi.fn((response: unknown) => response), + mockIsEmailAllowed: vi.fn(), +})) vi.mock('@sim/db', () => databaseMock) vi.mock('@sim/logger', () => loggerMock) @@ -45,6 +55,13 @@ vi.mock('@/lib/core/security/encryption', () => ({ vi.mock('@/lib/core/utils/request', () => requestUtilsMock) +vi.mock('@/lib/core/security/deployment', () => ({ + validateAuthToken: mockValidateAuthToken, + setDeploymentAuthCookie: mockSetDeploymentAuthCookie, + addCorsHeaders: mockAddCorsHeaders, + isEmailAllowed: mockIsEmailAllowed, +})) + vi.mock('@/lib/core/config/feature-flags', () => ({ isDev: true, isHosted: false, @@ -55,7 +72,6 @@ vi.mock('@/lib/workflows/utils', () => ({ authorizeWorkflowByWorkspacePermission: vi.fn(), })) -import { addCorsHeaders, validateAuthToken } from '@/lib/core/security/deployment' import { decryptSecret } from '@/lib/core/security/encryption' import { setChatAuthCookie, validateChatAuth } from '@/app/api/chat/utils' @@ -72,90 +88,66 @@ describe('Chat API Utils', () => { }) describe('Auth token utils', () => { - it.concurrent('should validate auth tokens', () => { - const chatId = 'test-chat-id' - const type = 'password' + it('should accept valid auth cookie via validateChatAuth', async () => { + mockValidateAuthToken.mockReturnValue(true) - const token = Buffer.from(`${chatId}:${type}:${Date.now()}`).toString('base64') - expect(typeof token).toBe('string') - expect(token.length).toBeGreaterThan(0) + const deployment = { + id: 'chat-id', + authType: 'password', + password: 'encrypted-password', + } - const isValid = validateAuthToken(token, chatId) - expect(isValid).toBe(true) + const mockRequest = { + method: 'POST', + cookies: { + get: vi.fn().mockReturnValue({ value: 'valid-token' }), + }, + } as any - const isInvalidChat = validateAuthToken(token, 'wrong-chat-id') - expect(isInvalidChat).toBe(false) + const result = await validateChatAuth('request-id', deployment, mockRequest) + expect(mockValidateAuthToken).toHaveBeenCalledWith( + 'valid-token', + 'chat-id', + 'encrypted-password' + ) + expect(result.authorized).toBe(true) }) - it.concurrent('should reject expired tokens', () => { - const chatId = 'test-chat-id' - const expiredToken = Buffer.from( - `${chatId}:password:${Date.now() - 25 * 60 * 60 * 1000}` - ).toString('base64') + it('should reject invalid auth cookie via validateChatAuth', async () => { + mockValidateAuthToken.mockReturnValue(false) - const isValid = validateAuthToken(expiredToken, chatId) - expect(isValid).toBe(false) - }) - }) + const deployment = { + id: 'chat-id', + authType: 'password', + password: 'encrypted-password', + } - describe('Cookie handling', () => { - it('should set auth cookie correctly', () => { - const mockSet = vi.fn() - const mockResponse = { + const mockRequest = { + method: 'GET', cookies: { - set: mockSet, + get: vi.fn().mockReturnValue({ value: 'invalid-token' }), }, - } as unknown as NextResponse - - const chatId = 'test-chat-id' - const type = 'password' - - setChatAuthCookie(mockResponse, chatId, type) + } as any - expect(mockSet).toHaveBeenCalledWith({ - name: `chat_auth_${chatId}`, - value: expect.any(String), - httpOnly: true, - secure: false, // Development mode - sameSite: 'lax', - path: '/', - domain: undefined, // Development mode - maxAge: 60 * 60 * 24, - }) + const result = await validateChatAuth('request-id', deployment, mockRequest) + expect(result.authorized).toBe(false) }) }) - describe('CORS handling', () => { - it('should add CORS headers for localhost in development', () => { - const mockRequest = { - headers: { - get: vi.fn().mockReturnValue('http://localhost:3000'), - }, - } as any - + describe('Cookie handling', () => { + it('should delegate to setDeploymentAuthCookie', () => { const mockResponse = { - headers: { - set: vi.fn(), - }, + cookies: { set: vi.fn() }, } as unknown as NextResponse - addCorsHeaders(mockResponse, mockRequest) + setChatAuthCookie(mockResponse, 'test-chat-id', 'password') - expect(mockResponse.headers.set).toHaveBeenCalledWith( - 'Access-Control-Allow-Origin', - 'http://localhost:3000' - ) - expect(mockResponse.headers.set).toHaveBeenCalledWith( - 'Access-Control-Allow-Credentials', - 'true' - ) - expect(mockResponse.headers.set).toHaveBeenCalledWith( - 'Access-Control-Allow-Methods', - 'GET, POST, OPTIONS' - ) - expect(mockResponse.headers.set).toHaveBeenCalledWith( - 'Access-Control-Allow-Headers', - 'Content-Type, X-Requested-With' + expect(mockSetDeploymentAuthCookie).toHaveBeenCalledWith( + mockResponse, + 'chat', + 'test-chat-id', + 'password', + undefined ) }) }) @@ -283,6 +275,7 @@ describe('Chat API Utils', () => { }, } as any + mockIsEmailAllowed.mockReturnValue(true) const result1 = await validateChatAuth('request-id', deployment, mockRequest, { email: 'user@example.com', }) @@ -295,6 +288,7 @@ describe('Chat API Utils', () => { expect(result2.authorized).toBe(false) expect(result2.error).toBe('otp_required') + mockIsEmailAllowed.mockReturnValue(false) const result3 = await validateChatAuth('request-id', deployment, mockRequest, { email: 'user@unknown.com', }) diff --git a/apps/sim/app/api/files/authorization.ts b/apps/sim/app/api/files/authorization.ts index 610cd7f3107..1be57e4a389 100644 --- a/apps/sim/app/api/files/authorization.ts +++ b/apps/sim/app/api/files/authorization.ts @@ -114,9 +114,9 @@ export async function verifyFileAccess( // Infer context from key if not explicitly provided const inferredContext = context || inferContextFromKey(cloudKey) - // 0. Profile pictures: Public access (anyone can view creator profile pictures) - if (inferredContext === 'profile-pictures') { - logger.info('Profile picture access allowed (public)', { cloudKey }) + // 0. Public contexts: profile pictures and OG images are publicly accessible + if (inferredContext === 'profile-pictures' || inferredContext === 'og-images') { + logger.info('Public file access allowed', { cloudKey, context: inferredContext }) return true } diff --git a/apps/sim/app/api/files/serve/[...path]/route.ts b/apps/sim/app/api/files/serve/[...path]/route.ts index 28ccfad0e85..bc14086395a 100644 --- a/apps/sim/app/api/files/serve/[...path]/route.ts +++ b/apps/sim/app/api/files/serve/[...path]/route.ts @@ -94,12 +94,11 @@ export async function GET( const isCloudPath = isS3Path || isBlobPath const cloudKey = isCloudPath ? path.slice(1).join('/') : fullPath - const contextParam = request.nextUrl.searchParams.get('context') - const raw = request.nextUrl.searchParams.get('raw') === '1' - - const context = contextParam || (isCloudPath ? inferContextFromKey(cloudKey) : undefined) + const isPublicByKeyPrefix = + cloudKey.startsWith('profile-pictures/') || cloudKey.startsWith('og-images/') - if (context === 'profile-pictures' || context === 'og-images') { + if (isPublicByKeyPrefix) { + const context = inferContextFromKey(cloudKey) logger.info(`Serving public ${context}:`, { cloudKey }) if (isUsingCloudStorage() || isCloudPath) { return await handleCloudProxyPublic(cloudKey, context) @@ -107,6 +106,8 @@ export async function GET( return await handleLocalFilePublic(fullPath) } + const raw = request.nextUrl.searchParams.get('raw') === '1' + const authResult = await checkSessionOrInternalAuth(request, { requireWorkflowId: false }) if (!authResult.success || !authResult.userId) { @@ -120,7 +121,7 @@ export async function GET( const userId = authResult.userId if (isUsingCloudStorage()) { - return await handleCloudProxy(cloudKey, userId, contextParam, raw) + return await handleCloudProxy(cloudKey, userId, raw) } return await handleLocalFile(cloudKey, userId, raw) @@ -192,19 +193,11 @@ async function handleLocalFile( async function handleCloudProxy( cloudKey: string, userId: string, - contextParam?: string | null, raw = false ): Promise { try { - let context: StorageContext - - if (contextParam) { - context = contextParam as StorageContext - logger.info(`Using explicit context: ${context} for key: ${cloudKey}`) - } else { - context = inferContextFromKey(cloudKey) - logger.info(`Inferred context: ${context} from key pattern: ${cloudKey}`) - } + const context = inferContextFromKey(cloudKey) + logger.info(`Inferred context: ${context} from key pattern: ${cloudKey}`) const hasAccess = await verifyFileAccess( cloudKey, diff --git a/apps/sim/app/api/form/utils.test.ts b/apps/sim/app/api/form/utils.test.ts index f40773efdd0..d88146cb69f 100644 --- a/apps/sim/app/api/form/utils.test.ts +++ b/apps/sim/app/api/form/utils.test.ts @@ -7,8 +7,18 @@ import { databaseMock, loggerMock } from '@sim/testing' import type { NextResponse } from 'next/server' import { beforeEach, describe, expect, it, vi } from 'vitest' -const { mockDecryptSecret } = vi.hoisted(() => ({ +const { + mockDecryptSecret, + mockValidateAuthToken, + mockSetDeploymentAuthCookie, + mockAddCorsHeaders, + mockIsEmailAllowed, +} = vi.hoisted(() => ({ mockDecryptSecret: vi.fn(), + mockValidateAuthToken: vi.fn().mockReturnValue(false), + mockSetDeploymentAuthCookie: vi.fn(), + mockAddCorsHeaders: vi.fn((response: unknown) => response), + mockIsEmailAllowed: vi.fn(), })) vi.mock('@sim/db', () => databaseMock) @@ -18,6 +28,13 @@ vi.mock('@/lib/core/security/encryption', () => ({ decryptSecret: mockDecryptSecret, })) +vi.mock('@/lib/core/security/deployment', () => ({ + validateAuthToken: mockValidateAuthToken, + setDeploymentAuthCookie: mockSetDeploymentAuthCookie, + addCorsHeaders: mockAddCorsHeaders, + isEmailAllowed: mockIsEmailAllowed, +})) + vi.mock('@/lib/core/config/feature-flags', () => ({ isDev: true, isHosted: false, @@ -28,8 +45,6 @@ vi.mock('@/lib/workflows/utils', () => ({ authorizeWorkflowByWorkspacePermission: vi.fn(), })) -import crypto from 'crypto' -import { addCorsHeaders, validateAuthToken } from '@/lib/core/security/deployment' import { decryptSecret } from '@/lib/core/security/encryption' import { DEFAULT_FORM_CUSTOMIZATIONS, @@ -43,126 +58,67 @@ describe('Form API Utils', () => { }) describe('Auth token utils', () => { - it.concurrent('should validate auth tokens', () => { - const formId = 'test-form-id' - const type = 'password' - - const token = Buffer.from(`${formId}:${type}:${Date.now()}`).toString('base64') - expect(typeof token).toBe('string') - expect(token.length).toBeGreaterThan(0) - - const isValid = validateAuthToken(token, formId) - expect(isValid).toBe(true) + it('should accept valid auth cookie via validateFormAuth', async () => { + mockValidateAuthToken.mockReturnValue(true) - const isInvalidForm = validateAuthToken(token, 'wrong-form-id') - expect(isInvalidForm).toBe(false) - }) - - it.concurrent('should reject expired tokens', () => { - const formId = 'test-form-id' - const expiredToken = Buffer.from( - `${formId}:password:${Date.now() - 25 * 60 * 60 * 1000}` - ).toString('base64') - - const isValid = validateAuthToken(expiredToken, formId) - expect(isValid).toBe(false) - }) - - it.concurrent('should validate tokens with password hash', () => { - const formId = 'test-form-id' - const encryptedPassword = 'encrypted-password-value' - const pwHash = crypto - .createHash('sha256') - .update(encryptedPassword) - .digest('hex') - .substring(0, 8) - - const token = Buffer.from(`${formId}:password:${Date.now()}:${pwHash}`).toString('base64') - - const isValid = validateAuthToken(token, formId, encryptedPassword) - expect(isValid).toBe(true) - - const isInvalidPassword = validateAuthToken(token, formId, 'different-password') - expect(isInvalidPassword).toBe(false) - }) - }) + const deployment = { + id: 'form-id', + authType: 'password', + password: 'encrypted-password', + } - describe('Cookie handling', () => { - it('should set auth cookie correctly', () => { - const mockSet = vi.fn() - const mockResponse = { + const mockRequest = { + method: 'POST', cookies: { - set: mockSet, + get: vi.fn().mockReturnValue({ value: 'valid-token' }), }, - } as unknown as NextResponse + } as any - const formId = 'test-form-id' - const type = 'password' + const result = await validateFormAuth('request-id', deployment, mockRequest) + expect(mockValidateAuthToken).toHaveBeenCalledWith( + 'valid-token', + 'form-id', + 'encrypted-password' + ) + expect(result.authorized).toBe(true) + }) - setFormAuthCookie(mockResponse, formId, type) + it('should reject invalid auth cookie via validateFormAuth', async () => { + mockValidateAuthToken.mockReturnValue(false) - expect(mockSet).toHaveBeenCalledWith({ - name: `form_auth_${formId}`, - value: expect.any(String), - httpOnly: true, - secure: false, // Development mode - sameSite: 'lax', - path: '/', - maxAge: 60 * 60 * 24, - }) - }) - }) + const deployment = { + id: 'form-id', + authType: 'password', + password: 'encrypted-password', + } - describe('CORS handling', () => { - it.concurrent('should add CORS headers for any origin', () => { const mockRequest = { - headers: { - get: vi.fn().mockReturnValue('http://localhost:3000'), + method: 'GET', + cookies: { + get: vi.fn().mockReturnValue({ value: 'invalid-token' }), }, } as any - const mockResponse = { - headers: { - set: vi.fn(), - }, - } as unknown as NextResponse - - addCorsHeaders(mockResponse, mockRequest) - - expect(mockResponse.headers.set).toHaveBeenCalledWith( - 'Access-Control-Allow-Origin', - 'http://localhost:3000' - ) - expect(mockResponse.headers.set).toHaveBeenCalledWith( - 'Access-Control-Allow-Credentials', - 'true' - ) - expect(mockResponse.headers.set).toHaveBeenCalledWith( - 'Access-Control-Allow-Methods', - 'GET, POST, OPTIONS' - ) - expect(mockResponse.headers.set).toHaveBeenCalledWith( - 'Access-Control-Allow-Headers', - 'Content-Type, X-Requested-With' - ) + const result = await validateFormAuth('request-id', deployment, mockRequest) + expect(result.authorized).toBe(false) }) + }) - it.concurrent('should not set CORS headers when no origin', () => { - const mockRequest = { - headers: { - get: vi.fn().mockReturnValue(''), - }, - } as any - + describe('Cookie handling', () => { + it('should delegate to setDeploymentAuthCookie', () => { const mockResponse = { - headers: { - set: vi.fn(), - }, + cookies: { set: vi.fn() }, } as unknown as NextResponse - addCorsHeaders(mockResponse, mockRequest) + setFormAuthCookie(mockResponse, 'test-form-id', 'password') - expect(mockResponse.headers.set).not.toHaveBeenCalled() + expect(mockSetDeploymentAuthCookie).toHaveBeenCalledWith( + mockResponse, + 'form', + 'test-form-id', + 'password', + undefined + ) }) }) @@ -291,6 +247,7 @@ describe('Form API Utils', () => { } as any // Exact email match should authorize + mockIsEmailAllowed.mockReturnValue(true) const result1 = await validateFormAuth('request-id', deployment, mockRequest, { email: 'user@example.com', }) @@ -303,6 +260,7 @@ describe('Form API Utils', () => { expect(result2.authorized).toBe(true) // Unknown email should not authorize + mockIsEmailAllowed.mockReturnValue(false) const result3 = await validateFormAuth('request-id', deployment, mockRequest, { email: 'user@unknown.com', }) diff --git a/apps/sim/app/api/mcp/servers/[id]/route.ts b/apps/sim/app/api/mcp/servers/[id]/route.ts index 597244a9703..54265bb687c 100644 --- a/apps/sim/app/api/mcp/servers/[id]/route.ts +++ b/apps/sim/app/api/mcp/servers/[id]/route.ts @@ -4,7 +4,13 @@ import { createLogger } from '@sim/logger' import { and, eq, isNull } from 'drizzle-orm' import type { NextRequest } from 'next/server' import { AuditAction, AuditResourceType, recordAudit } from '@/lib/audit/log' -import { McpDomainNotAllowedError, validateMcpDomain } from '@/lib/mcp/domain-check' +import { + McpDnsResolutionError, + McpDomainNotAllowedError, + McpSsrfError, + validateMcpDomain, + validateMcpServerSsrf, +} from '@/lib/mcp/domain-check' import { getParsedBody, withMcpAuth } from '@/lib/mcp/middleware' import { mcpService } from '@/lib/mcp/service' import { createMcpErrorResponse, createMcpSuccessResponse } from '@/lib/mcp/utils' @@ -44,6 +50,18 @@ export const PATCH = withMcpAuth<{ id: string }>('write')( } throw e } + + try { + await validateMcpServerSsrf(updateData.url) + } catch (e) { + if (e instanceof McpDnsResolutionError) { + return createMcpErrorResponse(e, e.message, 502) + } + if (e instanceof McpSsrfError) { + return createMcpErrorResponse(e, e.message, 403) + } + throw e + } } // Get the current server to check if URL is changing diff --git a/apps/sim/app/api/mcp/servers/route.ts b/apps/sim/app/api/mcp/servers/route.ts index 3087ff9bde5..ff08085d1ec 100644 --- a/apps/sim/app/api/mcp/servers/route.ts +++ b/apps/sim/app/api/mcp/servers/route.ts @@ -4,7 +4,13 @@ import { createLogger } from '@sim/logger' import { and, eq, isNull } from 'drizzle-orm' import type { NextRequest } from 'next/server' import { AuditAction, AuditResourceType, recordAudit } from '@/lib/audit/log' -import { McpDomainNotAllowedError, validateMcpDomain } from '@/lib/mcp/domain-check' +import { + McpDnsResolutionError, + McpDomainNotAllowedError, + McpSsrfError, + validateMcpDomain, + validateMcpServerSsrf, +} from '@/lib/mcp/domain-check' import { getParsedBody, withMcpAuth } from '@/lib/mcp/middleware' import { mcpService } from '@/lib/mcp/service' import { @@ -83,6 +89,18 @@ export const POST = withMcpAuth('write')( throw e } + try { + await validateMcpServerSsrf(body.url) + } catch (e) { + if (e instanceof McpDnsResolutionError) { + return createMcpErrorResponse(e, e.message, 502) + } + if (e instanceof McpSsrfError) { + return createMcpErrorResponse(e, e.message, 403) + } + throw e + } + const serverId = body.url ? generateMcpServerId(workspaceId, body.url) : crypto.randomUUID() const [existingServer] = await db diff --git a/apps/sim/app/api/mcp/servers/test-connection/route.ts b/apps/sim/app/api/mcp/servers/test-connection/route.ts index 4f9f6a990d9..37d6696b9c0 100644 --- a/apps/sim/app/api/mcp/servers/test-connection/route.ts +++ b/apps/sim/app/api/mcp/servers/test-connection/route.ts @@ -1,7 +1,13 @@ import { createLogger } from '@sim/logger' import type { NextRequest } from 'next/server' import { McpClient } from '@/lib/mcp/client' -import { McpDomainNotAllowedError, validateMcpDomain } from '@/lib/mcp/domain-check' +import { + McpDnsResolutionError, + McpDomainNotAllowedError, + McpSsrfError, + validateMcpDomain, + validateMcpServerSsrf, +} from '@/lib/mcp/domain-check' import { getParsedBody, withMcpAuth } from '@/lib/mcp/middleware' import { resolveMcpConfigEnvVars } from '@/lib/mcp/resolve-config' import type { McpTransport } from '@/lib/mcp/types' @@ -95,6 +101,18 @@ export const POST = withMcpAuth('write')( throw e } + try { + await validateMcpServerSsrf(body.url) + } catch (e) { + if (e instanceof McpDnsResolutionError) { + return createMcpErrorResponse(e, e.message, 502) + } + if (e instanceof McpSsrfError) { + return createMcpErrorResponse(e, e.message, 403) + } + throw e + } + // Build initial config for resolution const initialConfig = { id: `test-${requestId}`, @@ -119,7 +137,7 @@ export const POST = withMcpAuth('write')( logger.warn(`[${requestId}] Some environment variables not found:`, { missingVars }) } - // Re-validate domain after env var resolution + // Re-validate domain and SSRF after env var resolution try { validateMcpDomain(testConfig.url) } catch (e) { @@ -129,6 +147,18 @@ export const POST = withMcpAuth('write')( throw e } + try { + await validateMcpServerSsrf(testConfig.url) + } catch (e) { + if (e instanceof McpDnsResolutionError) { + return createMcpErrorResponse(e, e.message, 502) + } + if (e instanceof McpSsrfError) { + return createMcpErrorResponse(e, e.message, 403) + } + throw e + } + const testSecurityPolicy = { requireConsent: false, auditLevel: 'none' as const, diff --git a/apps/sim/app/api/tools/onepassword/utils.ts b/apps/sim/app/api/tools/onepassword/utils.ts index 703b7e5ac59..b4efe69d516 100644 --- a/apps/sim/app/api/tools/onepassword/utils.ts +++ b/apps/sim/app/api/tools/onepassword/utils.ts @@ -1,3 +1,4 @@ +import dns from 'dns/promises' import type { Item, ItemCategory, @@ -8,6 +9,9 @@ import type { VaultOverview, Website, } from '@1password/sdk' +import { createLogger } from '@sim/logger' +import * as ipaddr from 'ipaddr.js' +import { secureFetchWithPinnedIP } from '@/lib/core/security/input-validation.server' /** Connect-format field type strings returned by normalization. */ type ConnectFieldType = @@ -238,6 +242,63 @@ export async function createOnePasswordClient(serviceAccountToken: string) { }) } +const connectLogger = createLogger('OnePasswordConnect') + +/** + * Validates that a Connect server URL does not target cloud metadata endpoints. + * Allows private IPs and localhost since 1Password Connect is designed to be self-hosted. + * Returns the resolved IP for DNS pinning to prevent TOCTOU rebinding. + * @throws Error if the URL is invalid, points to a link-local address, or DNS fails. + */ +async function validateConnectServerUrl(serverUrl: string): Promise { + let hostname: string + try { + hostname = new URL(serverUrl).hostname + } catch { + throw new Error('1Password server URL is not a valid URL') + } + + const clean = + hostname.startsWith('[') && hostname.endsWith(']') ? hostname.slice(1, -1) : hostname + + if (ipaddr.isValid(clean)) { + const addr = ipaddr.process(clean) + if (addr.range() === 'linkLocal') { + throw new Error('1Password server URL cannot point to a link-local address') + } + return clean + } + + try { + const { address } = await dns.lookup(clean, { verbatim: true }) + if (ipaddr.isValid(address) && ipaddr.process(address).range() === 'linkLocal') { + connectLogger.warn('1Password Connect server URL resolves to link-local IP', { + hostname: clean, + resolvedIP: address, + }) + throw new Error('1Password server URL resolves to a link-local address') + } + return address + } catch (error) { + if (error instanceof Error && error.message.startsWith('1Password')) throw error + connectLogger.warn('DNS lookup failed for 1Password Connect server URL', { + hostname: clean, + error: error instanceof Error ? error.message : String(error), + }) + throw new Error('1Password server URL hostname could not be resolved') + } +} + +/** Minimal response shape used by all connectRequest callers. */ +export interface ConnectResponse { + ok: boolean + status: number + statusText: string + // eslint-disable-next-line @typescript-eslint/no-explicit-any + json: () => Promise + text: () => Promise +} + /** Proxy a request to the 1Password Connect Server. */ export async function connectRequest(options: { serverUrl: string @@ -246,7 +307,9 @@ export async function connectRequest(options: { method: string body?: unknown query?: string -}): Promise { +}): Promise { + const resolvedIP = await validateConnectServerUrl(options.serverUrl) + const base = options.serverUrl.replace(/\/$/, '') const queryStr = options.query ? `?${options.query}` : '' const url = `${base}${options.path}${queryStr}` @@ -259,10 +322,11 @@ export async function connectRequest(options: { headers['Content-Type'] = 'application/json' } - return fetch(url, { + return secureFetchWithPinnedIP(url, resolvedIP, { method: options.method, headers, body: options.body ? JSON.stringify(options.body) : undefined, + allowHttp: true, }) } diff --git a/apps/sim/app/api/tools/ssh/execute-command/route.ts b/apps/sim/app/api/tools/ssh/execute-command/route.ts index 94bd2b365ba..ba7f7b91c89 100644 --- a/apps/sim/app/api/tools/ssh/execute-command/route.ts +++ b/apps/sim/app/api/tools/ssh/execute-command/route.ts @@ -3,7 +3,12 @@ import { createLogger } from '@sim/logger' import { type NextRequest, NextResponse } from 'next/server' import { z } from 'zod' import { checkInternalAuth } from '@/lib/auth/hybrid' -import { createSSHConnection, executeSSHCommand, sanitizeCommand } from '@/app/api/tools/ssh/utils' +import { + createSSHConnection, + escapeShellArg, + executeSSHCommand, + sanitizeCommand, +} from '@/app/api/tools/ssh/utils' const logger = createLogger('SSHExecuteCommandAPI') @@ -52,7 +57,8 @@ export async function POST(request: NextRequest) { try { let command = sanitizeCommand(params.command) if (params.workingDirectory) { - command = `cd "${params.workingDirectory}" && ${command}` + const escapedWorkDir = escapeShellArg(params.workingDirectory) + command = `cd '${escapedWorkDir}' && ${command}` } const result = await executeSSHCommand(client, command) diff --git a/apps/sim/app/api/tools/ssh/execute-script/route.ts b/apps/sim/app/api/tools/ssh/execute-script/route.ts index 55c6df58f3c..f52dbc00c54 100644 --- a/apps/sim/app/api/tools/ssh/execute-script/route.ts +++ b/apps/sim/app/api/tools/ssh/execute-script/route.ts @@ -55,9 +55,10 @@ export async function POST(request: NextRequest) { const escapedScriptPath = escapeShellArg(scriptPath) const escapedInterpreter = escapeShellArg(params.interpreter) - let command = `cat > '${escapedScriptPath}' << 'SIMEOF' + const heredocDelimiter = `SIMEOF_${randomUUID().replace(/-/g, '')}` + let command = `cat > '${escapedScriptPath}' << '${heredocDelimiter}' ${params.script} -SIMEOF +${heredocDelimiter} chmod +x '${escapedScriptPath}'` if (params.workingDirectory) { diff --git a/apps/sim/background/schedule-execution.ts b/apps/sim/background/schedule-execution.ts index d1231e16a61..6aa3a306044 100644 --- a/apps/sim/background/schedule-execution.ts +++ b/apps/sim/background/schedule-execution.ts @@ -913,7 +913,7 @@ export async function executeJobInline(payload: JobExecutionPayload) { try { const url = buildAPIUrl('/api/mothership/execute') - const headers = await buildAuthHeaders() + const headers = await buildAuthHeaders(jobRecord.sourceUserId) const body = { messages: [{ role: 'user', content: promptText }], diff --git a/apps/sim/executor/handlers/agent/agent-handler.ts b/apps/sim/executor/handlers/agent/agent-handler.ts index 27873de1142..492d07c9b00 100644 --- a/apps/sim/executor/handlers/agent/agent-handler.ts +++ b/apps/sim/executor/handlers/agent/agent-handler.ts @@ -289,7 +289,7 @@ export class AgentBlockHandler implements BlockHandler { } try { - const headers = await buildAuthHeaders() + const headers = await buildAuthHeaders(ctx.userId) const params: Record = {} if (ctx.workspaceId) { @@ -467,7 +467,7 @@ export class AgentBlockHandler implements BlockHandler { throw new Error('workflowId is required for internal JWT authentication') } - const headers = await buildAuthHeaders() + const headers = await buildAuthHeaders(ctx.userId) const url = buildAPIUrl('/api/mcp/tools/discover', { serverId, workspaceId: ctx.workspaceId, diff --git a/apps/sim/executor/handlers/evaluator/evaluator-handler.ts b/apps/sim/executor/handlers/evaluator/evaluator-handler.ts index 710db01ee15..95d25ba4656 100644 --- a/apps/sim/executor/handlers/evaluator/evaluator-handler.ts +++ b/apps/sim/executor/handlers/evaluator/evaluator-handler.ts @@ -134,7 +134,7 @@ export class EvaluatorBlockHandler implements BlockHandler { const response = await fetch(url.toString(), { method: 'POST', - headers: await buildAuthHeaders(), + headers: await buildAuthHeaders(ctx.userId), body: stringifyJSON(providerRequest), }) diff --git a/apps/sim/executor/handlers/mothership/mothership-handler.ts b/apps/sim/executor/handlers/mothership/mothership-handler.ts index d8aecacf6ca..97a399cd2b2 100644 --- a/apps/sim/executor/handlers/mothership/mothership-handler.ts +++ b/apps/sim/executor/handlers/mothership/mothership-handler.ts @@ -32,7 +32,7 @@ export class MothershipBlockHandler implements BlockHandler { const chatId = crypto.randomUUID() const url = buildAPIUrl('/api/mothership/execute') - const headers = await buildAuthHeaders() + const headers = await buildAuthHeaders(ctx.userId) const body: Record = { messages, diff --git a/apps/sim/executor/handlers/router/router-handler.ts b/apps/sim/executor/handlers/router/router-handler.ts index c107f679377..f4c715d2ced 100644 --- a/apps/sim/executor/handlers/router/router-handler.ts +++ b/apps/sim/executor/handlers/router/router-handler.ts @@ -107,7 +107,7 @@ export class RouterBlockHandler implements BlockHandler { const response = await fetch(url.toString(), { method: 'POST', - headers: await buildAuthHeaders(), + headers: await buildAuthHeaders(ctx.userId), body: JSON.stringify(providerRequest), }) @@ -256,7 +256,7 @@ export class RouterBlockHandler implements BlockHandler { const response = await fetch(url.toString(), { method: 'POST', - headers: await buildAuthHeaders(), + headers: await buildAuthHeaders(ctx.userId), body: JSON.stringify(providerRequest), }) diff --git a/apps/sim/executor/handlers/workflow/workflow-handler.ts b/apps/sim/executor/handlers/workflow/workflow-handler.ts index 8aa478b6e84..50db926be7d 100644 --- a/apps/sim/executor/handlers/workflow/workflow-handler.ts +++ b/apps/sim/executor/handlers/workflow/workflow-handler.ts @@ -95,7 +95,7 @@ export class WorkflowBlockHandler implements BlockHandler { let childWorkflowSnapshotId: string | undefined try { if (ctx.isDeployedContext) { - const hasActiveDeployment = await this.checkChildDeployment(workflowId) + const hasActiveDeployment = await this.checkChildDeployment(workflowId, ctx.userId) if (!hasActiveDeployment) { throw new Error( `Child workflow is not deployed. Please deploy the workflow before invoking it.` @@ -104,8 +104,8 @@ export class WorkflowBlockHandler implements BlockHandler { } const childWorkflow = ctx.isDeployedContext - ? await this.loadChildWorkflowDeployed(workflowId) - : await this.loadChildWorkflow(workflowId) + ? await this.loadChildWorkflowDeployed(workflowId, ctx.userId) + : await this.loadChildWorkflow(workflowId, ctx.userId) if (!childWorkflow) { throw new Error(`Child workflow ${workflowId} not found`) @@ -323,8 +323,8 @@ export class WorkflowBlockHandler implements BlockHandler { return { chain, rootError: rootError.trim() || 'Unknown error' } } - private async loadChildWorkflow(workflowId: string) { - const headers = await buildAuthHeaders() + private async loadChildWorkflow(workflowId: string, userId?: string) { + const headers = await buildAuthHeaders(userId) const url = buildAPIUrl(`/api/workflows/${workflowId}`) const response = await fetch(url.toString(), { headers }) @@ -384,9 +384,9 @@ export class WorkflowBlockHandler implements BlockHandler { } } - private async checkChildDeployment(workflowId: string): Promise { + private async checkChildDeployment(workflowId: string, userId?: string): Promise { try { - const headers = await buildAuthHeaders() + const headers = await buildAuthHeaders(userId) const url = buildAPIUrl(`/api/workflows/${workflowId}/deployed`) const response = await fetch(url.toString(), { @@ -404,8 +404,8 @@ export class WorkflowBlockHandler implements BlockHandler { } } - private async loadChildWorkflowDeployed(workflowId: string) { - const headers = await buildAuthHeaders() + private async loadChildWorkflowDeployed(workflowId: string, userId?: string) { + const headers = await buildAuthHeaders(userId) const deployedUrl = buildAPIUrl(`/api/workflows/${workflowId}/deployed`) const deployedRes = await fetch(deployedUrl.toString(), { diff --git a/apps/sim/executor/utils/http.ts b/apps/sim/executor/utils/http.ts index ac4792dd74a..57ea632a41b 100644 --- a/apps/sim/executor/utils/http.ts +++ b/apps/sim/executor/utils/http.ts @@ -2,13 +2,13 @@ import { generateInternalToken } from '@/lib/auth/internal' import { getBaseUrl, getInternalApiBaseUrl } from '@/lib/core/utils/urls' import { HTTP } from '@/executor/constants' -export async function buildAuthHeaders(): Promise> { +export async function buildAuthHeaders(userId?: string): Promise> { const headers: Record = { 'Content-Type': HTTP.CONTENT_TYPE.JSON, } if (typeof window === 'undefined') { - const token = await generateInternalToken() + const token = await generateInternalToken(userId) headers.Authorization = `Bearer ${token}` } diff --git a/apps/sim/lib/auth/hybrid.ts b/apps/sim/lib/auth/hybrid.ts index af1e64da011..c9a9262ebc6 100644 --- a/apps/sim/lib/auth/hybrid.ts +++ b/apps/sim/lib/auth/hybrid.ts @@ -27,39 +27,18 @@ export interface AuthResult { /** * Resolves userId from a verified internal JWT token. - * Extracts userId from the JWT payload, URL search params, or POST body. + * Only trusts the userId embedded in the JWT payload — never from user-controlled sources. */ -async function resolveUserFromJwt( - request: NextRequest, +function resolveUserFromJwt( verificationUserId: string | null, options: { requireWorkflowId?: boolean } -): Promise { - let userId: string | null = verificationUserId - - if (!userId) { - const { searchParams } = new URL(request.url) - userId = searchParams.get('userId') - } - - if (!userId && request.method === 'POST') { - try { - const clonedRequest = request.clone() - const bodyText = await clonedRequest.text() - if (bodyText) { - const body = JSON.parse(bodyText) - userId = body.userId || body._context?.userId || null - } - } catch { - // Ignore JSON parse errors - } - } - - if (userId) { - return { success: true, userId, authType: AuthType.INTERNAL_JWT } +): AuthResult { + if (verificationUserId) { + return { success: true, userId: verificationUserId, authType: AuthType.INTERNAL_JWT } } if (options.requireWorkflowId !== false) { - return { success: false, error: 'userId required for internal JWT calls' } + return { success: false, error: 'userId required but not present in JWT' } } return { success: true, authType: AuthType.INTERNAL_JWT } @@ -103,7 +82,7 @@ export async function checkInternalAuth( return { success: false, error: 'Invalid internal token' } } - return resolveUserFromJwt(request, verification.userId || null, options) + return resolveUserFromJwt(verification.userId || null, options) } catch (error) { logger.error('Error in internal authentication:', error) return { @@ -143,7 +122,7 @@ export async function checkSessionOrInternalAuth( const verification = await verifyInternalToken(token) if (verification.valid) { - return resolveUserFromJwt(request, verification.userId || null, options) + return resolveUserFromJwt(verification.userId || null, options) } } @@ -192,7 +171,7 @@ export async function checkHybridAuth( const verification = await verifyInternalToken(token) if (verification.valid) { - return resolveUserFromJwt(request, verification.userId || null, options) + return resolveUserFromJwt(verification.userId || null, options) } } diff --git a/apps/sim/lib/core/security/deployment.ts b/apps/sim/lib/core/security/deployment.ts index 9b038ae0771..10a0781f83a 100644 --- a/apps/sim/lib/core/security/deployment.ts +++ b/apps/sim/lib/core/security/deployment.ts @@ -1,5 +1,6 @@ -import { createHash } from 'crypto' +import { createHash, createHmac, timingSafeEqual } from 'crypto' import type { NextRequest, NextResponse } from 'next/server' +import { env } from '@/lib/core/config/env' import { isDev } from '@/lib/core/config/feature-flags' /** @@ -7,21 +8,29 @@ import { isDev } from '@/lib/core/config/feature-flags' * These functions handle token generation, validation, cookies, and CORS. */ -function hashPassword(encryptedPassword: string): string { - return createHash('sha256').update(encryptedPassword).digest('hex').substring(0, 8) +function signPayload(payload: string): string { + return createHmac('sha256', env.BETTER_AUTH_SECRET).update(payload).digest('hex') } -function encryptAuthToken( +function passwordSlot(encryptedPassword?: string | null): string { + if (!encryptedPassword) return '' + return createHash('sha256').update(encryptedPassword).digest('hex').slice(0, 8) +} + +function generateAuthToken( deploymentId: string, type: string, encryptedPassword?: string | null ): string { - const pwHash = encryptedPassword ? hashPassword(encryptedPassword) : '' - return Buffer.from(`${deploymentId}:${type}:${Date.now()}:${pwHash}`).toString('base64') + const payload = `${deploymentId}:${type}:${Date.now()}:${passwordSlot(encryptedPassword)}` + const sig = signPayload(payload) + return Buffer.from(`${payload}:${sig}`).toString('base64') } /** - * Validates an authentication token for a deployment (chat or form) + * Validates an HMAC-signed authentication token for a deployment (chat or form). + * Includes a password-derived slot so changing the deployment password immediately + * invalidates existing sessions. */ export function validateAuthToken( token: string, @@ -30,27 +39,32 @@ export function validateAuthToken( ): boolean { try { const decoded = Buffer.from(token, 'base64').toString() - const parts = decoded.split(':') - const [storedId, _type, timestamp, storedPwHash] = parts + const lastColon = decoded.lastIndexOf(':') + if (lastColon === -1) return false + + const payload = decoded.slice(0, lastColon) + const sig = decoded.slice(lastColon + 1) - if (storedId !== deploymentId) { + const expectedSig = signPayload(payload) + if ( + sig.length !== expectedSig.length || + !timingSafeEqual(Buffer.from(sig), Buffer.from(expectedSig)) + ) { return false } - const createdAt = Number.parseInt(timestamp) - const now = Date.now() - const expireTime = 24 * 60 * 60 * 1000 + const parts = payload.split(':') + if (parts.length < 4) return false + const [storedId, _type, timestamp, storedPwSlot] = parts - if (now - createdAt > expireTime) { - return false - } + if (storedId !== deploymentId) return false - if (encryptedPassword) { - const currentPwHash = hashPassword(encryptedPassword) - if (storedPwHash !== currentPwHash) { - return false - } - } + const expectedPwSlot = passwordSlot(encryptedPassword) + if (storedPwSlot !== expectedPwSlot) return false + + const createdAt = Number.parseInt(timestamp) + const expireTime = 24 * 60 * 60 * 1000 + if (Date.now() - createdAt > expireTime) return false return true } catch (_e) { @@ -68,7 +82,7 @@ export function setDeploymentAuthCookie( authType: string, encryptedPassword?: string | null ): void { - const token = encryptAuthToken(deploymentId, authType, encryptedPassword) + const token = generateAuthToken(deploymentId, authType, encryptedPassword) response.cookies.set({ name: `${cookiePrefix}_auth_${deploymentId}`, value: token, @@ -82,15 +96,15 @@ export function setDeploymentAuthCookie( /** * Adds CORS headers to allow cross-origin requests for embedded deployments. - * Embedded chat widgets and forms are designed to run on any customer domain, - * so we reflect the requesting origin rather than restricting to an allowlist. + * We reflect the requesting origin to support same-site cross-origin setups + * (e.g. subdomains), but never set Allow-Credentials — auth cookies use + * SameSite=Lax and are handled within same-origin iframe contexts. */ export function addCorsHeaders(response: NextResponse, request: NextRequest): NextResponse { - const origin = request.headers.get('origin') || '' + const origin = request.headers.get('origin') if (origin) { response.headers.set('Access-Control-Allow-Origin', origin) - response.headers.set('Access-Control-Allow-Credentials', 'true') response.headers.set('Access-Control-Allow-Methods', 'GET, POST, OPTIONS') response.headers.set('Access-Control-Allow-Headers', 'Content-Type, X-Requested-With') } diff --git a/apps/sim/lib/core/security/input-validation.server.ts b/apps/sim/lib/core/security/input-validation.server.ts index dab2c769d9f..78c93d6d13d 100644 --- a/apps/sim/lib/core/security/input-validation.server.ts +++ b/apps/sim/lib/core/security/input-validation.server.ts @@ -24,7 +24,7 @@ export interface AsyncValidationResult extends ValidationResult { * - IPv4-mapped IPv6 (::ffff:127.0.0.1) * - Various edge cases that regex patterns miss */ -function isPrivateOrReservedIP(ip: string): boolean { +export function isPrivateOrReservedIP(ip: string): boolean { try { if (!ipaddr.isValid(ip)) { return true diff --git a/apps/sim/lib/mcp/domain-check.test.ts b/apps/sim/lib/mcp/domain-check.test.ts index d34974194ee..fa7b120416b 100644 --- a/apps/sim/lib/mcp/domain-check.test.ts +++ b/apps/sim/lib/mcp/domain-check.test.ts @@ -3,19 +3,45 @@ */ import { beforeEach, describe, expect, it, vi } from 'vitest' -const { mockGetAllowedMcpDomainsFromEnv } = vi.hoisted(() => ({ +const { mockGetAllowedMcpDomainsFromEnv, mockDnsLookup } = vi.hoisted(() => ({ mockGetAllowedMcpDomainsFromEnv: vi.fn<() => string[] | null>(), + mockDnsLookup: vi.fn(), })) vi.mock('@/lib/core/config/feature-flags', () => ({ getAllowedMcpDomainsFromEnv: mockGetAllowedMcpDomainsFromEnv, })) +vi.mock('@/lib/core/security/input-validation.server', () => ({ + isPrivateOrReservedIP: (ip: string) => { + if (ip.startsWith('10.') || ip.startsWith('192.168.')) return true + if (ip.startsWith('172.')) { + const second = Number.parseInt(ip.split('.')[1], 10) + if (second >= 16 && second <= 31) return true + } + if (ip.startsWith('169.254.')) return true + if (ip.startsWith('127.') || ip === '::1') return true + if (ip === '0.0.0.0') return true + return false + }, +})) + +vi.mock('dns/promises', () => ({ + default: { lookup: mockDnsLookup }, +})) + vi.mock('@/executor/utils/reference-validation', () => ({ createEnvVarPattern: () => /\{\{([^}]+)\}\}/g, })) -import { isMcpDomainAllowed, McpDomainNotAllowedError, validateMcpDomain } from './domain-check' +import { + isMcpDomainAllowed, + McpDnsResolutionError, + McpDomainNotAllowedError, + McpSsrfError, + validateMcpDomain, + validateMcpServerSsrf, +} from './domain-check' describe('McpDomainNotAllowedError', () => { it.concurrent('creates error with correct name and message', () => { @@ -299,3 +325,97 @@ describe('validateMcpDomain', () => { }) }) }) + +describe('validateMcpServerSsrf', () => { + beforeEach(() => { + vi.clearAllMocks() + mockGetAllowedMcpDomainsFromEnv.mockReturnValue(null) + }) + + it('does nothing for undefined URL', async () => { + await expect(validateMcpServerSsrf(undefined)).resolves.toBeUndefined() + expect(mockDnsLookup).not.toHaveBeenCalled() + }) + + it('skips validation for env var URLs', async () => { + await expect(validateMcpServerSsrf('{{MCP_SERVER_URL}}')).resolves.toBeUndefined() + expect(mockDnsLookup).not.toHaveBeenCalled() + }) + + it('skips validation for URLs with env var in hostname', async () => { + await expect(validateMcpServerSsrf('https://{{MCP_HOST}}/mcp')).resolves.toBeUndefined() + expect(mockDnsLookup).not.toHaveBeenCalled() + }) + + it('allows localhost URLs without DNS lookup', async () => { + await expect(validateMcpServerSsrf('http://localhost:3000/mcp')).resolves.toBeUndefined() + expect(mockDnsLookup).not.toHaveBeenCalled() + }) + + it('allows 127.0.0.1 URLs without DNS lookup', async () => { + await expect(validateMcpServerSsrf('http://127.0.0.1:8080/mcp')).resolves.toBeUndefined() + expect(mockDnsLookup).not.toHaveBeenCalled() + }) + + it('allows URLs that resolve to public IPs', async () => { + mockDnsLookup.mockResolvedValue({ address: '93.184.216.34' }) + await expect(validateMcpServerSsrf('https://example.com/mcp')).resolves.toBeUndefined() + }) + + it('allows HTTP URLs on non-localhost hosts', async () => { + mockDnsLookup.mockResolvedValue({ address: '93.184.216.34' }) + await expect(validateMcpServerSsrf('http://example.com:3000/mcp')).resolves.toBeUndefined() + }) + + it('throws McpSsrfError for cloud metadata IP literal', async () => { + await expect(validateMcpServerSsrf('http://169.254.169.254/latest/meta-data/')).rejects.toThrow( + McpSsrfError + ) + expect(mockDnsLookup).not.toHaveBeenCalled() + }) + + it('throws McpSsrfError for RFC-1918 IP literal', async () => { + await expect(validateMcpServerSsrf('http://10.0.0.1/mcp')).rejects.toThrow(McpSsrfError) + }) + + it('throws McpSsrfError for 192.168.x.x IP literal', async () => { + await expect(validateMcpServerSsrf('http://192.168.1.1/mcp')).rejects.toThrow(McpSsrfError) + }) + + it('throws McpSsrfError for URLs resolving to private IPs', async () => { + mockDnsLookup.mockResolvedValue({ address: '10.0.0.5' }) + await expect(validateMcpServerSsrf('https://internal.corp/mcp')).rejects.toThrow(McpSsrfError) + }) + + it('throws McpSsrfError for URLs resolving to link-local IPs', async () => { + mockDnsLookup.mockResolvedValue({ address: '169.254.169.254' }) + await expect(validateMcpServerSsrf('https://metadata.internal/latest')).rejects.toThrow( + McpSsrfError + ) + }) + + it('throws McpDnsResolutionError when DNS lookup fails', async () => { + mockDnsLookup.mockRejectedValue(new Error('ENOTFOUND')) + await expect(validateMcpServerSsrf('https://nonexistent.invalid/mcp')).rejects.toThrow( + McpDnsResolutionError + ) + }) + + it('allows URLs resolving to loopback (localhost alias)', async () => { + mockDnsLookup.mockResolvedValue({ address: '127.0.0.1' }) + await expect(validateMcpServerSsrf('http://my-local-alias:3000/mcp')).resolves.toBeUndefined() + }) + + it('throws for malformed URLs', async () => { + await expect(validateMcpServerSsrf('not-a-url')).rejects.toThrow(McpSsrfError) + }) + + it('skips all checks when ALLOWED_MCP_DOMAINS is configured', async () => { + mockGetAllowedMcpDomainsFromEnv.mockReturnValue(['internal.corp']) + await expect(validateMcpServerSsrf('http://10.0.0.1/mcp')).resolves.toBeUndefined() + await expect( + validateMcpServerSsrf('http://169.254.169.254/latest/meta-data/') + ).resolves.toBeUndefined() + expect(mockDnsLookup).not.toHaveBeenCalled() + }) +}) diff --git a/apps/sim/lib/mcp/domain-check.ts b/apps/sim/lib/mcp/domain-check.ts index 84d4d59d5eb..fa09d163e96 100644 --- a/apps/sim/lib/mcp/domain-check.ts +++ b/apps/sim/lib/mcp/domain-check.ts @@ -1,6 +1,12 @@ +import dns from 'dns/promises' +import { createLogger } from '@sim/logger' +import * as ipaddr from 'ipaddr.js' import { getAllowedMcpDomainsFromEnv } from '@/lib/core/config/feature-flags' +import { isPrivateOrReservedIP } from '@/lib/core/security/input-validation.server' import { createEnvVarPattern } from '@/executor/utils/reference-validation' +const logger = createLogger('McpDomainCheck') + export class McpDomainNotAllowedError extends Error { constructor(domain: string) { super(`MCP server domain "${domain}" is not allowed by the server's ALLOWED_MCP_DOMAINS policy`) @@ -8,6 +14,20 @@ export class McpDomainNotAllowedError extends Error { } } +export class McpSsrfError extends Error { + constructor(message: string) { + super(message) + this.name = 'McpSsrfError' + } +} + +export class McpDnsResolutionError extends Error { + constructor(hostname: string) { + super(`MCP server URL hostname "${hostname}" could not be resolved`) + this.name = 'McpDnsResolutionError' + } +} + /** * Core domain check. Returns null if the URL is allowed, or the hostname/url * string to use in the rejection error. @@ -76,3 +96,85 @@ export function validateMcpDomain(url: string | undefined): void { throw new McpDomainNotAllowedError(rejected) } } + +/** + * Returns true if the IP is a loopback address (full 127.0.0.0/8 range, or ::1). + */ +function isLoopbackIP(ip: string): boolean { + try { + if (!ipaddr.isValid(ip)) return false + return ipaddr.process(ip).range() === 'loopback' + } catch { + return false + } +} + +/** + * Returns true if the hostname is localhost or a loopback IP literal. + * Expects IPv6 brackets to already be stripped. + */ +function isLocalhostHostname(hostname: string): boolean { + const clean = hostname.toLowerCase() + if (clean === 'localhost') return true + return ipaddr.isValid(clean) && isLoopbackIP(clean) +} + +/** + * Validates an MCP server URL against SSRF attacks by resolving DNS and + * rejecting private/reserved IP ranges (RFC-1918, link-local, cloud metadata). + * + * Only active when ALLOWED_MCP_DOMAINS is **not configured**. When an admin + * has set an explicit domain allowlist, they control which domains are + * reachable and private-network MCP servers are legitimate. Applying SSRF + * blocking on top of an admin-curated list would break self-hosted + * deployments where MCP servers run on internal networks. + * + * Does NOT enforce protocol (HTTP is allowed) or block service ports — MCP + * servers legitimately run on HTTP and on arbitrary ports. + * + * Localhost/loopback is always allowed for local dev MCP servers. + * URLs with env var references in the hostname are skipped — they will be + * validated after resolution at execution time. + * + * @throws McpSsrfError if the URL resolves to a blocked IP address + */ +export async function validateMcpServerSsrf(url: string | undefined): Promise { + if (!url) return + if (getAllowedMcpDomainsFromEnv() !== null) return + if (hasEnvVarInHostname(url)) return + + let hostname: string + try { + hostname = new URL(url).hostname + } catch { + throw new McpSsrfError('MCP server URL is not a valid URL') + } + + const cleanHostname = + hostname.startsWith('[') && hostname.endsWith(']') ? hostname.slice(1, -1) : hostname + + if (isLocalhostHostname(cleanHostname)) return + + if (ipaddr.isValid(cleanHostname) && isPrivateOrReservedIP(cleanHostname)) { + throw new McpSsrfError('MCP server URL cannot point to a private or reserved IP address') + } + + try { + const { address } = await dns.lookup(cleanHostname, { verbatim: true }) + + if (isPrivateOrReservedIP(address) && !isLoopbackIP(address)) { + logger.warn('MCP server URL resolves to blocked IP address', { + hostname, + resolvedIP: address, + }) + throw new McpSsrfError('MCP server URL resolves to a blocked IP address') + } + } catch (error) { + if (error instanceof McpSsrfError) throw error + logger.warn('DNS lookup failed for MCP server URL', { + hostname, + error: error instanceof Error ? error.message : String(error), + }) + throw new McpDnsResolutionError(cleanHostname) + } +} diff --git a/apps/sim/lib/mcp/service.ts b/apps/sim/lib/mcp/service.ts index 69e7cc81178..44326c31602 100644 --- a/apps/sim/lib/mcp/service.ts +++ b/apps/sim/lib/mcp/service.ts @@ -10,7 +10,11 @@ import { isTest } from '@/lib/core/config/feature-flags' import { generateRequestId } from '@/lib/core/utils/request' import { McpClient } from '@/lib/mcp/client' import { mcpConnectionManager } from '@/lib/mcp/connection-manager' -import { isMcpDomainAllowed, validateMcpDomain } from '@/lib/mcp/domain-check' +import { + isMcpDomainAllowed, + validateMcpDomain, + validateMcpServerSsrf, +} from '@/lib/mcp/domain-check' import { resolveMcpConfigEnvVars } from '@/lib/mcp/resolve-config' import { createMcpCacheAdapter, @@ -68,6 +72,7 @@ class McpService { strict: true, }) validateMcpDomain(resolvedConfig.url) + await validateMcpServerSsrf(resolvedConfig.url) return resolvedConfig } diff --git a/apps/sim/lib/uploads/utils/file-utils.ts b/apps/sim/lib/uploads/utils/file-utils.ts index 85655391570..ac947260909 100644 --- a/apps/sim/lib/uploads/utils/file-utils.ts +++ b/apps/sim/lib/uploads/utils/file-utils.ts @@ -419,10 +419,11 @@ export function inferContextFromKey(key: string): StorageContext { if (key.startsWith('execution/')) return 'execution' if (key.startsWith('workspace/')) return 'workspace' if (key.startsWith('profile-pictures/')) return 'profile-pictures' + if (key.startsWith('og-images/')) return 'og-images' if (key.startsWith('logs/')) return 'logs' throw new Error( - `File key must start with a context prefix (kb/, chat/, copilot/, execution/, workspace/, profile-pictures/, or logs/). Got: ${key}` + `File key must start with a context prefix (kb/, chat/, copilot/, execution/, workspace/, profile-pictures/, og-images/, or logs/). Got: ${key}` ) } diff --git a/apps/sim/tools/index.ts b/apps/sim/tools/index.ts index fcec6fe3710..c354b691c2b 100644 --- a/apps/sim/tools/index.ts +++ b/apps/sim/tools/index.ts @@ -1017,12 +1017,13 @@ async function addInternalAuthIfNeeded( headers: Headers | Record, isInternalRoute: boolean, requestId: string, - context: string + context: string, + userId?: string ): Promise { if (typeof window === 'undefined') { if (isInternalRoute) { try { - const internalToken = await generateInternalToken() + const internalToken = await generateInternalToken(userId) if (headers instanceof Headers) { headers.set('Authorization', `Bearer ${internalToken}`) } else { @@ -1163,7 +1164,13 @@ async function executeToolRequest( } const headers = new Headers(requestParams.headers) - await addInternalAuthIfNeeded(headers, isInternalRoute, requestId, toolId) + await addInternalAuthIfNeeded( + headers, + isInternalRoute, + requestId, + toolId, + params._context?.userId + ) const shouldPropagateCallChain = isInternalRoute || isSelfOriginUrl(fullUrl) if (shouldPropagateCallChain) { @@ -1518,7 +1525,7 @@ async function executeMcpTool( if (typeof window === 'undefined') { try { - const internalToken = await generateInternalToken() + const internalToken = await generateInternalToken(executionContext?.userId) headers.Authorization = `Bearer ${internalToken}` } catch (error) { logger.error(`[${actualRequestId}] Failed to generate internal token:`, error) diff --git a/apps/sim/tools/utils.ts b/apps/sim/tools/utils.ts index a8a2885e93b..581d4a6ac58 100644 --- a/apps/sim/tools/utils.ts +++ b/apps/sim/tools/utils.ts @@ -388,7 +388,7 @@ async function fetchCustomToolFromAPI( if (typeof window === 'undefined') { try { const { generateInternalToken } = await import('@/lib/auth/internal') - const internalToken = await generateInternalToken() + const internalToken = await generateInternalToken(userId) headers.Authorization = `Bearer ${internalToken}` } catch (error) { logger.warn('Failed to generate internal token for custom tools fetch', { error })