diff --git a/apps/sim/app/api/chat/[identifier]/otp/route.ts b/apps/sim/app/api/chat/[identifier]/otp/route.ts index b2e129b5fa8..fcccc003e86 100644 --- a/apps/sim/app/api/chat/[identifier]/otp/route.ts +++ b/apps/sim/app/api/chat/[identifier]/otp/route.ts @@ -1,18 +1,24 @@ -import { randomInt } from 'crypto' import { db } from '@sim/db' -import { chat, verification } from '@sim/db/schema' +import { chat } from '@sim/db/schema' import { createLogger } from '@sim/logger' -import { generateId } from '@sim/utils/id' -import { and, eq, gt, isNull } from 'drizzle-orm' +import { and, eq, isNull } from 'drizzle-orm' import type { NextRequest } from 'next/server' import { renderOTPEmail } from '@/components/emails' import { requestChatEmailOtpContract, verifyChatEmailOtpContract } from '@/lib/api/contracts/chats' import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' -import { getRedisClient } from '@/lib/core/config/redis' -import type { TokenBucketConfig } from '@/lib/core/rate-limiter' import { RateLimiter } from '@/lib/core/rate-limiter' import { addCorsHeaders, isEmailAllowed } from '@/lib/core/security/deployment' -import { getStorageMethod } from '@/lib/core/storage' +import { + decodeOTPValue, + deleteOTP, + generateOTP, + getOTP, + incrementOTPAttempts, + MAX_OTP_ATTEMPTS, + OTP_EMAIL_RATE_LIMIT, + OTP_IP_RATE_LIMIT, + storeOTP, +} from '@/lib/core/security/otp' import { generateRequestId, getClientIp } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' import { sendEmail } from '@/lib/messaging/email/mailer' @@ -23,199 +29,6 @@ const logger = createLogger('ChatOtpAPI') const rateLimiter = new RateLimiter() -const OTP_IP_RATE_LIMIT: TokenBucketConfig = { - maxTokens: 10, - refillRate: 10, - refillIntervalMs: 15 * 60_000, -} - -const OTP_EMAIL_RATE_LIMIT: TokenBucketConfig = { - maxTokens: 3, - refillRate: 3, - refillIntervalMs: 15 * 60_000, -} - -function generateOTP(): string { - return randomInt(100000, 1000000).toString() -} - -const OTP_EXPIRY = 15 * 60 // 15 minutes -const OTP_EXPIRY_MS = OTP_EXPIRY * 1000 -const MAX_OTP_ATTEMPTS = 5 - -/** - * OTP values are stored as "code:attempts" (e.g. "654321:0"). - * This keeps the attempt counter in the same key/row as the OTP itself. - */ -function encodeOTPValue(otp: string, attempts: number): string { - return `${otp}:${attempts}` -} - -function decodeOTPValue(value: string): { otp: string; attempts: number } { - const lastColon = value.lastIndexOf(':') - if (lastColon === -1) return { otp: value, attempts: 0 } - const attempts = Number.parseInt(value.slice(lastColon + 1), 10) - return { otp: value.slice(0, lastColon), attempts: Number.isNaN(attempts) ? 0 : attempts } -} - -/** - * Stores OTP in Redis or database depending on storage method. - * Uses the verification table for database storage. - */ -async function storeOTP(email: string, chatId: string, otp: string): Promise { - const identifier = `chat-otp:${chatId}:${email}` - const storageMethod = getStorageMethod() - const value = encodeOTPValue(otp, 0) - - if (storageMethod === 'redis') { - const redis = getRedisClient() - if (!redis) { - throw new Error('Redis configured but client unavailable') - } - await redis.set(`otp:${email}:${chatId}`, value, 'EX', OTP_EXPIRY) - } else { - const now = new Date() - const expiresAt = new Date(now.getTime() + OTP_EXPIRY_MS) - - await db.transaction(async (tx) => { - await tx.delete(verification).where(eq(verification.identifier, identifier)) - await tx.insert(verification).values({ - id: generateId(), - identifier, - value, - expiresAt, - createdAt: now, - updatedAt: now, - }) - }) - } -} - -async function getOTP(email: string, chatId: string): Promise { - const identifier = `chat-otp:${chatId}:${email}` - const storageMethod = getStorageMethod() - - if (storageMethod === 'redis') { - const redis = getRedisClient() - if (!redis) { - throw new Error('Redis configured but client unavailable') - } - return redis.get(`otp:${email}:${chatId}`) - } - - const now = new Date() - const [record] = await db - .select({ value: verification.value }) - .from(verification) - .where(and(eq(verification.identifier, identifier), gt(verification.expiresAt, now))) - .limit(1) - - return record?.value ?? null -} - -/** - * Lua script for atomic OTP attempt increment. - * Returns: "LOCKED" if max attempts reached (key deleted), new encoded value otherwise, nil if key missing. - */ -const ATOMIC_INCREMENT_SCRIPT = ` -local val = redis.call('GET', KEYS[1]) -if not val then return nil end -local colon = val:find(':([^:]*$)') -local otp, attempts -if colon then - otp = val:sub(1, colon - 1) - attempts = tonumber(val:sub(colon + 1)) or 0 -else - otp = val - attempts = 0 -end -attempts = attempts + 1 -if attempts >= tonumber(ARGV[1]) then - redis.call('DEL', KEYS[1]) - return 'LOCKED' -end -local newVal = otp .. ':' .. attempts -local ttl = redis.call('TTL', KEYS[1]) -if ttl > 0 then - redis.call('SET', KEYS[1], newVal, 'EX', ttl) -else - redis.call('SET', KEYS[1], newVal) -end -return newVal -` - -/** - * Atomically increments OTP attempts. Returns 'locked' if max reached, 'incremented' otherwise. - */ -async function incrementOTPAttempts( - email: string, - chatId: string, - currentValue: string -): Promise<'locked' | 'incremented'> { - const identifier = `chat-otp:${chatId}:${email}` - const storageMethod = getStorageMethod() - - if (storageMethod === 'redis') { - const redis = getRedisClient() - if (!redis) { - throw new Error('Redis configured but client unavailable') - } - const key = `otp:${email}:${chatId}` - const result = await redis.eval(ATOMIC_INCREMENT_SCRIPT, 1, key, MAX_OTP_ATTEMPTS) - if (result === null || result === 'LOCKED') return 'locked' - return 'incremented' - } - - // DB path: optimistic locking with retry on conflict - const MAX_RETRIES = 3 - let value = currentValue - - for (let attempt = 0; attempt < MAX_RETRIES; attempt++) { - const { otp, attempts } = decodeOTPValue(value) - const newAttempts = attempts + 1 - - if (newAttempts >= MAX_OTP_ATTEMPTS) { - await db.delete(verification).where(eq(verification.identifier, identifier)) - return 'locked' - } - - const newValue = encodeOTPValue(otp, newAttempts) - const updated = await db - .update(verification) - .set({ value: newValue, updatedAt: new Date() }) - .where(and(eq(verification.identifier, identifier), eq(verification.value, value))) - .returning({ id: verification.id }) - - if (updated.length > 0) return 'incremented' - - // Conflict: another request already incremented — re-read and retry - const fresh = await getOTP(email, chatId) - if (!fresh) return 'locked' - value = fresh - } - - // Exhausted retries — re-read final state to determine outcome - const final = await getOTP(email, chatId) - if (!final) return 'locked' - const { attempts: finalAttempts } = decodeOTPValue(final) - return finalAttempts >= MAX_OTP_ATTEMPTS ? 'locked' : 'incremented' -} - -async function deleteOTP(email: string, chatId: string): Promise { - const identifier = `chat-otp:${chatId}:${email}` - const storageMethod = getStorageMethod() - - if (storageMethod === 'redis') { - const redis = getRedisClient() - if (!redis) { - throw new Error('Redis configured but client unavailable') - } - await redis.del(`otp:${email}:${chatId}`) - } else { - await db.delete(verification).where(eq(verification.identifier, identifier)) - } -} - export const POST = withRouteHandler( async (request: NextRequest, context: { params: Promise<{ identifier: string }> }) => { const { identifier } = await context.params @@ -305,7 +118,7 @@ export const POST = withRouteHandler( } const otp = generateOTP() - await storeOTP(email, deployment.id, otp) + await storeOTP('chat', deployment.id, email, otp) const emailHtml = await renderOTPEmail( otp, @@ -330,12 +143,9 @@ export const POST = withRouteHandler( logger.info(`[${requestId}] OTP sent to ${email} for chat ${deployment.id}`) return addCorsHeaders(createSuccessResponse({ message: 'Verification code sent' }), request) - } catch (error: any) { + } catch (error) { logger.error(`[${requestId}] Error processing OTP request:`, error) - return addCorsHeaders( - createErrorResponse(error.message || 'Failed to process request', 500), - request - ) + return addCorsHeaders(createErrorResponse('Failed to process request', 500), request) } } ) @@ -379,7 +189,7 @@ export const PUT = withRouteHandler( const deployment = deploymentResult[0] - const storedValue = await getOTP(email, deployment.id) + const storedValue = await getOTP('chat', deployment.id, email) if (!storedValue) { return addCorsHeaders( createErrorResponse('No verification code found, request a new one', 400), @@ -390,7 +200,7 @@ export const PUT = withRouteHandler( const { otp: storedOTP, attempts } = decodeOTPValue(storedValue) if (attempts >= MAX_OTP_ATTEMPTS) { - await deleteOTP(email, deployment.id) + await deleteOTP('chat', deployment.id, email) logger.warn(`[${requestId}] OTP already locked out for ${email}`) return addCorsHeaders( createErrorResponse('Too many failed attempts. Please request a new code.', 429), @@ -399,7 +209,7 @@ export const PUT = withRouteHandler( } if (storedOTP !== otp) { - const result = await incrementOTPAttempts(email, deployment.id, storedValue) + const result = await incrementOTPAttempts('chat', deployment.id, email, storedValue) if (result === 'locked') { logger.warn(`[${requestId}] OTP invalidated after max failed attempts for ${email}`) return addCorsHeaders( @@ -410,7 +220,7 @@ export const PUT = withRouteHandler( return addCorsHeaders(createErrorResponse('Invalid verification code', 400), request) } - await deleteOTP(email, deployment.id) + await deleteOTP('chat', deployment.id, email) const response = addCorsHeaders( createSuccessResponse({ @@ -426,12 +236,9 @@ export const PUT = withRouteHandler( setChatAuthCookie(response, deployment.id, deployment.authType, deployment.password) return response - } catch (error: any) { + } catch (error) { logger.error(`[${requestId}] Error verifying OTP:`, error) - return addCorsHeaders( - createErrorResponse(error.message || 'Failed to process request', 500), - request - ) + return addCorsHeaders(createErrorResponse('Failed to process request', 500), request) } } ) diff --git a/apps/sim/app/api/form/[identifier]/otp/route.test.ts b/apps/sim/app/api/form/[identifier]/otp/route.test.ts new file mode 100644 index 00000000000..4b3b13441d0 --- /dev/null +++ b/apps/sim/app/api/form/[identifier]/otp/route.test.ts @@ -0,0 +1,695 @@ +/** + * Tests for form OTP API route + * + * @vitest-environment node + */ +import { + redisConfigMock, + redisConfigMockFns, + requestUtilsMockFns, + workflowsApiUtilsMock, + workflowsApiUtilsMockFns, +} from '@sim/testing' +import { NextRequest } from 'next/server' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +const { + mockRedisSet, + mockRedisGet, + mockRedisDel, + mockRedisTtl, + mockRedisEval, + mockRedisClient, + mockDbSelect, + mockDbInsert, + mockDbDelete, + mockDbUpdate, + mockSendEmail, + mockRenderOTPEmail, + mockAddCorsHeaders, + mockSetFormAuthCookie, + mockGetStorageMethod, + mockZodParse, + mockGetEnv, +} = vi.hoisted(() => { + const mockRedisSet = vi.fn() + const mockRedisGet = vi.fn() + const mockRedisDel = vi.fn() + const mockRedisTtl = vi.fn() + const mockRedisEval = vi.fn() + const mockRedisClient = { + set: mockRedisSet, + get: mockRedisGet, + del: mockRedisDel, + ttl: mockRedisTtl, + eval: mockRedisEval, + } + return { + mockRedisSet, + mockRedisGet, + mockRedisDel, + mockRedisTtl, + mockRedisEval, + mockRedisClient, + mockDbSelect: vi.fn(), + mockDbInsert: vi.fn(), + mockDbDelete: vi.fn(), + mockDbUpdate: vi.fn(), + mockSendEmail: vi.fn(), + mockRenderOTPEmail: vi.fn(), + mockAddCorsHeaders: vi.fn(), + mockSetFormAuthCookie: vi.fn(), + mockGetStorageMethod: vi.fn(), + mockZodParse: vi.fn(), + mockGetEnv: vi.fn(), + } +}) + +const mockGetRedisClient = redisConfigMockFns.mockGetRedisClient +const mockCreateSuccessResponse = workflowsApiUtilsMockFns.mockCreateSuccessResponse +const mockCreateErrorResponse = workflowsApiUtilsMockFns.mockCreateErrorResponse + +vi.mock('@/lib/core/config/redis', () => redisConfigMock) + +vi.mock('@sim/db', () => ({ + db: { + select: mockDbSelect, + insert: mockDbInsert, + delete: mockDbDelete, + update: mockDbUpdate, + transaction: vi.fn(async (callback: (tx: Record) => unknown) => { + return callback({ + select: mockDbSelect, + insert: mockDbInsert, + delete: mockDbDelete, + update: mockDbUpdate, + }) + }), + }, +})) + +vi.mock('drizzle-orm', () => ({ + eq: vi.fn((field: string, value: string) => ({ field, value, type: 'eq' })), + and: vi.fn((...conditions: unknown[]) => ({ conditions, type: 'and' })), + gt: vi.fn((field: string, value: string) => ({ field, value, type: 'gt' })), + lt: vi.fn((field: string, value: string) => ({ field, value, type: 'lt' })), + isNull: vi.fn((field: unknown) => ({ field, type: 'isNull' })), +})) + +vi.mock('@/lib/core/storage', () => ({ + getStorageMethod: mockGetStorageMethod, +})) + +const { mockCheckRateLimitDirect } = vi.hoisted(() => ({ + mockCheckRateLimitDirect: vi.fn(), +})) + +vi.mock('@/lib/core/rate-limiter', () => ({ + RateLimiter: class { + checkRateLimitDirect = mockCheckRateLimitDirect + }, +})) + +vi.mock('@/lib/messaging/email/mailer', () => ({ + sendEmail: mockSendEmail, +})) + +vi.mock('@/components/emails', () => ({ + renderOTPEmail: mockRenderOTPEmail, +})) + +vi.mock('@/lib/core/security/deployment', () => ({ + addCorsHeaders: mockAddCorsHeaders, + isEmailAllowed: (email: string, allowedEmails: string[]) => { + if (allowedEmails.includes(email)) return true + const atIndex = email.indexOf('@') + if (atIndex > 0) { + const domain = email.substring(atIndex + 1) + if (domain && allowedEmails.some((allowed: string) => allowed === `@${domain}`)) return true + } + return false + }, +})) + +vi.mock('@/app/api/form/utils', () => ({ + setFormAuthCookie: mockSetFormAuthCookie, +})) + +vi.mock('@/app/api/workflows/utils', () => workflowsApiUtilsMock) + +vi.mock('@/lib/core/config/env', () => ({ + env: { + NEXT_PUBLIC_APP_URL: 'http://localhost:3000', + NODE_ENV: 'test', + }, + getEnv: mockGetEnv, + isTruthy: vi.fn().mockReturnValue(false), + isFalsy: vi.fn().mockReturnValue(true), +})) + +vi.mock('zod', () => { + class ZodError extends Error { + errors: Array<{ message: string }> + constructor(issues: Array<{ message: string }>) { + super('ZodError') + this.errors = issues + } + } + const chainable: Record = {} + const proxy: Record = new Proxy(chainable, { + get(target, prop) { + if (prop === 'parse') return mockZodParse + if (prop === 'safeParse') { + return (data: unknown) => ({ success: true, data }) + } + if (prop === 'then') return undefined + if (typeof prop === 'symbol') return Reflect.get(target, prop) + if (!(prop in target)) { + target[prop as string] = vi.fn().mockReturnValue(proxy) + } + return target[prop as string] + }, + }) + const makeChain = vi.fn(() => proxy) + return { + z: new Proxy( + { ZodError }, + { + get(target, prop) { + if (prop === 'ZodError') return ZodError + if (typeof prop === 'symbol') return Reflect.get(target, prop) + return makeChain + }, + } + ), + } +}) + +import { POST, PUT } from './route' + +describe('Form OTP API Route', () => { + const mockEmail = 'user@example.com' + const mockFormId = 'form-123' + const mockIdentifier = 'test-form' + const mockOTP = '123456' + + const deploymentRow = { + id: mockFormId, + authType: 'email', + allowedEmails: [mockEmail], + title: 'Test Form', + isActive: true, + } + + const verifyDeploymentRow = { + id: mockFormId, + authType: 'email', + password: null, + allowedEmails: [mockEmail], + isActive: true, + } + + const selectOnce = (rows: unknown[]) => + mockDbSelect.mockImplementationOnce(() => ({ + from: vi.fn().mockReturnValue({ + where: vi.fn().mockReturnValue({ + limit: vi.fn().mockResolvedValue(rows), + }), + }), + })) + + beforeEach(() => { + vi.clearAllMocks() + + vi.spyOn(Math, 'random').mockReturnValue(0.123456) + vi.spyOn(Date, 'now').mockReturnValue(1640995200000) + + vi.stubGlobal('crypto', { + ...crypto, + randomUUID: vi.fn().mockReturnValue('test-uuid-1234'), + }) + + mockGetRedisClient.mockReturnValue(mockRedisClient) + mockRedisSet.mockResolvedValue('OK') + mockRedisGet.mockResolvedValue(null) + mockRedisDel.mockResolvedValue(1) + mockRedisTtl.mockResolvedValue(600) + + mockDbSelect.mockImplementation(() => ({ + from: vi.fn().mockReturnValue({ + where: vi.fn().mockReturnValue({ + limit: vi.fn().mockResolvedValue([]), + }), + }), + })) + mockDbInsert.mockImplementation(() => ({ values: vi.fn().mockResolvedValue(undefined) })) + mockDbDelete.mockImplementation(() => ({ where: vi.fn().mockResolvedValue(undefined) })) + mockDbUpdate.mockImplementation(() => ({ + set: vi.fn().mockReturnValue({ where: vi.fn().mockResolvedValue(undefined) }), + })) + + mockGetStorageMethod.mockReturnValue('redis') + + mockSendEmail.mockResolvedValue({ success: true }) + mockRenderOTPEmail.mockResolvedValue('OTP Email') + + mockAddCorsHeaders.mockImplementation((response: unknown) => response) + mockCreateSuccessResponse.mockImplementation((data: unknown) => ({ + json: () => Promise.resolve(data), + status: 200, + })) + mockCreateErrorResponse.mockImplementation((message: string, status: number) => ({ + json: () => Promise.resolve({ error: message }), + status, + })) + + requestUtilsMockFns.mockGenerateRequestId.mockReturnValue('req-123') + requestUtilsMockFns.mockGetClientIp.mockReturnValue('1.2.3.4') + + mockCheckRateLimitDirect.mockResolvedValue({ + allowed: true, + remaining: 10, + resetAt: new Date(Date.now() + 60_000), + }) + + mockZodParse.mockImplementation((data: unknown) => data) + mockGetEnv.mockReturnValue('http://localhost:3000') + }) + + afterEach(() => { + vi.restoreAllMocks() + }) + + describe('POST /otp - request code', () => { + it('stores OTP in Redis when storage is redis and sends email', async () => { + selectOnce([deploymentRow]) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + await POST(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockRedisSet).toHaveBeenCalledWith( + `form-otp:${mockEmail}:${mockFormId}`, + expect.stringMatching(/^\d{6}:0$/), + 'EX', + 900 + ) + expect(mockSendEmail).toHaveBeenCalledWith( + expect.objectContaining({ to: mockEmail, subject: expect.stringContaining('Test Form') }) + ) + expect(mockDbInsert).not.toHaveBeenCalled() + }) + + it('stores OTP in database when storage is database', async () => { + mockGetStorageMethod.mockReturnValue('database') + mockGetRedisClient.mockReturnValue(null) + selectOnce([deploymentRow]) + const insertValues = vi.fn().mockResolvedValue(undefined) + mockDbInsert.mockImplementationOnce(() => ({ values: insertValues })) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + await POST(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(insertValues).toHaveBeenCalledWith( + expect.objectContaining({ + identifier: `form-otp:${mockFormId}:${mockEmail}`, + value: expect.stringMatching(/^\d{6}:0$/), + }) + ) + expect(mockRedisSet).not.toHaveBeenCalled() + }) + + it('returns 404 when form is not found', async () => { + selectOnce([]) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + await POST(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockCreateErrorResponse).toHaveBeenCalledWith('Form not found', 404) + expect(mockSendEmail).not.toHaveBeenCalled() + }) + + it('returns 403 when form is inactive', async () => { + selectOnce([{ ...deploymentRow, isActive: false }]) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + await POST(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockCreateErrorResponse).toHaveBeenCalledWith( + 'This form is currently unavailable', + 403 + ) + expect(mockSendEmail).not.toHaveBeenCalled() + }) + + it('returns 400 when form authType is not email', async () => { + selectOnce([{ ...deploymentRow, authType: 'public' }]) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + await POST(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockCreateErrorResponse).toHaveBeenCalledWith( + 'This form does not use email authentication', + 400 + ) + expect(mockSendEmail).not.toHaveBeenCalled() + }) + + it('returns 403 when email is not in allowedEmails', async () => { + selectOnce([{ ...deploymentRow, allowedEmails: ['other@example.com'] }]) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + await POST(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockCreateErrorResponse).toHaveBeenCalledWith( + 'Email not authorized for this form', + 403 + ) + expect(mockSendEmail).not.toHaveBeenCalled() + }) + + it('authorizes by domain match in allowedEmails', async () => { + selectOnce([{ ...deploymentRow, allowedEmails: ['@example.com'] }]) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + await POST(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockSendEmail).toHaveBeenCalled() + }) + + it('returns 429 with Retry-After when IP rate limit is exceeded', async () => { + mockCheckRateLimitDirect.mockResolvedValueOnce({ + allowed: false, + remaining: 0, + resetAt: new Date(Date.now() + 900_000), + retryAfterMs: 900_000, + }) + const headerSet = vi.fn() + mockCreateErrorResponse.mockImplementationOnce((message: string, status: number) => ({ + json: () => Promise.resolve({ error: message }), + status, + headers: { set: headerSet }, + })) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + const response = await POST(request, { + params: Promise.resolve({ identifier: mockIdentifier }), + }) + + expect(response.status).toBe(429) + expect(headerSet).toHaveBeenCalledWith('Retry-After', '900') + expect(mockSendEmail).not.toHaveBeenCalled() + expect(mockDbSelect).not.toHaveBeenCalled() + }) + + it('returns 429 with Retry-After when email rate limit is exceeded', async () => { + mockCheckRateLimitDirect + .mockResolvedValueOnce({ + allowed: true, + remaining: 9, + resetAt: new Date(Date.now() + 60_000), + }) + .mockResolvedValueOnce({ + allowed: false, + remaining: 0, + resetAt: new Date(Date.now() + 900_000), + retryAfterMs: 900_000, + }) + const headerSet = vi.fn() + mockCreateErrorResponse.mockImplementationOnce((message: string, status: number) => ({ + json: () => Promise.resolve({ error: message }), + status, + headers: { set: headerSet }, + })) + selectOnce([deploymentRow]) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + const response = await POST(request, { + params: Promise.resolve({ identifier: mockIdentifier }), + }) + + expect(response.status).toBe(429) + expect(headerSet).toHaveBeenCalledWith('Retry-After', '900') + expect(mockSendEmail).not.toHaveBeenCalled() + }) + + it('rate-limits the IP bucket before reading the deployment row', async () => { + mockCheckRateLimitDirect.mockResolvedValueOnce({ + allowed: false, + remaining: 0, + resetAt: new Date(Date.now() + 900_000), + retryAfterMs: 900_000, + }) + mockCreateErrorResponse.mockImplementationOnce((message: string, status: number) => ({ + json: () => Promise.resolve({ error: message }), + status, + headers: { set: vi.fn() }, + })) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + await POST(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockDbSelect).not.toHaveBeenCalled() + }) + + it('returns 500 when email send fails', async () => { + selectOnce([deploymentRow]) + mockSendEmail.mockResolvedValueOnce({ success: false, message: 'smtp down' }) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + await POST(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockCreateErrorResponse).toHaveBeenCalledWith('Failed to send verification email', 500) + }) + }) + + describe('PUT /otp - verify code', () => { + it('verifies OTP, deletes it, and sets the form auth cookie on success', async () => { + selectOnce([verifyDeploymentRow]) + mockRedisGet.mockResolvedValue(`${mockOTP}:0`) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'PUT', + body: JSON.stringify({ email: mockEmail, otp: mockOTP }), + }) + + await PUT(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockRedisGet).toHaveBeenCalledWith(`form-otp:${mockEmail}:${mockFormId}`) + expect(mockRedisDel).toHaveBeenCalledWith(`form-otp:${mockEmail}:${mockFormId}`) + expect(mockSetFormAuthCookie).toHaveBeenCalledWith( + expect.any(Object), + mockFormId, + 'email', + null + ) + expect(mockCreateSuccessResponse).toHaveBeenCalledWith({ authenticated: true }) + }) + + it('returns 404 when form is not found', async () => { + selectOnce([]) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'PUT', + body: JSON.stringify({ email: mockEmail, otp: mockOTP }), + }) + + await PUT(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockCreateErrorResponse).toHaveBeenCalledWith('Form not found', 404) + expect(mockSetFormAuthCookie).not.toHaveBeenCalled() + }) + + it('returns 403 when form is inactive at verify time', async () => { + selectOnce([{ ...verifyDeploymentRow, isActive: false }]) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'PUT', + body: JSON.stringify({ email: mockEmail, otp: mockOTP }), + }) + + await PUT(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockCreateErrorResponse).toHaveBeenCalledWith( + 'This form is currently unavailable', + 403 + ) + expect(mockSetFormAuthCookie).not.toHaveBeenCalled() + }) + + it('returns 403 when email is no longer in allowedEmails at verify time', async () => { + selectOnce([{ ...verifyDeploymentRow, allowedEmails: ['other@example.com'] }]) + mockRedisGet.mockResolvedValue(`${mockOTP}:0`) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'PUT', + body: JSON.stringify({ email: mockEmail, otp: mockOTP }), + }) + + await PUT(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockCreateErrorResponse).toHaveBeenCalledWith( + 'Email not authorized for this form', + 403 + ) + expect(mockSetFormAuthCookie).not.toHaveBeenCalled() + }) + + it('returns 400 when no OTP is stored', async () => { + selectOnce([verifyDeploymentRow]) + mockRedisGet.mockResolvedValue(null) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'PUT', + body: JSON.stringify({ email: mockEmail, otp: mockOTP }), + }) + + await PUT(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockCreateErrorResponse).toHaveBeenCalledWith( + 'No verification code found, request a new one', + 400 + ) + expect(mockSetFormAuthCookie).not.toHaveBeenCalled() + }) + + it('atomically increments attempts on wrong OTP and returns 400', async () => { + selectOnce([verifyDeploymentRow]) + mockRedisGet.mockResolvedValue('654321:0') + mockRedisEval.mockResolvedValue('654321:1') + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'PUT', + body: JSON.stringify({ email: mockEmail, otp: 'wrong1' }), + }) + + await PUT(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockRedisEval).toHaveBeenCalledWith( + expect.any(String), + 1, + `form-otp:${mockEmail}:${mockFormId}`, + 5 + ) + expect(mockCreateErrorResponse).toHaveBeenCalledWith('Invalid verification code', 400) + expect(mockSetFormAuthCookie).not.toHaveBeenCalled() + }) + + it('invalidates OTP and returns 429 after max failed attempts', async () => { + selectOnce([verifyDeploymentRow]) + mockRedisGet.mockResolvedValue('654321:4') + mockRedisEval.mockResolvedValue('LOCKED') + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'PUT', + body: JSON.stringify({ email: mockEmail, otp: 'wrong5' }), + }) + + await PUT(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockCreateErrorResponse).toHaveBeenCalledWith( + 'Too many failed attempts. Please request a new code.', + 429 + ) + expect(mockSetFormAuthCookie).not.toHaveBeenCalled() + }) + + it('rejects when stored OTP is already at max attempts', async () => { + selectOnce([verifyDeploymentRow]) + mockRedisGet.mockResolvedValue(`${mockOTP}:5`) + const deleteWhere = vi.fn().mockResolvedValue(undefined) + mockDbDelete.mockImplementation(() => ({ where: deleteWhere })) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'PUT', + body: JSON.stringify({ email: mockEmail, otp: mockOTP }), + }) + + await PUT(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockCreateErrorResponse).toHaveBeenCalledWith( + 'Too many failed attempts. Please request a new code.', + 429 + ) + expect(mockSetFormAuthCookie).not.toHaveBeenCalled() + }) + + it('uses database storage path when configured', async () => { + mockGetStorageMethod.mockReturnValue('database') + mockGetRedisClient.mockReturnValue(null) + let selectCallCount = 0 + mockDbSelect.mockImplementation(() => ({ + from: vi.fn().mockReturnValue({ + where: vi.fn().mockReturnValue({ + limit: vi.fn().mockImplementation(() => { + selectCallCount++ + if (selectCallCount === 1) return Promise.resolve([verifyDeploymentRow]) + return Promise.resolve([ + { + value: `${mockOTP}:0`, + expiresAt: new Date(Date.now() + 10 * 60 * 1000), + }, + ]) + }), + }), + }), + })) + const deleteWhere = vi.fn().mockResolvedValue(undefined) + mockDbDelete.mockImplementation(() => ({ where: deleteWhere })) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'PUT', + body: JSON.stringify({ email: mockEmail, otp: mockOTP }), + }) + + await PUT(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockDbDelete).toHaveBeenCalled() + expect(mockRedisDel).not.toHaveBeenCalled() + expect(mockSetFormAuthCookie).toHaveBeenCalled() + }) + }) +}) diff --git a/apps/sim/app/api/form/[identifier]/otp/route.ts b/apps/sim/app/api/form/[identifier]/otp/route.ts new file mode 100644 index 00000000000..0d9804efa55 --- /dev/null +++ b/apps/sim/app/api/form/[identifier]/otp/route.ts @@ -0,0 +1,261 @@ +import { db } from '@sim/db' +import { form } from '@sim/db/schema' +import { createLogger } from '@sim/logger' +import { and, eq, isNull } from 'drizzle-orm' +import type { NextRequest } from 'next/server' +import { renderOTPEmail } from '@/components/emails' +import { requestFormEmailOtpContract, verifyFormEmailOtpContract } from '@/lib/api/contracts/forms' +import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' +import { RateLimiter } from '@/lib/core/rate-limiter' +import { addCorsHeaders, isEmailAllowed } from '@/lib/core/security/deployment' +import { + decodeOTPValue, + deleteOTP, + generateOTP, + getOTP, + incrementOTPAttempts, + MAX_OTP_ATTEMPTS, + OTP_EMAIL_RATE_LIMIT, + OTP_IP_RATE_LIMIT, + storeOTP, +} from '@/lib/core/security/otp' +import { generateRequestId, getClientIp } from '@/lib/core/utils/request' +import { withRouteHandler } from '@/lib/core/utils/with-route-handler' +import { sendEmail } from '@/lib/messaging/email/mailer' +import { setFormAuthCookie } from '@/app/api/form/utils' +import { createErrorResponse, createSuccessResponse } from '@/app/api/workflows/utils' + +const logger = createLogger('FormOtpAPI') + +const rateLimiter = new RateLimiter() + +export const POST = withRouteHandler( + async (request: NextRequest, context: { params: Promise<{ identifier: string }> }) => { + const { identifier } = await context.params + const requestId = generateRequestId() + + try { + const ip = getClientIp(request) + const ipRateLimit = await rateLimiter.checkRateLimitDirect( + `form-otp:ip:${identifier}:${ip}`, + OTP_IP_RATE_LIMIT + ) + if (!ipRateLimit.allowed) { + logger.warn(`[${requestId}] OTP IP rate limit exceeded for ${identifier} from ${ip}`) + const retryAfter = Math.ceil( + (ipRateLimit.retryAfterMs ?? OTP_IP_RATE_LIMIT.refillIntervalMs) / 1000 + ) + const response = createErrorResponse('Too many requests. Please try again later.', 429) + response.headers.set('Retry-After', String(retryAfter)) + return addCorsHeaders(response, request) + } + + const parsed = await parseRequest(requestFormEmailOtpContract, request, context, { + validationErrorResponse: (error) => + addCorsHeaders( + createErrorResponse(getValidationErrorMessage(error, 'Invalid request'), 400), + request + ), + }) + if (!parsed.success) return parsed.response + const { email } = parsed.data.body + + const deploymentResult = await db + .select({ + id: form.id, + authType: form.authType, + allowedEmails: form.allowedEmails, + title: form.title, + isActive: form.isActive, + }) + .from(form) + .where(and(eq(form.identifier, identifier), isNull(form.archivedAt))) + .limit(1) + + if (deploymentResult.length === 0) { + logger.warn(`[${requestId}] Form not found for identifier: ${identifier}`) + return addCorsHeaders(createErrorResponse('Form not found', 404), request) + } + + const deployment = deploymentResult[0] + + if (!deployment.isActive) { + return addCorsHeaders( + createErrorResponse('This form is currently unavailable', 403), + request + ) + } + + if (deployment.authType !== 'email') { + return addCorsHeaders( + createErrorResponse('This form does not use email authentication', 400), + request + ) + } + + const allowedEmails: string[] = Array.isArray(deployment.allowedEmails) + ? (deployment.allowedEmails as string[]) + : [] + + if (!isEmailAllowed(email, allowedEmails)) { + return addCorsHeaders( + createErrorResponse('Email not authorized for this form', 403), + request + ) + } + + const emailRateLimit = await rateLimiter.checkRateLimitDirect( + `form-otp:email:${deployment.id}:${email.toLowerCase()}`, + OTP_EMAIL_RATE_LIMIT + ) + if (!emailRateLimit.allowed) { + logger.warn( + `[${requestId}] OTP email rate limit exceeded for ${email} on form ${deployment.id}` + ) + const retryAfter = Math.ceil( + (emailRateLimit.retryAfterMs ?? OTP_EMAIL_RATE_LIMIT.refillIntervalMs) / 1000 + ) + const response = createErrorResponse( + 'Too many verification code requests. Please try again later.', + 429 + ) + response.headers.set('Retry-After', String(retryAfter)) + return addCorsHeaders(response, request) + } + + const otp = generateOTP() + await storeOTP('form', deployment.id, email, otp) + + const emailHtml = await renderOTPEmail( + otp, + email, + 'email-verification', + deployment.title || 'Form' + ) + + const emailResult = await sendEmail({ + to: email, + subject: `Verification code for ${deployment.title || 'Form'}`, + html: emailHtml, + }) + + if (!emailResult.success) { + logger.error(`[${requestId}] Failed to send OTP email:`, emailResult.message) + return addCorsHeaders( + createErrorResponse('Failed to send verification email', 500), + request + ) + } + + logger.info(`[${requestId}] OTP sent to ${email} for form ${deployment.id}`) + return addCorsHeaders(createSuccessResponse({ message: 'Verification code sent' }), request) + } catch (error) { + logger.error(`[${requestId}] Error processing OTP request:`, error) + return addCorsHeaders(createErrorResponse('Failed to process request', 500), request) + } + } +) + +export const PUT = withRouteHandler( + async (request: NextRequest, context: { params: Promise<{ identifier: string }> }) => { + const { identifier } = await context.params + const requestId = generateRequestId() + + try { + const parsed = await parseRequest(verifyFormEmailOtpContract, request, context, { + validationErrorResponse: (error) => + addCorsHeaders( + createErrorResponse(getValidationErrorMessage(error, 'Invalid request'), 400), + request + ), + }) + if (!parsed.success) return parsed.response + const { email, otp } = parsed.data.body + + const deploymentResult = await db + .select({ + id: form.id, + authType: form.authType, + password: form.password, + allowedEmails: form.allowedEmails, + isActive: form.isActive, + }) + .from(form) + .where(and(eq(form.identifier, identifier), isNull(form.archivedAt))) + .limit(1) + + if (deploymentResult.length === 0) { + logger.warn(`[${requestId}] Form not found for identifier: ${identifier}`) + return addCorsHeaders(createErrorResponse('Form not found', 404), request) + } + + const deployment = deploymentResult[0] + + if (!deployment.isActive) { + return addCorsHeaders( + createErrorResponse('This form is currently unavailable', 403), + request + ) + } + + if (deployment.authType !== 'email') { + return addCorsHeaders( + createErrorResponse('This form does not use email authentication', 400), + request + ) + } + + const allowedEmails: string[] = Array.isArray(deployment.allowedEmails) + ? (deployment.allowedEmails as string[]) + : [] + + if (!isEmailAllowed(email, allowedEmails)) { + return addCorsHeaders( + createErrorResponse('Email not authorized for this form', 403), + request + ) + } + + const storedValue = await getOTP('form', deployment.id, email) + if (!storedValue) { + return addCorsHeaders( + createErrorResponse('No verification code found, request a new one', 400), + request + ) + } + + const { otp: storedOTP, attempts } = decodeOTPValue(storedValue) + + if (attempts >= MAX_OTP_ATTEMPTS) { + await deleteOTP('form', deployment.id, email) + logger.warn(`[${requestId}] OTP already locked out for ${email}`) + return addCorsHeaders( + createErrorResponse('Too many failed attempts. Please request a new code.', 429), + request + ) + } + + if (storedOTP !== otp) { + const result = await incrementOTPAttempts('form', deployment.id, email, storedValue) + if (result === 'locked') { + logger.warn(`[${requestId}] OTP invalidated after max failed attempts for ${email}`) + return addCorsHeaders( + createErrorResponse('Too many failed attempts. Please request a new code.', 429), + request + ) + } + return addCorsHeaders(createErrorResponse('Invalid verification code', 400), request) + } + + await deleteOTP('form', deployment.id, email) + + const response = addCorsHeaders(createSuccessResponse({ authenticated: true }), request) + setFormAuthCookie(response, deployment.id, deployment.authType, deployment.password) + + return response + } catch (error) { + logger.error(`[${requestId}] Error verifying OTP:`, error) + return addCorsHeaders(createErrorResponse('Failed to process request', 500), request) + } + } +) diff --git a/apps/sim/app/api/form/utils.test.ts b/apps/sim/app/api/form/utils.test.ts index 9c36ccc6e92..1826d9386c1 100644 --- a/apps/sim/app/api/form/utils.test.ts +++ b/apps/sim/app/api/form/utils.test.ts @@ -239,18 +239,20 @@ describe('Form API Utils', () => { }, } as any - // Exact email match should authorize + // Exact email match should require OTP verification, not authorize directly mockIsEmailAllowed.mockReturnValue(true) const result1 = await validateFormAuth('request-id', deployment, mockRequest, { email: 'user@example.com', }) - expect(result1.authorized).toBe(true) + expect(result1.authorized).toBe(false) + expect(result1.error).toBe('otp_required') - // Domain match should authorize + // Domain match should also require OTP verification const result2 = await validateFormAuth('request-id', deployment, mockRequest, { email: 'other@company.com', }) - expect(result2.authorized).toBe(true) + expect(result2.authorized).toBe(false) + expect(result2.error).toBe('otp_required') // Unknown email should not authorize mockIsEmailAllowed.mockReturnValue(false) diff --git a/apps/sim/app/api/form/utils.ts b/apps/sim/app/api/form/utils.ts index 55bbe65e17f..7b1f1df54dc 100644 --- a/apps/sim/app/api/form/utils.ts +++ b/apps/sim/app/api/form/utils.ts @@ -159,7 +159,7 @@ export async function validateFormAuth( const allowedEmails: string[] = deployment.allowedEmails || [] if (isEmailAllowed(email, allowedEmails)) { - return { authorized: true } + return { authorized: false, error: 'otp_required' } } return { authorized: false, error: 'Email not authorized for this form' } diff --git a/apps/sim/app/api/knowledge/[id]/documents/upsert/route.test.ts b/apps/sim/app/api/knowledge/[id]/documents/upsert/route.test.ts new file mode 100644 index 00000000000..d5f64cf306e --- /dev/null +++ b/apps/sim/app/api/knowledge/[id]/documents/upsert/route.test.ts @@ -0,0 +1,111 @@ +/** + * Tests for knowledge base document upsert API route + * + * @vitest-environment node + */ +import { + auditMock, + createMockRequest, + hybridAuthMock, + hybridAuthMockFns, + knowledgeApiUtilsMock, +} from '@sim/testing' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { mockDbChain } = vi.hoisted(() => { + const chain = { + select: vi.fn().mockReturnThis(), + from: vi.fn().mockReturnThis(), + where: vi.fn().mockReturnThis(), + limit: vi.fn().mockResolvedValue([]), + } + return { mockDbChain: chain } +}) + +vi.mock('@sim/db', () => ({ db: mockDbChain })) +vi.mock('@/lib/auth/hybrid', () => hybridAuthMock) +vi.mock('@/app/api/knowledge/utils', () => knowledgeApiUtilsMock) +vi.mock('@sim/audit', () => auditMock) + +vi.mock('@/lib/knowledge/documents/service', () => ({ + createDocumentRecords: vi.fn(), + deleteDocument: vi.fn(), + getProcessingConfig: vi.fn().mockReturnValue({ maxConcurrentDocuments: 1, batchSize: 1 }), + processDocumentsWithQueue: vi.fn(), +})) + +import { createDocumentRecords, processDocumentsWithQueue } from '@/lib/knowledge/documents/service' +import { POST } from '@/app/api/knowledge/[id]/documents/upsert/route' +import { checkKnowledgeBaseWriteAccess } from '@/app/api/knowledge/utils' + +describe('POST /api/knowledge/[id]/documents/upsert', () => { + const params = Promise.resolve({ id: 'kb-123' }) + + beforeEach(() => { + vi.clearAllMocks() + mockDbChain.select.mockReturnThis() + mockDbChain.from.mockReturnThis() + mockDbChain.where.mockReturnThis() + mockDbChain.limit.mockResolvedValue([]) + + hybridAuthMockFns.mockCheckSessionOrInternalAuth.mockResolvedValue({ + success: true, + userId: 'user-1', + authType: 'session', + userName: 'Test User', + userEmail: 'test@example.com', + }) + + vi.mocked(checkKnowledgeBaseWriteAccess).mockResolvedValue({ + hasAccess: true, + knowledgeBase: { id: 'kb-123', userId: 'user-1', workspaceId: 'ws-1', name: 'KB' }, + } as any) + + vi.mocked(createDocumentRecords).mockResolvedValue([ + { documentId: 'doc-new', filename: 'note.txt' }, + ] as any) + vi.mocked(processDocumentsWithQueue).mockResolvedValue(undefined as any) + }) + + const baseBody = { + filename: 'note.txt', + fileSize: 11, + mimeType: 'text/plain', + } + + it('accepts a data: URI', async () => { + const req = createMockRequest('POST', { + ...baseBody, + fileUrl: 'data:text/plain;base64,SGVsbG8gd29ybGQ=', + }) + const res = await POST(req, { params }) + expect(res.status).toBe(200) + expect(createDocumentRecords).toHaveBeenCalled() + }) + + it('accepts an https URL', async () => { + const req = createMockRequest('POST', { + ...baseBody, + fileUrl: 'https://example.com/note.txt', + }) + const res = await POST(req, { params }) + expect(res.status).toBe(200) + expect(createDocumentRecords).toHaveBeenCalled() + }) + + it.each([ + ['absolute local path', '/etc/passwd'], + ['app config path', '/app/.env'], + ['file:// URL', 'file:///etc/passwd'], + ['relative serve path', '/api/files/serve/kb/foo.pdf'], + ['ftp URL', 'ftp://example.com/file.pdf'], + ['parent traversal', '../../etc/passwd'], + ['windows path', 'C:\\Windows\\System32\\config\\SAM'], + ])('rejects %s with 400 and never invokes the pipeline', async (_label, fileUrl) => { + const req = createMockRequest('POST', { ...baseBody, fileUrl }) + const res = await POST(req, { params }) + expect(res.status).toBe(400) + expect(createDocumentRecords).not.toHaveBeenCalled() + expect(processDocumentsWithQueue).not.toHaveBeenCalled() + }) +}) diff --git a/apps/sim/app/api/knowledge/[id]/route.test.ts b/apps/sim/app/api/knowledge/[id]/route.test.ts index 111e42829bd..2d0dc1ce2e2 100644 --- a/apps/sim/app/api/knowledge/[id]/route.test.ts +++ b/apps/sim/app/api/knowledge/[id]/route.test.ts @@ -31,6 +31,7 @@ vi.mock('@/lib/knowledge/service', async (importOriginal) => { getKnowledgeBaseById: vi.fn(), updateKnowledgeBase: vi.fn(), deleteKnowledgeBase: vi.fn(), + KnowledgeBasePermissionError: actual.KnowledgeBasePermissionError, } }) @@ -39,6 +40,7 @@ vi.mock('@/app/api/knowledge/utils', () => knowledgeApiUtilsMock) import { deleteKnowledgeBase, getKnowledgeBaseById, + KnowledgeBasePermissionError, updateKnowledgeBase, } from '@/lib/knowledge/service' import { DELETE, GET, PUT } from '@/app/api/knowledge/[id]/route' @@ -229,10 +231,59 @@ describe('Knowledge Base By ID API Route', () => { workspaceId: undefined, chunkingConfig: undefined, }, - expect.any(String) + expect.any(String), + { actorUserId: 'user-123' } ) }) + it('returns 403 when service rejects a cross-workspace transfer', async () => { + authMockFns.mockGetSession.mockResolvedValue({ + user: { id: 'attacker', email: 'a@example.com' }, + }) + + resetMocks() + + vi.mocked(checkKnowledgeBaseWriteAccess).mockResolvedValueOnce({ + hasAccess: true, + knowledgeBase: { id: 'kb-123', userId: 'user-123', workspaceId: 'ws-current' }, + }) + + vi.mocked(updateKnowledgeBase).mockRejectedValueOnce( + new KnowledgeBasePermissionError('User does not have permission on the target workspace') + ) + + const req = createMockRequest('PUT', { workspaceId: 'ws-target' }) + const response = await PUT(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(403) + expect(data.error).toBe('User does not have permission on the target workspace') + }) + + it('returns 403 when service rejects clearing workspaceId', async () => { + authMockFns.mockGetSession.mockResolvedValue({ + user: { id: 'user-123', email: 'test@example.com' }, + }) + + resetMocks() + + vi.mocked(checkKnowledgeBaseWriteAccess).mockResolvedValueOnce({ + hasAccess: true, + knowledgeBase: { id: 'kb-123', userId: 'user-123', workspaceId: 'ws-current' }, + }) + + vi.mocked(updateKnowledgeBase).mockRejectedValueOnce( + new KnowledgeBasePermissionError('Knowledge base workspace cannot be cleared') + ) + + const req = createMockRequest('PUT', { workspaceId: null }) + const response = await PUT(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(403) + expect(data.error).toBe('Knowledge base workspace cannot be cleared') + }) + it('should return unauthorized for unauthenticated user', async () => { authMockFns.mockGetSession.mockResolvedValue(null) diff --git a/apps/sim/app/api/knowledge/[id]/route.ts b/apps/sim/app/api/knowledge/[id]/route.ts index 5013a8c29c7..34a85d1a684 100644 --- a/apps/sim/app/api/knowledge/[id]/route.ts +++ b/apps/sim/app/api/knowledge/[id]/route.ts @@ -11,6 +11,7 @@ import { deleteKnowledgeBase, getKnowledgeBaseById, KnowledgeBaseConflictError, + KnowledgeBasePermissionError, updateKnowledgeBase, } from '@/lib/knowledge/service' import { checkKnowledgeBaseAccess, checkKnowledgeBaseWriteAccess } from '@/app/api/knowledge/utils' @@ -101,7 +102,8 @@ export const PUT = withRouteHandler( workspaceId: validatedData.workspaceId, chunkingConfig: validatedData.chunkingConfig, }, - requestId + requestId, + { actorUserId: userId } ) logger.info(`[${requestId}] Knowledge base updated: ${id} for user ${userId}`) @@ -141,6 +143,10 @@ export const PUT = withRouteHandler( if (error instanceof KnowledgeBaseConflictError) { return NextResponse.json({ error: error.message }, { status: 409 }) } + if (error instanceof KnowledgeBasePermissionError) { + logger.warn(`[${requestId}] Forbidden knowledge base update on ${id}: ${error.message}`) + return NextResponse.json({ error: error.message }, { status: 403 }) + } logger.error(`[${requestId}] Error updating knowledge base`, error) return NextResponse.json({ error: 'Failed to update knowledge base' }, { status: 500 }) diff --git a/apps/sim/app/api/knowledge/route.test.ts b/apps/sim/app/api/knowledge/route.test.ts index bc0ab08d755..4ad2aad2acf 100644 --- a/apps/sim/app/api/knowledge/route.test.ts +++ b/apps/sim/app/api/knowledge/route.test.ts @@ -155,6 +155,23 @@ describe('Knowledge Base API Route', () => { expect(data.details).toBeDefined() }) + it('returns 403 when user lacks permission on target workspace', async () => { + authMockFns.mockGetSession.mockResolvedValue({ + user: { id: 'attacker', email: 'a@example.com' }, + }) + permissionsMockFns.mockGetUserEntityPermissions.mockResolvedValueOnce('read') + + const req = createMockRequest('POST', validKnowledgeBaseData) + const response = await POST(req) + const data = await response.json() + + expect(response.status).toBe(403) + expect(data.error).toBe( + 'User does not have permission to create knowledge bases in this workspace' + ) + expect(mockDbChain.insert).not.toHaveBeenCalled() + }) + it('should validate chunking config constraints', async () => { authMockFns.mockGetSession.mockResolvedValue({ user: { id: 'user-123', email: 'test@example.com' }, diff --git a/apps/sim/app/api/knowledge/route.ts b/apps/sim/app/api/knowledge/route.ts index 8cea52b8eb7..e14efd14656 100644 --- a/apps/sim/app/api/knowledge/route.ts +++ b/apps/sim/app/api/knowledge/route.ts @@ -15,6 +15,7 @@ import { createKnowledgeBase, getKnowledgeBases, KnowledgeBaseConflictError, + KnowledgeBasePermissionError, type KnowledgeBaseScope, } from '@/lib/knowledge/service' import { captureServerEvent } from '@/lib/posthog/server' @@ -159,6 +160,10 @@ export const POST = withRouteHandler(async (req: NextRequest) => { if (createError instanceof KnowledgeBaseConflictError) { return NextResponse.json({ error: createError.message }, { status: 409 }) } + if (createError instanceof KnowledgeBasePermissionError) { + logger.warn(`[${requestId}] Forbidden knowledge base creation: ${createError.message}`) + return NextResponse.json({ error: createError.message }, { status: 403 }) + } throw createError } } catch (error) { diff --git a/apps/sim/app/api/mcp/servers/test-connection/route.ts b/apps/sim/app/api/mcp/servers/test-connection/route.ts index 46ef05fc2bd..bd88be77aad 100644 --- a/apps/sim/app/api/mcp/servers/test-connection/route.ts +++ b/apps/sim/app/api/mcp/servers/test-connection/route.ts @@ -95,6 +95,9 @@ export const POST = withRouteHandler( } try { + // Initial pre-resolution check; the authoritative resolved IP is + // captured after env-var resolution below and used to pin the + // connection against DNS rebinding. await validateMcpServerSsrf(body.url) } catch (e) { if (e instanceof McpDnsResolutionError) { @@ -140,8 +143,9 @@ export const POST = withRouteHandler( throw e } + let resolvedIP: string | null try { - await validateMcpServerSsrf(testConfig.url) + resolvedIP = await validateMcpServerSsrf(testConfig.url) } catch (e) { if (e instanceof McpDnsResolutionError) { return createMcpErrorResponse(e, e.message, 502) @@ -162,7 +166,11 @@ export const POST = withRouteHandler( let client: McpClient | null = null try { - client = new McpClient(testConfig, testSecurityPolicy) + client = new McpClient({ + config: testConfig, + securityPolicy: testSecurityPolicy, + resolvedIP: resolvedIP ?? undefined, + }) await client.connect() result.negotiatedVersion = client.getNegotiatedVersion() diff --git a/apps/sim/app/api/tools/agiloft/attach/route.test.ts b/apps/sim/app/api/tools/agiloft/attach/route.test.ts new file mode 100644 index 00000000000..f1e4c8c4264 --- /dev/null +++ b/apps/sim/app/api/tools/agiloft/attach/route.test.ts @@ -0,0 +1,144 @@ +/** + * @vitest-environment node + */ +import { + createMockRequest, + hybridAuthMockFns, + inputValidationMock, + inputValidationMockFns, +} from '@sim/testing' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { mockProcessFilesToUserFiles, mockDownloadFileFromStorage, mockAssertToolFileAccess } = + vi.hoisted(() => ({ + mockProcessFilesToUserFiles: vi.fn(), + mockDownloadFileFromStorage: vi.fn(), + mockAssertToolFileAccess: vi.fn(), + })) + +vi.mock('@/lib/core/security/input-validation.server', () => inputValidationMock) +vi.mock('@/lib/uploads/utils/file-utils', () => ({ + processFilesToUserFiles: mockProcessFilesToUserFiles, +})) +vi.mock('@/lib/uploads/utils/file-utils.server', () => ({ + downloadFileFromStorage: mockDownloadFileFromStorage, +})) +vi.mock('@/app/api/files/authorization', () => ({ + assertToolFileAccess: mockAssertToolFileAccess, +})) + +import { POST } from '@/app/api/tools/agiloft/attach/route' + +const PINNED_IP = '93.184.216.34' + +const baseBody = { + instanceUrl: 'https://example.agiloft.com', + knowledgeBase: 'demo', + login: 'admin', + password: 'secret', + table: 'contracts', + recordId: '42', + fieldName: 'attachments', + file: { key: 's3://bucket/file.txt', name: 'file.txt', size: 5, type: 'text/plain' }, + fileName: 'file.txt', +} + +function mockSecureFetchResponse(body: { + ok?: boolean + status?: number + json?: unknown + text?: string +}) { + return { + ok: body.ok ?? true, + status: body.status ?? 200, + statusText: '', + headers: new Headers(), + body: null, + text: async () => body.text ?? '', + json: async () => body.json ?? {}, + arrayBuffer: async () => new ArrayBuffer(0), + } +} + +beforeEach(() => { + vi.clearAllMocks() + hybridAuthMockFns.mockCheckInternalAuth.mockResolvedValue({ + success: true, + userId: 'user-1', + authType: 'internal_jwt', + }) + inputValidationMockFns.mockValidateUrlWithDNS.mockResolvedValue({ + isValid: true, + resolvedIP: PINNED_IP, + originalHostname: 'example.agiloft.com', + }) + mockProcessFilesToUserFiles.mockReturnValue([ + { key: 's3://bucket/file.txt', name: 'file.txt', size: 5, type: 'text/plain' }, + ]) + mockAssertToolFileAccess.mockResolvedValue(null) + mockDownloadFileFromStorage.mockResolvedValue(Buffer.from('hello')) +}) + +describe('POST /api/tools/agiloft/attach', () => { + it('rejects unauthenticated requests', async () => { + hybridAuthMockFns.mockCheckInternalAuth.mockResolvedValueOnce({ + success: false, + error: 'unauthorized', + }) + + const response = await POST(createMockRequest('POST', baseBody)) + expect(response.status).toBe(401) + expect(inputValidationMockFns.mockSecureFetchWithPinnedIP).not.toHaveBeenCalled() + }) + + it('blocks SSRF when the instance URL fails DNS validation', async () => { + inputValidationMockFns.mockValidateUrlWithDNS.mockResolvedValueOnce({ + isValid: false, + error: 'instanceUrl resolves to a blocked IP address', + }) + + const response = await POST( + createMockRequest('POST', { ...baseBody, instanceUrl: 'https://attacker.example.com' }) + ) + + expect(response.status).toBe(400) + expect(inputValidationMockFns.mockSecureFetchWithPinnedIP).not.toHaveBeenCalled() + }) + + it('pins the resolved IP for login, attach, and logout (TOCTOU fix)', async () => { + inputValidationMockFns.mockSecureFetchWithPinnedIP + .mockResolvedValueOnce(mockSecureFetchResponse({ json: { access_token: 'tok-att' } })) + .mockResolvedValueOnce(mockSecureFetchResponse({ text: '1' })) + .mockResolvedValueOnce(mockSecureFetchResponse({})) + + const response = await POST(createMockRequest('POST', baseBody)) + expect(response.status).toBe(200) + const data = (await response.json()) as { + success: true + output: { totalAttachments: number; fileName: string } + } + expect(data.output.totalAttachments).toBe(1) + expect(data.output.fileName).toBe('file.txt') + + const calls = inputValidationMockFns.mockSecureFetchWithPinnedIP.mock.calls + expect(calls).toHaveLength(3) + for (const call of calls) { + expect(call[1]).toBe(PINNED_IP) + } + + expect(calls[0][0]).toContain('https://example.agiloft.com/ewws/EWLogin') + expect(calls[1][0]).toContain('https://example.agiloft.com/ewws/EWAttach') + expect(calls[1][2]).toMatchObject({ + method: 'PUT', + headers: { + Authorization: 'Bearer tok-att', + 'Content-Type': 'application/octet-stream', + }, + }) + expect(calls[2][0]).toContain('https://example.agiloft.com/ewws/EWLogout') + + // DNS only resolved once. + expect(inputValidationMockFns.mockValidateUrlWithDNS).toHaveBeenCalledTimes(1) + }) +}) diff --git a/apps/sim/app/api/tools/agiloft/attach/route.ts b/apps/sim/app/api/tools/agiloft/attach/route.ts index 6257502ae4c..b0fcb351751 100644 --- a/apps/sim/app/api/tools/agiloft/attach/route.ts +++ b/apps/sim/app/api/tools/agiloft/attach/route.ts @@ -4,14 +4,19 @@ import { type NextRequest, NextResponse } from 'next/server' import { agiloftAttachContract } from '@/lib/api/contracts/tools/agiloft' import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' import { checkInternalAuth } from '@/lib/auth/hybrid' -import { validateUrlWithDNS } from '@/lib/core/security/input-validation.server' +import { secureFetchWithPinnedIP } from '@/lib/core/security/input-validation.server' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' import type { RawFileInput } from '@/lib/uploads/utils/file-schemas' import { processFilesToUserFiles } from '@/lib/uploads/utils/file-utils' import { downloadFileFromStorage } from '@/lib/uploads/utils/file-utils.server' import { assertToolFileAccess } from '@/app/api/files/authorization' -import { agiloftLogin, agiloftLogout, buildAttachFileUrl } from '@/tools/agiloft/utils' +import { buildAttachFileUrl } from '@/tools/agiloft/utils' +import { + agiloftLoginPinned, + agiloftLogoutPinned, + resolveAgiloftInstance, +} from '@/tools/agiloft/utils.server' export const dynamic = 'force-dynamic' @@ -72,18 +77,17 @@ export const POST = withRouteHandler(async (request: NextRequest) => { const fileBuffer = await downloadFileFromStorage(userFile, requestId, logger) const resolvedFileName = data.fileName || userFile.name || 'attachment' - const urlValidation = await validateUrlWithDNS(data.instanceUrl, 'instanceUrl') - if (!urlValidation.isValid) { + let resolvedIP: string + try { + resolvedIP = await resolveAgiloftInstance(data.instanceUrl) + } catch (error) { logger.warn(`[${requestId}] SSRF attempt blocked for Agiloft instance URL`, { instanceUrl: data.instanceUrl, }) - return NextResponse.json( - { success: false, error: urlValidation.error || 'Invalid instance URL' }, - { status: 400 } - ) + return NextResponse.json({ success: false, error: toError(error).message }, { status: 400 }) } - const token = await agiloftLogin(data) + const token = await agiloftLoginPinned(data, resolvedIP) const base = data.instanceUrl.replace(/\/$/, '') try { @@ -91,7 +95,7 @@ export const POST = withRouteHandler(async (request: NextRequest) => { logger.info(`[${requestId}] Uploading file to Agiloft: ${resolvedFileName}`) - const agiloftResponse = await fetch(url, { + const agiloftResponse = await secureFetchWithPinnedIP(url, resolvedIP, { method: 'PUT', headers: { 'Content-Type': 'application/octet-stream', @@ -135,7 +139,7 @@ export const POST = withRouteHandler(async (request: NextRequest) => { }, }) } finally { - await agiloftLogout(data.instanceUrl, data.knowledgeBase, token) + await agiloftLogoutPinned(data.instanceUrl, data.knowledgeBase, token, resolvedIP) } } catch (error) { logger.error(`[${requestId}] Error attaching file to Agiloft:`, error) diff --git a/apps/sim/app/api/tools/agiloft/retrieve/route.test.ts b/apps/sim/app/api/tools/agiloft/retrieve/route.test.ts new file mode 100644 index 00000000000..efd435b5b04 --- /dev/null +++ b/apps/sim/app/api/tools/agiloft/retrieve/route.test.ts @@ -0,0 +1,163 @@ +/** + * @vitest-environment node + */ +import { + createMockRequest, + hybridAuthMockFns, + inputValidationMock, + inputValidationMockFns, +} from '@sim/testing' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +vi.mock('@/lib/core/security/input-validation.server', () => inputValidationMock) + +import { POST } from '@/app/api/tools/agiloft/retrieve/route' + +const PINNED_IP = '93.184.216.34' + +const baseBody = { + instanceUrl: 'https://example.agiloft.com', + knowledgeBase: 'demo', + login: 'admin', + password: 'secret', + table: 'contracts', + recordId: '42', + fieldName: 'attachments', + position: '0', +} + +function mockSecureFetchResponse(body: { + ok?: boolean + status?: number + json?: unknown + text?: string + arrayBuffer?: ArrayBuffer + headers?: Headers +}) { + return { + ok: body.ok ?? true, + status: body.status ?? 200, + statusText: '', + headers: body.headers ?? new Headers(), + body: null, + text: async () => body.text ?? '', + json: async () => body.json ?? {}, + arrayBuffer: async () => body.arrayBuffer ?? new ArrayBuffer(0), + } +} + +beforeEach(() => { + vi.clearAllMocks() + hybridAuthMockFns.mockCheckInternalAuth.mockResolvedValue({ + success: true, + userId: 'user-1', + authType: 'internal_jwt', + }) + inputValidationMockFns.mockValidateUrlWithDNS.mockResolvedValue({ + isValid: true, + resolvedIP: PINNED_IP, + originalHostname: 'example.agiloft.com', + }) +}) + +describe('POST /api/tools/agiloft/retrieve', () => { + it('rejects unauthenticated requests', async () => { + hybridAuthMockFns.mockCheckInternalAuth.mockResolvedValueOnce({ + success: false, + error: 'unauthorized', + }) + + const response = await POST(createMockRequest('POST', baseBody)) + expect(response.status).toBe(401) + expect(inputValidationMockFns.mockSecureFetchWithPinnedIP).not.toHaveBeenCalled() + }) + + it('blocks SSRF when the instance URL fails DNS validation', async () => { + inputValidationMockFns.mockValidateUrlWithDNS.mockResolvedValueOnce({ + isValid: false, + error: 'instanceUrl resolves to a blocked IP address', + }) + + const response = await POST( + createMockRequest('POST', { ...baseBody, instanceUrl: 'https://attacker.example.com' }) + ) + + expect(response.status).toBe(400) + const data = (await response.json()) as { success: false; error: string } + expect(data.success).toBe(false) + expect(data.error).toContain('blocked IP') + expect(inputValidationMockFns.mockSecureFetchWithPinnedIP).not.toHaveBeenCalled() + }) + + it('pins the resolved IP for login, retrieve, and logout (TOCTOU fix)', async () => { + const fileBytes = Buffer.from('hello-attachment', 'utf-8') + + inputValidationMockFns.mockSecureFetchWithPinnedIP + .mockResolvedValueOnce(mockSecureFetchResponse({ json: { access_token: 'tok-xyz' } })) + .mockResolvedValueOnce( + mockSecureFetchResponse({ + arrayBuffer: fileBytes.buffer.slice( + fileBytes.byteOffset, + fileBytes.byteOffset + fileBytes.byteLength + ) as ArrayBuffer, + headers: new Headers({ + 'content-type': 'text/plain', + 'content-disposition': 'attachment; filename="report.txt"', + }), + }) + ) + .mockResolvedValueOnce(mockSecureFetchResponse({})) + + const response = await POST(createMockRequest('POST', baseBody)) + expect(response.status).toBe(200) + const data = (await response.json()) as { + success: true + output: { file: { name: string; mimeType: string; data: string; size: number } } + } + + expect(data.output.file.name).toBe('report.txt') + expect(data.output.file.mimeType).toBe('text/plain') + expect(data.output.file.size).toBe(fileBytes.length) + expect(Buffer.from(data.output.file.data, 'base64').toString('utf-8')).toBe('hello-attachment') + + const calls = inputValidationMockFns.mockSecureFetchWithPinnedIP.mock.calls + expect(calls).toHaveLength(3) + + // All three outbound calls must use the pre-resolved IP. + for (const call of calls) { + expect(call[1]).toBe(PINNED_IP) + } + + // Original hostname is preserved in the URL (so TLS SNI works). + expect(calls[0][0]).toContain('https://example.agiloft.com/ewws/EWLogin') + expect(calls[1][0]).toContain('https://example.agiloft.com/ewws/EWRetrieve') + expect(calls[1][2]).toMatchObject({ + method: 'GET', + headers: { Authorization: 'Bearer tok-xyz' }, + }) + expect(calls[2][0]).toContain('https://example.agiloft.com/ewws/EWLogout') + + // DNS only resolved once — no second lookup that could rebind. + expect(inputValidationMockFns.mockValidateUrlWithDNS).toHaveBeenCalledTimes(1) + }) + + it('propagates upstream errors and still calls logout', async () => { + inputValidationMockFns.mockSecureFetchWithPinnedIP + .mockResolvedValueOnce(mockSecureFetchResponse({ json: { access_token: 'tok-err' } })) + .mockResolvedValueOnce( + mockSecureFetchResponse({ ok: false, status: 404, text: 'Record not found' }) + ) + .mockResolvedValueOnce(mockSecureFetchResponse({})) + + const response = await POST(createMockRequest('POST', baseBody)) + expect(response.status).toBe(404) + const data = (await response.json()) as { success: false; error: string } + expect(data.error).toContain('Record not found') + + // Logout still runs. + expect(inputValidationMockFns.mockSecureFetchWithPinnedIP).toHaveBeenCalledTimes(3) + expect(inputValidationMockFns.mockSecureFetchWithPinnedIP.mock.calls[2][0]).toContain( + '/ewws/EWLogout' + ) + }) +}) diff --git a/apps/sim/app/api/tools/agiloft/retrieve/route.ts b/apps/sim/app/api/tools/agiloft/retrieve/route.ts index 64bd72daae8..539f0bf7c2e 100644 --- a/apps/sim/app/api/tools/agiloft/retrieve/route.ts +++ b/apps/sim/app/api/tools/agiloft/retrieve/route.ts @@ -4,10 +4,15 @@ import { type NextRequest, NextResponse } from 'next/server' import { agiloftRetrieveContract } from '@/lib/api/contracts/tools/agiloft' import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' import { checkInternalAuth } from '@/lib/auth/hybrid' -import { validateUrlWithDNS } from '@/lib/core/security/input-validation.server' +import { secureFetchWithPinnedIP } from '@/lib/core/security/input-validation.server' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' -import { agiloftLogin, agiloftLogout, buildRetrieveAttachmentUrl } from '@/tools/agiloft/utils' +import { buildRetrieveAttachmentUrl } from '@/tools/agiloft/utils' +import { + agiloftLoginPinned, + agiloftLogoutPinned, + resolveAgiloftInstance, +} from '@/tools/agiloft/utils.server' export const dynamic = 'force-dynamic' @@ -48,18 +53,17 @@ export const POST = withRouteHandler(async (request: NextRequest) => { if (!parsed.success) return parsed.response const data = parsed.data.body - const urlValidation = await validateUrlWithDNS(data.instanceUrl, 'instanceUrl') - if (!urlValidation.isValid) { + let resolvedIP: string + try { + resolvedIP = await resolveAgiloftInstance(data.instanceUrl) + } catch (error) { logger.warn(`[${requestId}] SSRF attempt blocked for Agiloft instance URL`, { instanceUrl: data.instanceUrl, }) - return NextResponse.json( - { success: false, error: urlValidation.error || 'Invalid instance URL' }, - { status: 400 } - ) + return NextResponse.json({ success: false, error: toError(error).message }, { status: 400 }) } - const token = await agiloftLogin(data) + const token = await agiloftLoginPinned(data, resolvedIP) const base = data.instanceUrl.replace(/\/$/, '') try { @@ -71,7 +75,7 @@ export const POST = withRouteHandler(async (request: NextRequest) => { position: data.position, }) - const agiloftResponse = await fetch(url, { + const agiloftResponse = await secureFetchWithPinnedIP(url, resolvedIP, { method: 'GET', headers: { Authorization: `Bearer ${token}`, @@ -123,7 +127,7 @@ export const POST = withRouteHandler(async (request: NextRequest) => { }, }) } finally { - await agiloftLogout(data.instanceUrl, data.knowledgeBase, token) + await agiloftLogoutPinned(data.instanceUrl, data.knowledgeBase, token, resolvedIP) } } catch (error) { logger.error(`[${requestId}] Error retrieving Agiloft attachment:`, error) diff --git a/apps/sim/app/form/[identifier]/components/email-auth.tsx b/apps/sim/app/form/[identifier]/components/email-auth.tsx new file mode 100644 index 00000000000..b75cb159c3c --- /dev/null +++ b/apps/sim/app/form/[identifier]/components/email-auth.tsx @@ -0,0 +1,284 @@ +'use client' + +import { useEffect, useState } from 'react' +import { createLogger } from '@sim/logger' +import { toError } from '@sim/utils/errors' +import { Input, InputOTP, InputOTPGroup, InputOTPSlot, Label, Loader } from '@/components/emcn' +import { cn } from '@/lib/core/utils/cn' +import { quickValidateEmail } from '@/lib/messaging/email/validation' +import AuthBackground from '@/app/(auth)/components/auth-background' +import { AUTH_SUBMIT_BTN } from '@/app/(auth)/components/auth-button-classes' +import { SupportFooter } from '@/app/(auth)/components/support-footer' +import Navbar from '@/app/(landing)/components/navbar/navbar' +import { useFormEmailOtpRequest, useFormEmailOtpVerify } from '@/hooks/queries/forms' + +const logger = createLogger('FormEmailAuth') + +interface EmailAuthProps { + identifier: string + onAuthenticated: () => void +} + +function validateEmailField(emailValue: string): string[] { + const errors: string[] = [] + + if (!emailValue || !emailValue.trim()) { + errors.push('Email is required.') + return errors + } + + const validation = quickValidateEmail(emailValue.trim().toLowerCase()) + if (!validation.isValid) { + errors.push(validation.reason || 'Please enter a valid email address.') + } + + return errors +} + +export function EmailAuth({ identifier, onAuthenticated }: EmailAuthProps) { + const [email, setEmail] = useState('') + const [authError, setAuthError] = useState(null) + const [emailErrors, setEmailErrors] = useState([]) + const [showEmailValidationError, setShowEmailValidationError] = useState(false) + + const [showOtpVerification, setShowOtpVerification] = useState(false) + const [otpValue, setOtpValue] = useState('') + const [countdown, setCountdown] = useState(0) + + const requestOtp = useFormEmailOtpRequest(identifier) + const verifyOtp = useFormEmailOtpVerify(identifier) + + useEffect(() => { + if (countdown <= 0) return + const timer = setTimeout(() => setCountdown((c) => c - 1), 1000) + return () => clearTimeout(timer) + }, [countdown]) + + const handleEmailChange = (e: React.ChangeEvent) => { + const newEmail = e.target.value + setEmail(newEmail) + const errors = validateEmailField(newEmail) + setEmailErrors(errors) + setShowEmailValidationError(false) + } + + const handleSendOtp = async () => { + const emailValidationErrors = validateEmailField(email) + setEmailErrors(emailValidationErrors) + setShowEmailValidationError(emailValidationErrors.length > 0) + + if (emailValidationErrors.length > 0) return + + setAuthError(null) + + try { + await requestOtp.mutateAsync({ email }) + setShowOtpVerification(true) + } catch (error) { + logger.error('Error sending OTP:', error) + setEmailErrors([toError(error).message || 'Failed to send verification code']) + setShowEmailValidationError(true) + } + } + + const handleVerifyOtp = async (otp?: string) => { + const codeToVerify = otp || otpValue + if (!codeToVerify || codeToVerify.length !== 6) return + + setAuthError(null) + + try { + await verifyOtp.mutateAsync({ email, otp: codeToVerify }) + onAuthenticated() + } catch (error) { + logger.error('Error verifying OTP:', error) + setAuthError(toError(error).message || 'Invalid verification code') + } + } + + const handleResendOtp = async () => { + setAuthError(null) + setCountdown(30) + + try { + await requestOtp.mutateAsync({ email }) + setOtpValue('') + } catch (error) { + logger.error('Error resending OTP:', error) + setAuthError(toError(error).message || 'Failed to resend verification code') + setCountdown(0) + } + } + + return ( + +
+
+ +
+
+
+
+
+

+ {showOtpVerification ? 'Verify Your Email' : 'Email Verification'} +

+

+ {showOtpVerification + ? `A verification code has been sent to ${email}` + : 'This form requires email verification'} +

+
+ +
+ {!showOtpVerification ? ( +
{ + e.preventDefault() + handleSendOtp() + }} + className='space-y-6' + > +
+ + 0 && + 'border-red-500 focus:border-red-500' + )} + /> + {showEmailValidationError && emailErrors.length > 0 && ( +
+ {emailErrors.map((error) => ( +

{error}

+ ))} +
+ )} +
+ + +
+ ) : ( +
+

+ Enter the 6-digit code to verify your account. If you don't see it in your + inbox, check your spam folder. +

+ +
+ { + setOtpValue(value) + if (value.length === 6) { + handleVerifyOtp(value) + } + }} + disabled={verifyOtp.isPending} + className={cn('gap-2', authError && 'otp-error')} + > + + {[0, 1, 2, 3, 4, 5].map((index) => ( + + ))} + + +
+ + {authError && ( +
+

{authError}

+
+ )} + + + +
+

+ Didn't receive a code?{' '} + {countdown > 0 ? ( + + Resend in{' '} + + {countdown}s + + + ) : ( + + )} +

+
+ +
+ +
+
+ )} +
+
+
+
+ +
+
+ ) +} diff --git a/apps/sim/app/form/[identifier]/components/index.ts b/apps/sim/app/form/[identifier]/components/index.ts index 31cb46d6843..e888196967c 100644 --- a/apps/sim/app/form/[identifier]/components/index.ts +++ b/apps/sim/app/form/[identifier]/components/index.ts @@ -1,3 +1,4 @@ +export { EmailAuth } from './email-auth' export { FormErrorState } from './error-state' export { FormField } from './form-field' export { FormLoadingState } from './loading-state' diff --git a/apps/sim/app/form/[identifier]/form.tsx b/apps/sim/app/form/[identifier]/form.tsx index f6264ddf0fb..4e809096b55 100644 --- a/apps/sim/app/form/[identifier]/form.tsx +++ b/apps/sim/app/form/[identifier]/form.tsx @@ -10,6 +10,7 @@ import { AUTH_SUBMIT_BTN } from '@/app/(auth)/components/auth-button-classes' import { SupportFooter } from '@/app/(auth)/components/support-footer' import Navbar from '@/app/(landing)/components/navbar/navbar' import { + EmailAuth, FormErrorState, FormField, FormLoadingState, @@ -241,6 +242,10 @@ export default function Form({ identifier }: { identifier: string }) { return } + if (authRequired === 'email') { + return fetchFormConfig()} /> + } + if (isSubmitted && thankYouData) { return ( diff --git a/apps/sim/hooks/queries/forms.ts b/apps/sim/hooks/queries/forms.ts index e43e8fe7ecc..e20df733b5e 100644 --- a/apps/sim/hooks/queries/forms.ts +++ b/apps/sim/hooks/queries/forms.ts @@ -14,8 +14,10 @@ import { type FormStatusResponse, getFormDetailContract, getFormStatusContract, + requestFormEmailOtpContract, type UpdateFormInput, updateFormContract, + verifyFormEmailOtpContract, } from '@/lib/api/contracts/forms' import { deploymentKeys } from './deployments' @@ -35,6 +37,36 @@ export const formKeys = { */ export type { FormAuthType } +/** + * Requests a one-time passcode for an email-gated deployed form. + * Used for both the initial send and resend flows. + */ +export function useFormEmailOtpRequest(identifier: string) { + return useMutation({ + mutationFn: async ({ email }: { email: string }) => { + await requestJson(requestFormEmailOtpContract, { + params: { identifier }, + body: { email }, + }) + }, + }) +} + +/** + * Verifies a one-time passcode for an email-gated deployed form. + * On success the server sets the auth cookie; the caller should re-fetch the form config. + */ +export function useFormEmailOtpVerify(identifier: string) { + return useMutation({ + mutationFn: async ({ email, otp }: { email: string; otp: string }) => { + await requestJson(verifyFormEmailOtpContract, { + params: { identifier }, + body: { email, otp }, + }) + }, + }) +} + /** * Field configuration for form fields */ diff --git a/apps/sim/lib/api/contracts/forms.ts b/apps/sim/lib/api/contracts/forms.ts index af252043b65..0283121edca 100644 --- a/apps/sim/lib/api/contracts/forms.ts +++ b/apps/sim/lib/api/contracts/forms.ts @@ -148,6 +148,22 @@ export const formMutationResponseSchema = z.object({ message: z.string(), }) +export const formEmailOtpRequestBodySchema = z.object({ + email: z.string().email('Invalid email address'), +}) + +export const formEmailOtpVerifyBodySchema = formEmailOtpRequestBodySchema.extend({ + otp: z.string().length(6, 'OTP must be 6 digits'), +}) + +export const formEmailOtpRequestResponseSchema = z.object({ + message: z.string(), +}) + +export const formEmailOtpVerifyResponseSchema = z.object({ + authenticated: z.literal(true), +}) + export const getFormStatusContract = defineRouteContract({ method: 'GET', path: '/api/workflows/[id]/form/status', @@ -199,6 +215,28 @@ export const deleteFormContract = defineRouteContract({ }, }) +export const requestFormEmailOtpContract = defineRouteContract({ + method: 'POST', + path: '/api/form/[identifier]/otp', + params: formIdentifierParamsSchema, + body: formEmailOtpRequestBodySchema, + response: { + mode: 'json', + schema: formEmailOtpRequestResponseSchema, + }, +}) + +export const verifyFormEmailOtpContract = defineRouteContract({ + method: 'PUT', + path: '/api/form/[identifier]/otp', + params: formIdentifierParamsSchema, + body: formEmailOtpVerifyBodySchema, + response: { + mode: 'json', + schema: formEmailOtpVerifyResponseSchema, + }, +}) + export const validateFormIdentifierContract = defineRouteContract({ method: 'GET', path: '/api/form/validate', diff --git a/apps/sim/lib/api/contracts/knowledge/documents.ts b/apps/sim/lib/api/contracts/knowledge/documents.ts index 17cb344f37f..6a85005f700 100644 --- a/apps/sim/lib/api/contracts/knowledge/documents.ts +++ b/apps/sim/lib/api/contracts/knowledge/documents.ts @@ -5,6 +5,7 @@ import { documentNumberFieldSchema, documentTagFieldSchema, knowledgeBaseParamsSchema, + knowledgeDocumentFileUrlSchema, knowledgeDocumentParamsSchema, nullableWireDateSchema, paginationSchema, @@ -55,7 +56,7 @@ export const listKnowledgeDocumentsQuerySchema = z.object({ export const createDocumentBodySchema = z.object({ filename: z.string().min(1, 'Filename is required'), - fileUrl: z.string().url('File URL must be valid'), + fileUrl: knowledgeDocumentFileUrlSchema, fileSize: z.number().min(1, 'File size must be greater than 0'), mimeType: z.string().min(1, 'MIME type is required'), tag1: z.string().optional(), @@ -101,7 +102,7 @@ export type SingleCreateDocumentBody = z.input { + it('accepts data: URIs', () => { + const result = knowledgeDocumentFileUrlSchema.safeParse( + 'data:text/plain;base64,SGVsbG8gd29ybGQ=' + ) + expect(result.success).toBe(true) + }) + + it('accepts https URLs', () => { + const result = knowledgeDocumentFileUrlSchema.safeParse('https://example.com/file.pdf') + expect(result.success).toBe(true) + }) + + it('accepts http URLs', () => { + const result = knowledgeDocumentFileUrlSchema.safeParse( + 'http://localhost:3000/api/files/serve/kb/foo.pdf?context=knowledge-base' + ) + expect(result.success).toBe(true) + }) + + it('is case-insensitive on the scheme', () => { + expect(knowledgeDocumentFileUrlSchema.safeParse('HTTPS://example.com/x').success).toBe(true) + expect(knowledgeDocumentFileUrlSchema.safeParse('Http://example.com/x').success).toBe(true) + }) + + it.each([ + ['absolute local path', '/etc/passwd'], + ['app path', '/app/.env'], + ['relative path', './secrets.txt'], + ['parent traversal', '../../etc/shadow'], + ['file:// scheme', 'file:///etc/passwd'], + ['ftp scheme', 'ftp://example.com/x'], + ['javascript scheme', 'javascript:alert(1)'], + ['gopher scheme', 'gopher://example.com'], + ['relative serve path', '/api/files/serve/kb/foo.pdf'], + ['windows path', 'C:\\Windows\\System32\\config\\SAM'], + ['empty string', ''], + ['whitespace prefix', ' https://example.com/x'], + ])('rejects %s', (_label, value) => { + const result = knowledgeDocumentFileUrlSchema.safeParse(value) + expect(result.success).toBe(false) + }) + + it('returns a useful error message for unsupported schemes', () => { + const result = knowledgeDocumentFileUrlSchema.safeParse('/etc/passwd') + if (result.success) throw new Error('expected failure') + expect(result.error.issues[0].message).toMatch(/data: URI or an http\(s\):\/\/ URL/) + }) +}) diff --git a/apps/sim/lib/api/contracts/knowledge/shared.ts b/apps/sim/lib/api/contracts/knowledge/shared.ts index 070cd4606de..e75b3bde368 100644 --- a/apps/sim/lib/api/contracts/knowledge/shared.ts +++ b/apps/sim/lib/api/contracts/knowledge/shared.ts @@ -23,6 +23,21 @@ export const knowledgeConnectorParamsSchema = knowledgeBaseParamsSchema.extend({ connectorId: z.string().min(1), }) +/** + * A `fileUrl` accepted by knowledge document ingestion endpoints. + * + * Must be a `data:` URI or an `http(s)://` URL. Local paths, `file://`, + * and other schemes are rejected at the boundary to prevent the background + * parser from reading arbitrary files off the Sim server's filesystem. + */ +export const knowledgeDocumentFileUrlSchema = z + .string() + .min(1, 'File URL is required') + .refine( + (value) => /^data:/i.test(value) || /^https?:\/\//i.test(value), + 'File URL must be a data: URI or an http(s):// URL' + ) + export const documentTagFieldSchema = z.string().nullable().optional() export const documentNumberFieldSchema = z.number().nullable().optional() export const documentBooleanFieldSchema = z.boolean().nullable().optional() diff --git a/apps/sim/lib/core/security/otp.ts b/apps/sim/lib/core/security/otp.ts new file mode 100644 index 00000000000..5163487d10b --- /dev/null +++ b/apps/sim/lib/core/security/otp.ts @@ -0,0 +1,251 @@ +import { randomInt } from 'crypto' +import { db } from '@sim/db' +import { verification } from '@sim/db/schema' +import { generateId } from '@sim/utils/id' +import { and, eq, gt } from 'drizzle-orm' +import { getRedisClient } from '@/lib/core/config/redis' +import type { TokenBucketConfig } from '@/lib/core/rate-limiter' +import { getStorageMethod } from '@/lib/core/storage' + +export type DeploymentKind = 'chat' | 'form' + +/** + * Shared OTP configuration for deployment (chat/form) email-auth gates. + */ +export const OTP_EXPIRY_SECONDS = 15 * 60 +export const OTP_EXPIRY_MS = OTP_EXPIRY_SECONDS * 1000 +export const MAX_OTP_ATTEMPTS = 5 + +export const OTP_IP_RATE_LIMIT: TokenBucketConfig = { + maxTokens: 10, + refillRate: 10, + refillIntervalMs: 15 * 60_000, +} + +export const OTP_EMAIL_RATE_LIMIT: TokenBucketConfig = { + maxTokens: 3, + refillRate: 3, + refillIntervalMs: 15 * 60_000, +} + +/** + * Key formats are kept per-kind to preserve any in-flight OTPs already issued + * against existing chat deployments. The chat Redis key uses the legacy `otp:` + * prefix; the chat DB identifier uses `chat-otp:`. Forms use `form-otp:` for + * both. + */ +const OTP_KEYS = { + chat: { + redisKey: (email: string, deploymentId: string) => `otp:${email}:${deploymentId}`, + dbIdentifier: (email: string, deploymentId: string) => `chat-otp:${deploymentId}:${email}`, + }, + form: { + redisKey: (email: string, deploymentId: string) => `form-otp:${email}:${deploymentId}`, + dbIdentifier: (email: string, deploymentId: string) => `form-otp:${deploymentId}:${email}`, + }, +} as const satisfies Record< + DeploymentKind, + { + redisKey: (email: string, deploymentId: string) => string + dbIdentifier: (email: string, deploymentId: string) => string + } +> + +/** Returns a cryptographically random 6-digit OTP code. */ +export function generateOTP(): string { + return randomInt(100000, 1000000).toString() +} + +/** + * OTP values are stored as `"code:attempts"` (e.g. `"654321:0"`). + * This keeps the attempt counter in the same key/row as the OTP itself. + */ +function encodeOTPValue(otp: string, attempts: number): string { + return `${otp}:${attempts}` +} + +export function decodeOTPValue(value: string): { otp: string; attempts: number } { + const lastColon = value.lastIndexOf(':') + if (lastColon === -1) return { otp: value, attempts: 0 } + const attempts = Number.parseInt(value.slice(lastColon + 1), 10) + return { otp: value.slice(0, lastColon), attempts: Number.isNaN(attempts) ? 0 : attempts } +} + +/** + * Stores an OTP for a deployment+email pair, choosing Redis or the + * `verification` table based on the configured storage method. + */ +export async function storeOTP( + kind: DeploymentKind, + deploymentId: string, + email: string, + otp: string +): Promise { + const keys = OTP_KEYS[kind] + const value = encodeOTPValue(otp, 0) + const storageMethod = getStorageMethod() + + if (storageMethod === 'redis') { + const redis = getRedisClient() + if (!redis) throw new Error('Redis configured but client unavailable') + await redis.set(keys.redisKey(email, deploymentId), value, 'EX', OTP_EXPIRY_SECONDS) + return + } + + const now = new Date() + const expiresAt = new Date(now.getTime() + OTP_EXPIRY_MS) + const identifier = keys.dbIdentifier(email, deploymentId) + + await db.transaction(async (tx) => { + await tx.delete(verification).where(eq(verification.identifier, identifier)) + await tx.insert(verification).values({ + id: generateId(), + identifier, + value, + expiresAt, + createdAt: now, + updatedAt: now, + }) + }) +} + +export async function getOTP( + kind: DeploymentKind, + deploymentId: string, + email: string +): Promise { + const keys = OTP_KEYS[kind] + const storageMethod = getStorageMethod() + + if (storageMethod === 'redis') { + const redis = getRedisClient() + if (!redis) throw new Error('Redis configured but client unavailable') + return redis.get(keys.redisKey(email, deploymentId)) + } + + const now = new Date() + const [record] = await db + .select({ value: verification.value }) + .from(verification) + .where( + and( + eq(verification.identifier, keys.dbIdentifier(email, deploymentId)), + gt(verification.expiresAt, now) + ) + ) + .limit(1) + + return record?.value ?? null +} + +/** + * Lua script for atomic OTP attempt increment in Redis. + * Returns `'LOCKED'` if max attempts reached (key deleted), new encoded value + * otherwise, nil if key missing. + */ +const ATOMIC_INCREMENT_SCRIPT = ` +local val = redis.call('GET', KEYS[1]) +if not val then return nil end +local colon = val:find(':([^:]*$)') +local otp, attempts +if colon then + otp = val:sub(1, colon - 1) + attempts = tonumber(val:sub(colon + 1)) or 0 +else + otp = val + attempts = 0 +end +attempts = attempts + 1 +if attempts >= tonumber(ARGV[1]) then + redis.call('DEL', KEYS[1]) + return 'LOCKED' +end +local newVal = otp .. ':' .. attempts +local ttl = redis.call('TTL', KEYS[1]) +if ttl > 0 then + redis.call('SET', KEYS[1], newVal, 'EX', ttl) +else + redis.call('SET', KEYS[1], newVal) +end +return newVal +` + +/** + * Atomically increments an OTP's failed-attempt counter. Returns `'locked'` + * if the max-attempts threshold was reached (and the OTP was deleted), or + * `'incremented'` otherwise. The DB path uses optimistic locking with retry. + */ +export async function incrementOTPAttempts( + kind: DeploymentKind, + deploymentId: string, + email: string, + currentValue: string +): Promise<'locked' | 'incremented'> { + const keys = OTP_KEYS[kind] + const storageMethod = getStorageMethod() + + if (storageMethod === 'redis') { + const redis = getRedisClient() + if (!redis) throw new Error('Redis configured but client unavailable') + const key = keys.redisKey(email, deploymentId) + const result = await redis.eval(ATOMIC_INCREMENT_SCRIPT, 1, key, MAX_OTP_ATTEMPTS) + if (result === null || result === 'LOCKED') return 'locked' + return 'incremented' + } + + const identifier = keys.dbIdentifier(email, deploymentId) + const MAX_RETRIES = 3 + let value = currentValue + + for (let attempt = 0; attempt < MAX_RETRIES; attempt++) { + const { otp, attempts } = decodeOTPValue(value) + const newAttempts = attempts + 1 + + if (newAttempts >= MAX_OTP_ATTEMPTS) { + await db.delete(verification).where(eq(verification.identifier, identifier)) + return 'locked' + } + + const newValue = encodeOTPValue(otp, newAttempts) + const updated = await db + .update(verification) + .set({ value: newValue, updatedAt: new Date() }) + .where(and(eq(verification.identifier, identifier), eq(verification.value, value))) + .returning({ id: verification.id }) + + if (updated.length > 0) return 'incremented' + + const fresh = await getOTP(kind, deploymentId, email) + if (!fresh) return 'locked' + value = fresh + } + + /** + * Retry exhaustion under heavy DB-path contention: this request did not + * succeed in writing its own +1, so the stored count may not reflect it. + * Fail closed — invalidate the OTP rather than return `'incremented'` with + * a possibly-undercounted attempt total. + */ + await db.delete(verification).where(eq(verification.identifier, identifier)) + return 'locked' +} + +export async function deleteOTP( + kind: DeploymentKind, + deploymentId: string, + email: string +): Promise { + const keys = OTP_KEYS[kind] + const storageMethod = getStorageMethod() + + if (storageMethod === 'redis') { + const redis = getRedisClient() + if (!redis) throw new Error('Redis configured but client unavailable') + await redis.del(keys.redisKey(email, deploymentId)) + return + } + + await db + .delete(verification) + .where(eq(verification.identifier, keys.dbIdentifier(email, deploymentId))) +} diff --git a/apps/sim/lib/knowledge/documents/document-processor.ts b/apps/sim/lib/knowledge/documents/document-processor.ts index 67f246d06b5..5e550f5af60 100644 --- a/apps/sim/lib/knowledge/documents/document-processor.ts +++ b/apps/sim/lib/knowledge/documents/document-processor.ts @@ -15,7 +15,7 @@ import { } from '@/lib/chunkers' import type { ChunkingStrategy, StrategyOptions } from '@/lib/chunkers/types' import { env, envNumber } from '@/lib/core/config/env' -import { parseBuffer, parseFile } from '@/lib/file-parsers' +import { parseBuffer } from '@/lib/file-parsers' import type { FileParseMetadata } from '@/lib/file-parsers/types' import { resolveParserExtension } from '@/lib/knowledge/documents/parser-extension' import { retryWithExponentialBackoff } from '@/lib/knowledge/documents/utils' @@ -315,7 +315,7 @@ async function handleFileForOCR( userId?: string, workspaceId?: string | null ) { - const isExternalHttps = fileUrl.startsWith('https://') && !isInternalFileUrl(fileUrl) + const isExternalHttps = /^https:\/\//i.test(fileUrl) && !isInternalFileUrl(fileUrl) if (isExternalHttps) { if (mimeType === 'application/pdf') { @@ -385,18 +385,17 @@ async function downloadFileWithTimeout(fileUrl: string): Promise { } async function downloadFileForBase64(fileUrl: string): Promise { - if (fileUrl.startsWith('data:')) { + if (/^data:/i.test(fileUrl)) { const [, base64Data] = fileUrl.split(',') if (!base64Data) { throw new Error('Invalid data URI format') } return Buffer.from(base64Data, 'base64') } - if (fileUrl.startsWith('http')) { + if (/^https?:\/\//i.test(fileUrl)) { return downloadFileWithTimeout(fileUrl) } - const fs = await import('fs/promises') - return fs.readFile(fileUrl) + throw new Error('Unsupported fileUrl scheme: only data: URIs and http(s):// URLs are allowed') } function processOCRContent(result: OCRResult, filename: string): string { @@ -783,16 +782,14 @@ async function parseWithFileParser(fileUrl: string, filename: string, mimeType: let content: string let metadata: FileParseMetadata = {} - if (fileUrl.startsWith('data:')) { + if (/^data:/i.test(fileUrl)) { content = await parseDataURI(fileUrl, filename, mimeType) - } else if (fileUrl.startsWith('http')) { + } else if (/^https?:\/\//i.test(fileUrl)) { const result = await parseHttpFile(fileUrl, filename, mimeType) content = result.content metadata = result.metadata || {} } else { - const result = await parseFile(fileUrl) - content = result.content - metadata = result.metadata || {} + throw new Error('Unsupported fileUrl scheme: only data: URIs and http(s):// URLs are allowed') } if (!content.trim()) { diff --git a/apps/sim/lib/knowledge/service.test.ts b/apps/sim/lib/knowledge/service.test.ts new file mode 100644 index 00000000000..08e81b2746c --- /dev/null +++ b/apps/sim/lib/knowledge/service.test.ts @@ -0,0 +1,114 @@ +/** + * @vitest-environment node + */ +import { + dbChainMock, + dbChainMockFns, + permissionsMock, + permissionsMockFns, + resetDbChainMock, +} from '@sim/testing' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +vi.mock('@sim/db', () => dbChainMock) +vi.mock('@/lib/workspaces/permissions/utils', () => permissionsMock) + +import { KnowledgeBasePermissionError, updateKnowledgeBase } from '@/lib/knowledge/service' + +/** + * These tests guard the workspace mass-assignment fix: + * a user with write/admin on the *source* workspace must not be able to move a + * knowledge base into a workspace where they have no permission, and must not + * be able to clear `workspaceId` (which would orphan the KB to its original + * `userId`, who may not be the caller). + */ +describe('updateKnowledgeBase — workspace transfer authorization', () => { + beforeEach(() => { + vi.clearAllMocks() + dbChainMockFns.limit.mockReset() + resetDbChainMock() + }) + + it('rejects workspaceId change without actorUserId', async () => { + await expect( + updateKnowledgeBase('kb-1', { workspaceId: 'ws-target' }, 'req-1') + ).rejects.toBeInstanceOf(KnowledgeBasePermissionError) + expect(permissionsMockFns.mockGetUserEntityPermissions).not.toHaveBeenCalled() + }) + + it('rejects clearing workspaceId to null when actor is not the KB owner', async () => { + dbChainMockFns.limit.mockResolvedValueOnce([{ workspaceId: 'ws-current', userId: 'owner' }]) + + await expect( + updateKnowledgeBase('kb-1', { workspaceId: null }, 'req-1', { actorUserId: 'attacker' }) + ).rejects.toMatchObject({ + code: 'KNOWLEDGE_BASE_FORBIDDEN', + message: 'Only the knowledge base owner can remove it from a workspace', + }) + expect(permissionsMockFns.mockGetUserEntityPermissions).not.toHaveBeenCalled() + }) + + it('allows the KB owner to clear workspaceId to null (gate passes; target permission not checked)', async () => { + dbChainMockFns.limit.mockResolvedValueOnce([{ workspaceId: 'ws-current', userId: 'owner' }]) + + await expect( + updateKnowledgeBase('kb-1', { workspaceId: null }, 'req-1', { actorUserId: 'owner' }) + ).rejects.not.toBeInstanceOf(KnowledgeBasePermissionError) + expect(permissionsMockFns.mockGetUserEntityPermissions).not.toHaveBeenCalled() + }) + + it('rejects transfer when actor has no permission on target workspace', async () => { + dbChainMockFns.limit.mockResolvedValueOnce([{ workspaceId: 'ws-current', userId: 'u-1' }]) + permissionsMockFns.mockGetUserEntityPermissions.mockResolvedValueOnce(null) + + await expect( + updateKnowledgeBase('kb-1', { workspaceId: 'ws-target' }, 'req-1', { + actorUserId: 'attacker', + }) + ).rejects.toMatchObject({ + code: 'KNOWLEDGE_BASE_FORBIDDEN', + message: 'User does not have permission on the target workspace', + }) + expect(permissionsMockFns.mockGetUserEntityPermissions).toHaveBeenCalledWith( + 'attacker', + 'workspace', + 'ws-target' + ) + }) + + it('rejects transfer when actor only has read permission on target workspace', async () => { + dbChainMockFns.limit.mockResolvedValueOnce([{ workspaceId: 'ws-current', userId: 'u-1' }]) + permissionsMockFns.mockGetUserEntityPermissions.mockResolvedValueOnce('read') + + await expect( + updateKnowledgeBase('kb-1', { workspaceId: 'ws-target' }, 'req-1', { + actorUserId: 'reader', + }) + ).rejects.toBeInstanceOf(KnowledgeBasePermissionError) + }) + + it('throws when knowledge base does not exist during transfer', async () => { + dbChainMockFns.limit.mockResolvedValueOnce([]) + + await expect( + updateKnowledgeBase('kb-missing', { workspaceId: 'ws-target' }, 'req-1', { + actorUserId: 'u-1', + }) + ).rejects.toThrow('Knowledge base kb-missing not found') + expect(permissionsMockFns.mockGetUserEntityPermissions).not.toHaveBeenCalled() + }) + + it('locks the knowledge base row (SELECT … FOR UPDATE) before the permission check', async () => { + dbChainMockFns.limit.mockResolvedValueOnce([{ workspaceId: 'ws-current', userId: 'u-1' }]) + permissionsMockFns.mockGetUserEntityPermissions.mockResolvedValueOnce(null) + + await expect( + updateKnowledgeBase('kb-1', { workspaceId: 'ws-target' }, 'req-1', { + actorUserId: 'attacker', + }) + ).rejects.toBeInstanceOf(KnowledgeBasePermissionError) + + expect(dbChainMockFns.transaction).toHaveBeenCalledTimes(1) + expect(dbChainMockFns.for).toHaveBeenCalledWith('update') + }) +}) diff --git a/apps/sim/lib/knowledge/service.ts b/apps/sim/lib/knowledge/service.ts index 0ba5de8162a..8bca342cb77 100644 --- a/apps/sim/lib/knowledge/service.ts +++ b/apps/sim/lib/knowledge/service.ts @@ -21,6 +21,10 @@ export class KnowledgeBaseConflictError extends Error { } } +export class KnowledgeBasePermissionError extends Error { + readonly code = 'KNOWLEDGE_BASE_FORBIDDEN' as const +} + export type KnowledgeBaseScope = 'active' | 'archived' | 'all' /** @@ -148,7 +152,9 @@ export async function createKnowledgeBase( const hasPermission = await getUserEntityPermissions(data.userId, 'workspace', data.workspaceId) if (hasPermission !== 'admin' && hasPermission !== 'write') { - throw new Error('User does not have permission to create knowledge bases in this workspace') + throw new KnowledgeBasePermissionError( + 'User does not have permission to create knowledge bases in this workspace' + ) } const newKnowledgeBase = { @@ -226,7 +232,8 @@ export async function updateKnowledgeBase( overlap: number } }, - requestId: string + requestId: string, + options?: { actorUserId?: string } ): Promise { const now = new Date() const updateData: { @@ -252,38 +259,81 @@ export async function updateKnowledgeBase( updateData.chunkingConfig = updates.chunkingConfig } - if (updates.name !== undefined) { - const existing = await db - .select({ id: knowledgeBase.id, workspaceId: knowledgeBase.workspaceId }) - .from(knowledgeBase) - .where(and(eq(knowledgeBase.id, knowledgeBaseId), isNull(knowledgeBase.deletedAt))) - .limit(1) + if (updates.workspaceId !== undefined && !options?.actorUserId) { + throw new KnowledgeBasePermissionError( + 'actorUserId is required to change a knowledge base workspace' + ) + } - if (existing.length > 0 && existing[0].workspaceId) { - const duplicate = await db - .select({ id: knowledgeBase.id }) + try { + await db.transaction(async (tx) => { + const [currentKb] = await tx + .select({ workspaceId: knowledgeBase.workspaceId, userId: knowledgeBase.userId }) .from(knowledgeBase) - .where( - and( - eq(knowledgeBase.workspaceId, existing[0].workspaceId), - eq(knowledgeBase.name, updates.name), - isNull(knowledgeBase.deletedAt), - ne(knowledgeBase.id, knowledgeBaseId) - ) - ) + .where(and(eq(knowledgeBase.id, knowledgeBaseId), isNull(knowledgeBase.deletedAt))) + .for('update') .limit(1) - if (duplicate.length > 0) { - throw new KnowledgeBaseConflictError(updates.name) + if (!currentKb) { + throw new Error(`Knowledge base ${knowledgeBaseId} not found`) } - } - } - try { - await db - .update(knowledgeBase) - .set(updateData) - .where(and(eq(knowledgeBase.id, knowledgeBaseId), isNull(knowledgeBase.deletedAt))) + if (updates.workspaceId !== undefined) { + const actorUserId = options?.actorUserId as string + const currentWorkspaceId = currentKb.workspaceId ?? null + const targetWorkspaceId = updates.workspaceId ?? null + + if (targetWorkspaceId !== currentWorkspaceId) { + if (!targetWorkspaceId) { + if (actorUserId !== currentKb.userId) { + throw new KnowledgeBasePermissionError( + 'Only the knowledge base owner can remove it from a workspace' + ) + } + } else { + const targetPermission = await getUserEntityPermissions( + actorUserId, + 'workspace', + targetWorkspaceId + ) + if (targetPermission !== 'write' && targetPermission !== 'admin') { + throw new KnowledgeBasePermissionError( + 'User does not have permission on the target workspace' + ) + } + } + } + } + + if (updates.name !== undefined) { + const effectiveWorkspaceId = + updates.workspaceId !== undefined ? updates.workspaceId : currentKb.workspaceId + + if (effectiveWorkspaceId) { + const duplicate = await tx + .select({ id: knowledgeBase.id }) + .from(knowledgeBase) + .where( + and( + eq(knowledgeBase.workspaceId, effectiveWorkspaceId), + eq(knowledgeBase.name, updates.name), + isNull(knowledgeBase.deletedAt), + ne(knowledgeBase.id, knowledgeBaseId) + ) + ) + .limit(1) + + if (duplicate.length > 0) { + throw new KnowledgeBaseConflictError(updates.name) + } + } + } + + await tx + .update(knowledgeBase) + .set(updateData) + .where(and(eq(knowledgeBase.id, knowledgeBaseId), isNull(knowledgeBase.deletedAt))) + }) } catch (error: unknown) { if (getPostgresErrorCode(error) === '23505' && updates.name !== undefined) { throw new KnowledgeBaseConflictError(updates.name) diff --git a/apps/sim/lib/mcp/client.ts b/apps/sim/lib/mcp/client.ts index 93588aecdd3..bbc5cb19e00 100644 --- a/apps/sim/lib/mcp/client.ts +++ b/apps/sim/lib/mcp/client.ts @@ -18,6 +18,7 @@ import { import { createLogger } from '@sim/logger' import { getErrorMessage } from '@sim/utils/errors' import { getMaxExecutionTimeout } from '@/lib/core/execution-limits' +import { createMcpPinnedFetch } from '@/lib/mcp/pinned-fetch' import { type McpClientOptions, McpConnectionError, @@ -51,34 +52,15 @@ export class McpClient { '2024-11-05', // Initial stable release ] - /** - * Creates a new MCP client. - * - * Accepts either the legacy (config, securityPolicy?) signature - * or a single McpClientOptions object with an optional onToolsChanged callback. - */ - constructor(config: McpServerConfig, securityPolicy?: McpSecurityPolicy) - constructor(options: McpClientOptions) - constructor( - configOrOptions: McpServerConfig | McpClientOptions, - securityPolicy?: McpSecurityPolicy - ) { - if ('config' in configOrOptions) { - this.config = configOrOptions.config - this.securityPolicy = configOrOptions.securityPolicy ?? { - requireConsent: true, - auditLevel: 'basic', - maxToolExecutionsPerHour: 1000, - } - this.onToolsChanged = configOrOptions.onToolsChanged - } else { - this.config = configOrOptions - this.securityPolicy = securityPolicy ?? { - requireConsent: true, - auditLevel: 'basic', - maxToolExecutionsPerHour: 1000, - } + constructor(options: McpClientOptions) { + this.config = options.config + this.securityPolicy = options.securityPolicy ?? { + requireConsent: true, + auditLevel: 'basic', + maxToolExecutionsPerHour: 1000, } + this.onToolsChanged = options.onToolsChanged + const resolvedIP = options.resolvedIP this.connectionStatus = { connected: false } @@ -90,6 +72,7 @@ export class McpClient { requestInit: { headers: this.config.headers, }, + ...(resolvedIP ? { fetch: createMcpPinnedFetch(resolvedIP) } : {}), }) this.client = new Client( diff --git a/apps/sim/lib/mcp/connection-manager.ts b/apps/sim/lib/mcp/connection-manager.ts index 3d6627be57b..a150b194a87 100644 --- a/apps/sim/lib/mcp/connection-manager.ts +++ b/apps/sim/lib/mcp/connection-manager.ts @@ -71,7 +71,8 @@ export class McpConnectionManager { async connect( config: McpServerConfig, userId: string, - workspaceId: string + workspaceId: string, + resolvedIP?: string | null ): Promise<{ supportsListChanged: boolean }> { if (this.disposed) { logger.warn('Connection manager is disposed, ignoring connect request') @@ -106,6 +107,7 @@ export class McpConnectionManager { maxToolExecutionsPerHour: 1000, }, onToolsChanged, + resolvedIP: resolvedIP ?? undefined, }) try { diff --git a/apps/sim/lib/mcp/domain-check.test.ts b/apps/sim/lib/mcp/domain-check.test.ts index 6cc76716ca0..ff559caa8cf 100644 --- a/apps/sim/lib/mcp/domain-check.test.ts +++ b/apps/sim/lib/mcp/domain-check.test.ts @@ -4,13 +4,17 @@ import { inputValidationMock, inputValidationMockFns } from '@sim/testing' import { beforeEach, describe, expect, it, vi } from 'vitest' -const { mockGetAllowedMcpDomainsFromEnv, mockDnsLookup } = vi.hoisted(() => ({ +const { mockGetAllowedMcpDomainsFromEnv, mockDnsLookup, hostedFlag } = vi.hoisted(() => ({ mockGetAllowedMcpDomainsFromEnv: vi.fn<() => string[] | null>(), mockDnsLookup: vi.fn(), + hostedFlag: { value: false }, })) vi.mock('@/lib/core/config/feature-flags', () => ({ getAllowedMcpDomainsFromEnv: mockGetAllowedMcpDomainsFromEnv, + get isHosted() { + return hostedFlag.value + }, })) vi.mock('@/lib/core/security/input-validation.server', () => inputValidationMock) @@ -331,41 +335,44 @@ describe('validateMcpServerSsrf', () => { beforeEach(() => { vi.clearAllMocks() mockGetAllowedMcpDomainsFromEnv.mockReturnValue(null) + hostedFlag.value = false }) - it('does nothing for undefined URL', async () => { - await expect(validateMcpServerSsrf(undefined)).resolves.toBeUndefined() + it('returns null for undefined URL', async () => { + await expect(validateMcpServerSsrf(undefined)).resolves.toBeNull() expect(mockDnsLookup).not.toHaveBeenCalled() }) - it('skips validation for env var URLs', async () => { - await expect(validateMcpServerSsrf('{{MCP_SERVER_URL}}')).resolves.toBeUndefined() + it('returns null and skips validation for env var URLs', async () => { + await expect(validateMcpServerSsrf('{{MCP_SERVER_URL}}')).resolves.toBeNull() expect(mockDnsLookup).not.toHaveBeenCalled() }) - it('skips validation for URLs with env var in hostname', async () => { - await expect(validateMcpServerSsrf('https://{{MCP_HOST}}/mcp')).resolves.toBeUndefined() + it('returns null and skips validation for URLs with env var in hostname', async () => { + await expect(validateMcpServerSsrf('https://{{MCP_HOST}}/mcp')).resolves.toBeNull() expect(mockDnsLookup).not.toHaveBeenCalled() }) - it('allows localhost URLs without DNS lookup', async () => { - await expect(validateMcpServerSsrf('http://localhost:3000/mcp')).resolves.toBeUndefined() + it('returns null for localhost URLs without DNS lookup', async () => { + await expect(validateMcpServerSsrf('http://localhost:3000/mcp')).resolves.toBeNull() expect(mockDnsLookup).not.toHaveBeenCalled() }) - it('allows 127.0.0.1 URLs without DNS lookup', async () => { - await expect(validateMcpServerSsrf('http://127.0.0.1:8080/mcp')).resolves.toBeUndefined() + it('returns null for 127.0.0.1 literal without DNS lookup', async () => { + await expect(validateMcpServerSsrf('http://127.0.0.1:8080/mcp')).resolves.toBeNull() expect(mockDnsLookup).not.toHaveBeenCalled() }) - it('allows URLs that resolve to public IPs', async () => { + it('returns resolved IP for URLs that resolve to public IPs', async () => { mockDnsLookup.mockResolvedValue({ address: '93.184.216.34' }) - await expect(validateMcpServerSsrf('https://example.com/mcp')).resolves.toBeUndefined() + await expect(validateMcpServerSsrf('https://example.com/mcp')).resolves.toBe('93.184.216.34') }) - it('allows HTTP URLs on non-localhost hosts', async () => { + it('returns resolved IP for HTTP URLs on non-localhost hosts', async () => { mockDnsLookup.mockResolvedValue({ address: '93.184.216.34' }) - await expect(validateMcpServerSsrf('http://example.com:3000/mcp')).resolves.toBeUndefined() + await expect(validateMcpServerSsrf('http://example.com:3000/mcp')).resolves.toBe( + '93.184.216.34' + ) }) it('throws McpSsrfError for cloud metadata IP literal', async () => { @@ -402,21 +409,97 @@ describe('validateMcpServerSsrf', () => { ) }) - it('allows URLs resolving to loopback (localhost alias)', async () => { + it('returns resolved IP for URLs resolving to loopback on self-hosted (localhost alias)', async () => { mockDnsLookup.mockResolvedValue({ address: '127.0.0.1' }) - await expect(validateMcpServerSsrf('http://my-local-alias:3000/mcp')).resolves.toBeUndefined() + await expect(validateMcpServerSsrf('http://my-local-alias:3000/mcp')).resolves.toBe('127.0.0.1') }) it('throws for malformed URLs', async () => { await expect(validateMcpServerSsrf('not-a-url')).rejects.toThrow(McpSsrfError) }) + describe('hosted environment', () => { + beforeEach(() => { + hostedFlag.value = true + }) + + it('rejects localhost URLs on hosted', async () => { + await expect(validateMcpServerSsrf('http://localhost:3000/mcp')).rejects.toThrow(McpSsrfError) + }) + + it('rejects 127.0.0.1 URLs on hosted', async () => { + await expect(validateMcpServerSsrf('http://127.0.0.1:8080/mcp')).rejects.toThrow(McpSsrfError) + }) + + it('rejects [::1] URLs on hosted', async () => { + await expect(validateMcpServerSsrf('http://[::1]:8080/mcp')).rejects.toThrow(McpSsrfError) + }) + + it('rejects URLs resolving to loopback on hosted', async () => { + mockDnsLookup.mockResolvedValue({ address: '127.0.0.1' }) + await expect(validateMcpServerSsrf('http://my-local-alias:3000/mcp')).rejects.toThrow( + McpSsrfError + ) + }) + + it('returns resolved IP for public IP resolutions on hosted', async () => { + mockDnsLookup.mockResolvedValue({ address: '93.184.216.34' }) + await expect(validateMcpServerSsrf('https://example.com/mcp')).resolves.toBe('93.184.216.34') + }) + + it('skips loopback check on hosted when allowlist is configured', async () => { + mockGetAllowedMcpDomainsFromEnv.mockReturnValue(['localhost']) + await expect(validateMcpServerSsrf('http://localhost:3000/mcp')).resolves.toBeNull() + }) + + it('still blocks RFC-1918 IP literals on hosted (regression)', async () => { + await expect(validateMcpServerSsrf('http://10.0.0.1/mcp')).rejects.toThrow(McpSsrfError) + await expect(validateMcpServerSsrf('http://192.168.1.1/mcp')).rejects.toThrow(McpSsrfError) + }) + + it('still blocks cloud metadata IP on hosted (regression)', async () => { + await expect( + validateMcpServerSsrf('http://169.254.169.254/latest/meta-data/') + ).rejects.toThrow(McpSsrfError) + }) + + it('still blocks DNS resolutions to private IPs on hosted (regression)', async () => { + mockDnsLookup.mockResolvedValue({ address: '10.0.0.5' }) + await expect(validateMcpServerSsrf('https://internal.corp/mcp')).rejects.toThrow(McpSsrfError) + }) + + it('still skips env var hostnames on hosted', async () => { + await expect(validateMcpServerSsrf('{{MCP_SERVER_URL}}')).resolves.toBeNull() + await expect(validateMcpServerSsrf('https://{{MCP_HOST}}/mcp')).resolves.toBeNull() + expect(mockDnsLookup).not.toHaveBeenCalled() + }) + }) + + describe('self-hosted environment (regression)', () => { + beforeEach(() => { + hostedFlag.value = false + }) + + it('still allows localhost URLs (returns null, no pinning needed)', async () => { + await expect(validateMcpServerSsrf('http://localhost:3000/mcp')).resolves.toBeNull() + }) + + it('still allows 127.0.0.1 URLs (returns null, no pinning needed)', async () => { + await expect(validateMcpServerSsrf('http://127.0.0.1:8080/mcp')).resolves.toBeNull() + }) + + it('returns resolved loopback IP for DNS aliases (caller pins)', async () => { + mockDnsLookup.mockResolvedValue({ address: '127.0.0.1' }) + await expect(validateMcpServerSsrf('http://my-local-alias/mcp')).resolves.toBe('127.0.0.1') + }) + }) + it('skips all checks when ALLOWED_MCP_DOMAINS is configured', async () => { mockGetAllowedMcpDomainsFromEnv.mockReturnValue(['internal.corp']) - await expect(validateMcpServerSsrf('http://10.0.0.1/mcp')).resolves.toBeUndefined() + await expect(validateMcpServerSsrf('http://10.0.0.1/mcp')).resolves.toBeNull() await expect( validateMcpServerSsrf('http://169.254.169.254/latest/meta-data/') - ).resolves.toBeUndefined() + ).resolves.toBeNull() expect(mockDnsLookup).not.toHaveBeenCalled() }) }) diff --git a/apps/sim/lib/mcp/domain-check.ts b/apps/sim/lib/mcp/domain-check.ts index 83ec36c69f5..9e57b23c7f4 100644 --- a/apps/sim/lib/mcp/domain-check.ts +++ b/apps/sim/lib/mcp/domain-check.ts @@ -2,7 +2,7 @@ import dns from 'dns/promises' import { createLogger } from '@sim/logger' import { toError } from '@sim/utils/errors' import * as ipaddr from 'ipaddr.js' -import { getAllowedMcpDomainsFromEnv } from '@/lib/core/config/feature-flags' +import { getAllowedMcpDomainsFromEnv, isHosted } from '@/lib/core/config/feature-flags' import { isPrivateOrReservedIP } from '@/lib/core/security/input-validation.server' import { createEnvVarPattern } from '@/executor/utils/reference-validation' @@ -133,16 +133,25 @@ function isLocalhostHostname(hostname: string): boolean { * Does NOT enforce protocol (HTTP is allowed) or block service ports — MCP * servers legitimately run on HTTP and on arbitrary ports. * - * Localhost/loopback is always allowed for local dev MCP servers. + * Localhost/loopback is allowed for local dev MCP servers in self-hosted + * deployments, but blocked on the hosted environment (sim.ai) where users + * must not be able to reach the server's own loopback interface. * URLs with env var references in the hostname are skipped — they will be * validated after resolution at execution time. * + * Returns the resolved IP address when DNS resolution was performed (so the + * caller can pin subsequent connections to that IP and prevent DNS-rebinding + * TOCTOU attacks). Returns null in cases where pinning is unnecessary or + * impossible: no URL, allowlist-only mode, env-var hostnames (validated later), + * IP literals (no DNS to rebind), and localhost on self-hosted (no rebinding + * risk against a fixed loopback). + * * @throws McpSsrfError if the URL resolves to a blocked IP address */ -export async function validateMcpServerSsrf(url: string | undefined): Promise { - if (!url) return - if (getAllowedMcpDomainsFromEnv() !== null) return - if (hasEnvVarInHostname(url)) return +export async function validateMcpServerSsrf(url: string | undefined): Promise { + if (!url) return null + if (getAllowedMcpDomainsFromEnv() !== null) return null + if (hasEnvVarInHostname(url)) return null let hostname: string try { @@ -154,28 +163,47 @@ export async function validateMcpServerSsrf(url: string | undefined): Promise { + const capturedAgentOptions: unknown[] = [] + class MockAgent { + constructor(options: unknown) { + capturedAgentOptions.push(options) + } + } + return { + mockAgent: MockAgent, + mockCreatePinnedLookup: vi.fn(), + mockUndiciFetch: vi.fn(), + capturedAgentOptions, + } + } +) + +vi.mock('undici', () => ({ Agent: mockAgent, fetch: mockUndiciFetch })) +vi.mock('@/lib/core/security/input-validation.server', () => ({ + createPinnedLookup: mockCreatePinnedLookup, +})) + +import { createMcpPinnedFetch } from '@/lib/mcp/pinned-fetch' + +describe('createMcpPinnedFetch', () => { + beforeEach(() => { + vi.clearAllMocks() + capturedAgentOptions.length = 0 + mockCreatePinnedLookup.mockReturnValue('pinned-lookup-fn') + mockUndiciFetch.mockResolvedValue(new Response('ok')) + }) + + it('builds an undici Agent with the pinned lookup for the resolved IP', () => { + createMcpPinnedFetch('203.0.113.10') + expect(mockCreatePinnedLookup).toHaveBeenCalledWith('203.0.113.10') + expect(capturedAgentOptions).toHaveLength(1) + expect(capturedAgentOptions[0]).toEqual({ connect: { lookup: 'pinned-lookup-fn' } }) + }) + + it('forwards the dispatcher on every fetch call', async () => { + const fetchLike = createMcpPinnedFetch('203.0.113.10') + await fetchLike('https://example.com/mcp', { method: 'POST' }) + expect(mockUndiciFetch).toHaveBeenCalledTimes(1) + const [url, init] = mockUndiciFetch.mock.calls[0] + expect(url).toBe('https://example.com/mcp') + expect((init as { dispatcher?: unknown }).dispatcher).toBeInstanceOf(mockAgent) + expect((init as { method?: string }).method).toBe('POST') + }) + + it('preserves caller-provided init options (headers, signal)', async () => { + const fetchLike = createMcpPinnedFetch('203.0.113.10') + const controller = new AbortController() + await fetchLike('https://example.com/mcp', { + method: 'GET', + headers: { 'x-test': '1' }, + signal: controller.signal, + }) + const init = mockUndiciFetch.mock.calls[0][1] as RequestInit & { dispatcher?: unknown } + expect(init.headers).toEqual({ 'x-test': '1' }) + expect(init.signal).toBe(controller.signal) + expect(init.dispatcher).toBeInstanceOf(mockAgent) + }) + + it('handles undefined init gracefully', async () => { + const fetchLike = createMcpPinnedFetch('203.0.113.10') + await fetchLike('https://example.com/mcp') + const init = mockUndiciFetch.mock.calls[0][1] as { dispatcher?: unknown } + expect(init.dispatcher).toBeInstanceOf(mockAgent) + }) + + it('reuses the same dispatcher across calls (one Agent per fetch instance)', async () => { + const fetchLike = createMcpPinnedFetch('203.0.113.10') + await fetchLike('https://example.com/a') + await fetchLike('https://example.com/b') + expect(capturedAgentOptions).toHaveLength(1) + const d1 = (mockUndiciFetch.mock.calls[0][1] as { dispatcher: unknown }).dispatcher + const d2 = (mockUndiciFetch.mock.calls[1][1] as { dispatcher: unknown }).dispatcher + expect(d1).toBe(d2) + }) +}) diff --git a/apps/sim/lib/mcp/pinned-fetch.ts b/apps/sim/lib/mcp/pinned-fetch.ts new file mode 100644 index 00000000000..798de5710e6 --- /dev/null +++ b/apps/sim/lib/mcp/pinned-fetch.ts @@ -0,0 +1,38 @@ +import type { FetchLike } from '@modelcontextprotocol/sdk/shared/transport.js' +import { Agent, type RequestInit as UndiciRequestInit, fetch as undiciFetch } from 'undici' +import { createPinnedLookup } from '@/lib/core/security/input-validation.server' + +/** + * Creates a FetchLike that pins all outbound HTTP connections to a pre-resolved + * IP address. Used by the MCP transport to prevent DNS-rebinding (TOCTOU) + * attacks: validation performs DNS once and confirms the IP is allowed; this + * fetch then forces every subsequent request (initial POST, SSE GET, redirects) + * to use that same IP, regardless of what the hostname now resolves to. + * + * Uses undici's `fetch` directly so the `dispatcher` option is part of the + * real type contract — not a cast that would silently break if a future + * runtime swapped out the implementation. + * + * The original hostname is preserved on the request so TLS SNI and the Host + * header continue to match the certificate. + */ +export function createMcpPinnedFetch(resolvedIP: string): FetchLike { + const dispatcher = new Agent({ + connect: { lookup: createPinnedLookup(resolvedIP) }, + }) + + return (async (url, init) => { + // DOM `RequestInit` and undici's `RequestInit` are structurally compatible + // at runtime (Node's global fetch IS undici) but differ in TS types. + // Cast the init through unknown to bridge the typing without losing the + // critical `dispatcher` typing on the call itself. + const undiciInit: UndiciRequestInit = { + // double-cast-allowed: DOM RequestInit and undici RequestInit are structurally compatible at runtime (Node's global fetch IS undici) but the TS types differ + ...(init as unknown as UndiciRequestInit), + dispatcher, + } + const response = await undiciFetch(url as string | URL, undiciInit) + // double-cast-allowed: undici Response and DOM Response are structurally compatible at runtime; bridging the types is required to satisfy the FetchLike contract + return response as unknown as Response + }) satisfies FetchLike +} diff --git a/apps/sim/lib/mcp/service.ts b/apps/sim/lib/mcp/service.ts index 4ec764c3d41..7838f682822 100644 --- a/apps/sim/lib/mcp/service.ts +++ b/apps/sim/lib/mcp/service.ts @@ -69,13 +69,13 @@ class McpService { config: McpServerConfig, userId: string, workspaceId?: string - ): Promise { + ): Promise<{ config: McpServerConfig; resolvedIP: string | null }> { const { config: resolvedConfig } = await resolveMcpConfigEnvVars(config, userId, workspaceId, { strict: true, }) validateMcpDomain(resolvedConfig.url) - await validateMcpServerSsrf(resolvedConfig.url) - return resolvedConfig + const resolvedIP = await validateMcpServerSsrf(resolvedConfig.url) + return { config: resolvedConfig, resolvedIP } } /** @@ -156,7 +156,10 @@ class McpService { /** * Create and connect to an MCP client */ - private async createClient(config: McpServerConfig): Promise { + private async createClient( + config: McpServerConfig, + resolvedIP: string | null + ): Promise { const securityPolicy = { requireConsent: true, auditLevel: 'basic' as const, @@ -164,7 +167,11 @@ class McpService { allowedOrigins: config.url ? [new URL(config.url).origin] : undefined, } - const client = new McpClient(config, securityPolicy) + const client = new McpClient({ + config, + securityPolicy, + resolvedIP: resolvedIP ?? undefined, + }) await client.connect() return client } @@ -194,11 +201,15 @@ class McpService { throw new Error(`Server ${serverId} not found or not accessible`) } - const resolvedConfig = await this.resolveConfigEnvVars(config, userId, workspaceId) + const { config: resolvedConfig, resolvedIP } = await this.resolveConfigEnvVars( + config, + userId, + workspaceId + ) if (extraHeaders && Object.keys(extraHeaders).length > 0) { resolvedConfig.headers = { ...resolvedConfig.headers, ...extraHeaders } } - const client = await this.createClient(resolvedConfig) + const client = await this.createClient(resolvedConfig, resolvedIP) try { const result = await client.callTool(toolCall) @@ -348,14 +359,18 @@ class McpService { const allTools: McpTool[] = [] const results = await Promise.allSettled( servers.map(async (config) => { - const resolvedConfig = await this.resolveConfigEnvVars(config, userId, workspaceId) - const client = await this.createClient(resolvedConfig) + const { config: resolvedConfig, resolvedIP } = await this.resolveConfigEnvVars( + config, + userId, + workspaceId + ) + const client = await this.createClient(resolvedConfig, resolvedIP) try { const tools = await client.listTools() logger.debug( `[${requestId}] Discovered ${tools.length} tools from server ${config.name}` ) - return { serverId: config.id, tools, resolvedConfig } + return { serverId: config.id, tools, resolvedConfig, resolvedIP } } finally { await client.disconnect() } @@ -394,13 +409,15 @@ class McpService { if (mcpConnectionManager) { for (const [index, result] of results.entries()) { if (result.status === 'fulfilled') { - const { resolvedConfig } = result.value - mcpConnectionManager.connect(resolvedConfig, userId, workspaceId).catch((err) => { - logger.warn( - `[${requestId}] Persistent connection failed for ${servers[index].name}:`, - err - ) - }) + const { resolvedConfig, resolvedIP } = result.value + mcpConnectionManager + .connect(resolvedConfig, userId, workspaceId, resolvedIP) + .catch((err) => { + logger.warn( + `[${requestId}] Persistent connection failed for ${servers[index].name}:`, + err + ) + }) } } } @@ -450,8 +467,12 @@ class McpService { throw new Error(`Server ${serverId} not found or not accessible`) } - const resolvedConfig = await this.resolveConfigEnvVars(config, userId, workspaceId) - const client = await this.createClient(resolvedConfig) + const { config: resolvedConfig, resolvedIP } = await this.resolveConfigEnvVars( + config, + userId, + workspaceId + ) + const client = await this.createClient(resolvedConfig, resolvedIP) try { const tools = await client.listTools() @@ -490,8 +511,12 @@ class McpService { for (const config of servers) { try { - const resolvedConfig = await this.resolveConfigEnvVars(config, userId, workspaceId) - const client = await this.createClient(resolvedConfig) + const { config: resolvedConfig, resolvedIP } = await this.resolveConfigEnvVars( + config, + userId, + workspaceId + ) + const client = await this.createClient(resolvedConfig, resolvedIP) const tools = await client.listTools() await client.disconnect() diff --git a/apps/sim/lib/mcp/types.ts b/apps/sim/lib/mcp/types.ts index db9ac11fd0a..f4f5c939efd 100644 --- a/apps/sim/lib/mcp/types.ts +++ b/apps/sim/lib/mcp/types.ts @@ -161,6 +161,14 @@ export interface McpClientOptions { config: McpServerConfig securityPolicy?: McpSecurityPolicy onToolsChanged?: McpToolsChangedCallback + /** + * Pre-resolved IP address to pin all transport HTTP connections to. When + * set, the SDK transport uses a custom fetch backed by an undici Agent with + * a fixed DNS lookup, preventing DNS-rebinding (TOCTOU) attacks between + * URL validation and connection. Should be supplied by callers that have + * just validated the URL via `validateMcpServerSsrf`. + */ + resolvedIP?: string } /** diff --git a/apps/sim/package.json b/apps/sim/package.json index 71521b166bb..f175df18d51 100644 --- a/apps/sim/package.json +++ b/apps/sim/package.json @@ -196,6 +196,7 @@ "three": "0.177.0", "tldts": "7.0.30", "twilio": "5.9.0", + "undici": "7.25.0", "unified": "11.0.5", "unpdf": "1.4.0", "xlsx": "https://cdn.sheetjs.com/xlsx-0.20.3/xlsx-0.20.3.tgz", diff --git a/apps/sim/tools/agiloft/utils.server.ts b/apps/sim/tools/agiloft/utils.server.ts new file mode 100644 index 00000000000..3aaa0c62b71 --- /dev/null +++ b/apps/sim/tools/agiloft/utils.server.ts @@ -0,0 +1,79 @@ +import { createLogger } from '@sim/logger' +import { + type SecureFetchResponse, + secureFetchWithPinnedIP, + validateUrlWithDNS, +} from '@/lib/core/security/input-validation.server' +import type { AgiloftBaseParams } from '@/tools/agiloft/types' + +const logger = createLogger('AgiloftAuthServer') + +/** + * Validates the Agiloft instance URL and resolves its DNS once, returning the + * resolved IP so subsequent requests can pin to it. This prevents DNS-rebinding + * (TOCTOU) SSRF where the hostname could resolve to a private IP on a later + * lookup. Server-only — uses node:dns/promises. + */ +export async function resolveAgiloftInstance(instanceUrl: string): Promise { + const validation = await validateUrlWithDNS(instanceUrl, 'instanceUrl') + if (!validation.isValid || !validation.resolvedIP) { + throw new Error(validation.error || 'Invalid Agiloft instance URL') + } + return validation.resolvedIP +} + +/** + * DNS-pinned variant of agiloftLogin. Requires a pre-resolved IP so the + * connection cannot be steered to a different host between validation and + * the actual TCP connection. + */ +export async function agiloftLoginPinned( + params: AgiloftBaseParams, + resolvedIP: string +): Promise { + const base = params.instanceUrl.replace(/\/$/, '') + const kb = encodeURIComponent(params.knowledgeBase) + const login = encodeURIComponent(params.login) + const password = encodeURIComponent(params.password) + + const url = `${base}/ewws/EWLogin?$KB=${kb}&$login=${login}&$password=${password}` + const response = await secureFetchWithPinnedIP(url, resolvedIP, { method: 'POST' }) + + if (!response.ok) { + const errorText = await response.text() + throw new Error(`Agiloft login failed: ${response.status} - ${errorText}`) + } + + const data = (await response.json()) as { access_token?: string } + const token = data.access_token + + if (!token) { + throw new Error('Agiloft login did not return an access token') + } + + return token +} + +/** + * DNS-pinned variant of agiloftLogout. Best-effort — failures are logged but + * not thrown. + */ +export async function agiloftLogoutPinned( + instanceUrl: string, + knowledgeBase: string, + token: string, + resolvedIP: string +): Promise { + try { + const base = instanceUrl.replace(/\/$/, '') + const kb = encodeURIComponent(knowledgeBase) + await secureFetchWithPinnedIP(`${base}/ewws/EWLogout?$KB=${kb}`, resolvedIP, { + method: 'POST', + headers: { Authorization: `Bearer ${token}` }, + }) + } catch (error) { + logger.warn('Agiloft logout failed (best-effort)', { error }) + } +} + +export type { SecureFetchResponse } diff --git a/apps/sim/tools/agiloft/utils.test.ts b/apps/sim/tools/agiloft/utils.test.ts new file mode 100644 index 00000000000..b80eb2a33ba --- /dev/null +++ b/apps/sim/tools/agiloft/utils.test.ts @@ -0,0 +1,136 @@ +/** + * @vitest-environment node + */ +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { executeAgiloftRequest } from '@/tools/agiloft/utils' + +const baseParams = { + instanceUrl: 'https://example.agiloft.com', + knowledgeBase: 'demo', + login: 'admin', + password: 'secret', + table: 'contracts', +} + +function mockFetchResponse(body: { ok?: boolean; status?: number; json?: unknown; text?: string }) { + return { + ok: body.ok ?? true, + status: body.status ?? 200, + statusText: '', + headers: new Headers(), + text: async () => body.text ?? '', + json: async () => body.json ?? {}, + } as unknown as Response +} + +const fetchSpy = vi.fn() + +beforeEach(() => { + fetchSpy.mockReset() + vi.stubGlobal('fetch', fetchSpy) +}) + +afterEach(() => { + vi.unstubAllGlobals() +}) + +describe('executeAgiloftRequest', () => { + it('logs in, runs the operation with the bearer token, then logs out', async () => { + fetchSpy + .mockResolvedValueOnce(mockFetchResponse({ json: { access_token: 'tok-1' } })) + .mockResolvedValueOnce(mockFetchResponse({ json: { id: 42, fields: { name: 'foo' } } })) + .mockResolvedValueOnce(mockFetchResponse({})) + + const result = await executeAgiloftRequest( + baseParams, + (base) => ({ + url: `${base}/ewws/REST/demo/contracts/42`, + method: 'GET', + headers: { Accept: 'application/json' }, + }), + async (response) => { + const data = (await response.json()) as { id: number; fields: Record } + return { + success: response.ok, + output: { id: String(data.id), fields: data.fields }, + } + } + ) + + expect(result).toEqual({ success: true, output: { id: '42', fields: { name: 'foo' } } }) + + const calls = fetchSpy.mock.calls + expect(calls).toHaveLength(3) + expect(calls[0][0]).toBe( + 'https://example.agiloft.com/ewws/EWLogin?$KB=demo&$login=admin&$password=secret' + ) + expect(calls[1][0]).toBe('https://example.agiloft.com/ewws/REST/demo/contracts/42') + expect(calls[1][1]).toMatchObject({ + method: 'GET', + headers: { Accept: 'application/json', Authorization: 'Bearer tok-1' }, + }) + expect(calls[2][0]).toBe('https://example.agiloft.com/ewws/EWLogout?$KB=demo') + }) + + it('still calls logout when the operation throws', async () => { + fetchSpy + .mockResolvedValueOnce(mockFetchResponse({ json: { access_token: 'tok-2' } })) + .mockResolvedValueOnce(mockFetchResponse({ ok: false, status: 500 })) + .mockResolvedValueOnce(mockFetchResponse({})) + + await expect( + executeAgiloftRequest( + baseParams, + (base) => ({ url: `${base}/ewws/REST/demo/contracts/42`, method: 'GET' }), + async (response) => { + if (!response.ok) throw new Error('operation failed') + return { success: true, output: {} } + } + ) + ).rejects.toThrow('operation failed') + + expect(fetchSpy).toHaveBeenCalledTimes(3) + expect(fetchSpy.mock.calls[2][0]).toContain('/ewws/EWLogout') + }) + + it('swallows logout failures (best-effort)', async () => { + fetchSpy + .mockResolvedValueOnce(mockFetchResponse({ json: { access_token: 'tok-3' } })) + .mockResolvedValueOnce(mockFetchResponse({ json: { ok: true } })) + .mockRejectedValueOnce(new Error('logout network error')) + + const result = await executeAgiloftRequest( + baseParams, + (base) => ({ url: `${base}/ewws/REST/demo/contracts/42`, method: 'GET' }), + async () => ({ success: true, output: {} }) + ) + + expect(result.success).toBe(true) + }) + + it('throws when login does not return an access token', async () => { + fetchSpy.mockResolvedValueOnce(mockFetchResponse({ json: {} })) + + await expect( + executeAgiloftRequest( + baseParams, + (base) => ({ url: `${base}/ewws/REST/demo/contracts/42`, method: 'GET' }), + async () => ({ success: true, output: {} }) + ) + ).rejects.toThrow('Agiloft login did not return an access token') + + expect(fetchSpy).toHaveBeenCalledTimes(1) + }) + + it('rejects an instance URL that fails synchronous URL validation', async () => { + await expect( + executeAgiloftRequest( + { ...baseParams, instanceUrl: 'not-a-valid-url' }, + (base) => ({ url: `${base}/ewws/REST/demo/contracts/42`, method: 'GET' }), + async () => ({ success: true, output: {} }) + ) + ).rejects.toThrow(/Invalid Agiloft instance URL/) + + expect(fetchSpy).not.toHaveBeenCalled() + }) +}) diff --git a/apps/sim/tools/agiloft/utils.ts b/apps/sim/tools/agiloft/utils.ts index 47184deb5fb..811187ab833 100644 --- a/apps/sim/tools/agiloft/utils.ts +++ b/apps/sim/tools/agiloft/utils.ts @@ -47,7 +47,7 @@ async function agiloftLogin(params: AgiloftBaseParams): Promise { throw new Error(`Agiloft login failed: ${response.status} - ${errorText}`) } - const data = await response.json() + const data = (await response.json()) as { access_token?: string } const token = data.access_token if (!token) { diff --git a/apps/sim/tools/grafana/update_alert_rule.ts b/apps/sim/tools/grafana/update_alert_rule.ts index 9ca23bff773..19f2bf8164d 100644 --- a/apps/sim/tools/grafana/update_alert_rule.ts +++ b/apps/sim/tools/grafana/update_alert_rule.ts @@ -1,3 +1,4 @@ +import { validateExternalUrl } from '@/lib/core/security/input-validation' import { ALERT_RULE_OUTPUT_FIELDS, type GrafanaUpdateAlertRuleParams } from '@/tools/grafana/types' import { mapAlertRule } from '@/tools/grafana/utils' import type { ToolConfig, ToolResponse } from '@/tools/types' @@ -269,6 +270,15 @@ export const updateAlertRuleTool: ToolConfig return { success: true, output: mapAlertRule(data) } }, diff --git a/apps/sim/tools/grafana/update_dashboard.ts b/apps/sim/tools/grafana/update_dashboard.ts index 23449f36830..99a5f4352d3 100644 --- a/apps/sim/tools/grafana/update_dashboard.ts +++ b/apps/sim/tools/grafana/update_dashboard.ts @@ -1,3 +1,4 @@ +import { validateExternalUrl } from '@/lib/core/security/input-validation' import type { GrafanaUpdateDashboardParams } from '@/tools/grafana/types' import type { ToolConfig, ToolResponse } from '@/tools/types' @@ -183,6 +184,15 @@ export const updateDashboardTool: ToolConfig