From 4d3e1b3787629494eebf4bf0fb9aeddea3fea744 Mon Sep 17 00:00:00 2001 From: Waleed Latif Date: Sat, 16 May 2026 21:45:46 -0700 Subject: [PATCH 1/6] fix(security): KB fileUrl LFI, MCP/Agiloft SSRF pinning, form OTP, KB authz --- .../app/api/chat/[identifier]/otp/route.ts | 223 +----- .../api/form/[identifier]/otp/route.test.ts | 695 ++++++++++++++++++ .../app/api/form/[identifier]/otp/route.ts | 267 +++++++ apps/sim/app/api/form/utils.test.ts | 10 +- apps/sim/app/api/form/utils.ts | 2 +- .../[id]/documents/upsert/route.test.ts | 111 +++ apps/sim/app/api/knowledge/[id]/route.test.ts | 53 +- apps/sim/app/api/knowledge/[id]/route.ts | 8 +- apps/sim/app/api/knowledge/route.test.ts | 17 + apps/sim/app/api/knowledge/route.ts | 5 + .../api/mcp/servers/test-connection/route.ts | 12 +- .../api/tools/agiloft/attach/route.test.ts | 144 ++++ .../sim/app/api/tools/agiloft/attach/route.ts | 26 +- .../api/tools/agiloft/retrieve/route.test.ts | 163 ++++ .../app/api/tools/agiloft/retrieve/route.ts | 26 +- .../[identifier]/components/email-auth.tsx | 284 +++++++ .../app/form/[identifier]/components/index.ts | 1 + apps/sim/app/form/[identifier]/form.tsx | 5 + apps/sim/hooks/queries/forms.ts | 32 + apps/sim/lib/api/contracts/forms.ts | 38 + .../lib/api/contracts/knowledge/documents.ts | 5 +- .../api/contracts/knowledge/shared.test.ts | 57 ++ .../sim/lib/api/contracts/knowledge/shared.ts | 15 + apps/sim/lib/core/security/otp.ts | 247 +++++++ .../knowledge/documents/document-processor.ts | 9 +- apps/sim/lib/knowledge/service.test.ts | 114 +++ apps/sim/lib/knowledge/service.ts | 106 ++- apps/sim/lib/mcp/client.ts | 37 +- apps/sim/lib/mcp/connection-manager.ts | 4 +- apps/sim/lib/mcp/domain-check.test.ts | 121 ++- apps/sim/lib/mcp/domain-check.ts | 64 +- apps/sim/lib/mcp/pinned-fetch.test.ts | 90 +++ apps/sim/lib/mcp/pinned-fetch.ts | 25 + apps/sim/lib/mcp/service.ts | 67 +- apps/sim/lib/mcp/types.ts | 8 + apps/sim/package.json | 1 + apps/sim/tools/agiloft/utils.test.ts | 212 ++++++ apps/sim/tools/agiloft/utils.ts | 50 +- apps/sim/tools/grafana/update_alert_rule.ts | 8 +- apps/sim/tools/grafana/update_dashboard.ts | 24 +- bun.lock | 1 + scripts/check-api-validation-contracts.ts | 4 +- 42 files changed, 3006 insertions(+), 385 deletions(-) create mode 100644 apps/sim/app/api/form/[identifier]/otp/route.test.ts create mode 100644 apps/sim/app/api/form/[identifier]/otp/route.ts create mode 100644 apps/sim/app/api/knowledge/[id]/documents/upsert/route.test.ts create mode 100644 apps/sim/app/api/tools/agiloft/attach/route.test.ts create mode 100644 apps/sim/app/api/tools/agiloft/retrieve/route.test.ts create mode 100644 apps/sim/app/form/[identifier]/components/email-auth.tsx create mode 100644 apps/sim/lib/api/contracts/knowledge/shared.test.ts create mode 100644 apps/sim/lib/core/security/otp.ts create mode 100644 apps/sim/lib/knowledge/service.test.ts create mode 100644 apps/sim/lib/mcp/pinned-fetch.test.ts create mode 100644 apps/sim/lib/mcp/pinned-fetch.ts create mode 100644 apps/sim/tools/agiloft/utils.test.ts diff --git a/apps/sim/app/api/chat/[identifier]/otp/route.ts b/apps/sim/app/api/chat/[identifier]/otp/route.ts index b2e129b5fa8..c89cd4721e1 100644 --- a/apps/sim/app/api/chat/[identifier]/otp/route.ts +++ b/apps/sim/app/api/chat/[identifier]/otp/route.ts @@ -1,18 +1,24 @@ -import { randomInt } from 'crypto' import { db } from '@sim/db' -import { chat, verification } from '@sim/db/schema' +import { chat } from '@sim/db/schema' import { createLogger } from '@sim/logger' -import { generateId } from '@sim/utils/id' -import { and, eq, gt, isNull } from 'drizzle-orm' +import { and, eq, isNull } from 'drizzle-orm' import type { NextRequest } from 'next/server' import { renderOTPEmail } from '@/components/emails' import { requestChatEmailOtpContract, verifyChatEmailOtpContract } from '@/lib/api/contracts/chats' import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' -import { getRedisClient } from '@/lib/core/config/redis' -import type { TokenBucketConfig } from '@/lib/core/rate-limiter' import { RateLimiter } from '@/lib/core/rate-limiter' import { addCorsHeaders, isEmailAllowed } from '@/lib/core/security/deployment' -import { getStorageMethod } from '@/lib/core/storage' +import { + decodeOTPValue, + deleteOTP, + generateOTP, + getOTP, + incrementOTPAttempts, + MAX_OTP_ATTEMPTS, + OTP_EMAIL_RATE_LIMIT, + OTP_IP_RATE_LIMIT, + storeOTP, +} from '@/lib/core/security/otp' import { generateRequestId, getClientIp } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' import { sendEmail } from '@/lib/messaging/email/mailer' @@ -23,199 +29,6 @@ const logger = createLogger('ChatOtpAPI') const rateLimiter = new RateLimiter() -const OTP_IP_RATE_LIMIT: TokenBucketConfig = { - maxTokens: 10, - refillRate: 10, - refillIntervalMs: 15 * 60_000, -} - -const OTP_EMAIL_RATE_LIMIT: TokenBucketConfig = { - maxTokens: 3, - refillRate: 3, - refillIntervalMs: 15 * 60_000, -} - -function generateOTP(): string { - return randomInt(100000, 1000000).toString() -} - -const OTP_EXPIRY = 15 * 60 // 15 minutes -const OTP_EXPIRY_MS = OTP_EXPIRY * 1000 -const MAX_OTP_ATTEMPTS = 5 - -/** - * OTP values are stored as "code:attempts" (e.g. "654321:0"). - * This keeps the attempt counter in the same key/row as the OTP itself. - */ -function encodeOTPValue(otp: string, attempts: number): string { - return `${otp}:${attempts}` -} - -function decodeOTPValue(value: string): { otp: string; attempts: number } { - const lastColon = value.lastIndexOf(':') - if (lastColon === -1) return { otp: value, attempts: 0 } - const attempts = Number.parseInt(value.slice(lastColon + 1), 10) - return { otp: value.slice(0, lastColon), attempts: Number.isNaN(attempts) ? 0 : attempts } -} - -/** - * Stores OTP in Redis or database depending on storage method. - * Uses the verification table for database storage. - */ -async function storeOTP(email: string, chatId: string, otp: string): Promise { - const identifier = `chat-otp:${chatId}:${email}` - const storageMethod = getStorageMethod() - const value = encodeOTPValue(otp, 0) - - if (storageMethod === 'redis') { - const redis = getRedisClient() - if (!redis) { - throw new Error('Redis configured but client unavailable') - } - await redis.set(`otp:${email}:${chatId}`, value, 'EX', OTP_EXPIRY) - } else { - const now = new Date() - const expiresAt = new Date(now.getTime() + OTP_EXPIRY_MS) - - await db.transaction(async (tx) => { - await tx.delete(verification).where(eq(verification.identifier, identifier)) - await tx.insert(verification).values({ - id: generateId(), - identifier, - value, - expiresAt, - createdAt: now, - updatedAt: now, - }) - }) - } -} - -async function getOTP(email: string, chatId: string): Promise { - const identifier = `chat-otp:${chatId}:${email}` - const storageMethod = getStorageMethod() - - if (storageMethod === 'redis') { - const redis = getRedisClient() - if (!redis) { - throw new Error('Redis configured but client unavailable') - } - return redis.get(`otp:${email}:${chatId}`) - } - - const now = new Date() - const [record] = await db - .select({ value: verification.value }) - .from(verification) - .where(and(eq(verification.identifier, identifier), gt(verification.expiresAt, now))) - .limit(1) - - return record?.value ?? null -} - -/** - * Lua script for atomic OTP attempt increment. - * Returns: "LOCKED" if max attempts reached (key deleted), new encoded value otherwise, nil if key missing. - */ -const ATOMIC_INCREMENT_SCRIPT = ` -local val = redis.call('GET', KEYS[1]) -if not val then return nil end -local colon = val:find(':([^:]*$)') -local otp, attempts -if colon then - otp = val:sub(1, colon - 1) - attempts = tonumber(val:sub(colon + 1)) or 0 -else - otp = val - attempts = 0 -end -attempts = attempts + 1 -if attempts >= tonumber(ARGV[1]) then - redis.call('DEL', KEYS[1]) - return 'LOCKED' -end -local newVal = otp .. ':' .. attempts -local ttl = redis.call('TTL', KEYS[1]) -if ttl > 0 then - redis.call('SET', KEYS[1], newVal, 'EX', ttl) -else - redis.call('SET', KEYS[1], newVal) -end -return newVal -` - -/** - * Atomically increments OTP attempts. Returns 'locked' if max reached, 'incremented' otherwise. - */ -async function incrementOTPAttempts( - email: string, - chatId: string, - currentValue: string -): Promise<'locked' | 'incremented'> { - const identifier = `chat-otp:${chatId}:${email}` - const storageMethod = getStorageMethod() - - if (storageMethod === 'redis') { - const redis = getRedisClient() - if (!redis) { - throw new Error('Redis configured but client unavailable') - } - const key = `otp:${email}:${chatId}` - const result = await redis.eval(ATOMIC_INCREMENT_SCRIPT, 1, key, MAX_OTP_ATTEMPTS) - if (result === null || result === 'LOCKED') return 'locked' - return 'incremented' - } - - // DB path: optimistic locking with retry on conflict - const MAX_RETRIES = 3 - let value = currentValue - - for (let attempt = 0; attempt < MAX_RETRIES; attempt++) { - const { otp, attempts } = decodeOTPValue(value) - const newAttempts = attempts + 1 - - if (newAttempts >= MAX_OTP_ATTEMPTS) { - await db.delete(verification).where(eq(verification.identifier, identifier)) - return 'locked' - } - - const newValue = encodeOTPValue(otp, newAttempts) - const updated = await db - .update(verification) - .set({ value: newValue, updatedAt: new Date() }) - .where(and(eq(verification.identifier, identifier), eq(verification.value, value))) - .returning({ id: verification.id }) - - if (updated.length > 0) return 'incremented' - - // Conflict: another request already incremented — re-read and retry - const fresh = await getOTP(email, chatId) - if (!fresh) return 'locked' - value = fresh - } - - // Exhausted retries — re-read final state to determine outcome - const final = await getOTP(email, chatId) - if (!final) return 'locked' - const { attempts: finalAttempts } = decodeOTPValue(final) - return finalAttempts >= MAX_OTP_ATTEMPTS ? 'locked' : 'incremented' -} - -async function deleteOTP(email: string, chatId: string): Promise { - const identifier = `chat-otp:${chatId}:${email}` - const storageMethod = getStorageMethod() - - if (storageMethod === 'redis') { - const redis = getRedisClient() - if (!redis) { - throw new Error('Redis configured but client unavailable') - } - await redis.del(`otp:${email}:${chatId}`) - } else { - await db.delete(verification).where(eq(verification.identifier, identifier)) - } -} - export const POST = withRouteHandler( async (request: NextRequest, context: { params: Promise<{ identifier: string }> }) => { const { identifier } = await context.params @@ -305,7 +118,7 @@ export const POST = withRouteHandler( } const otp = generateOTP() - await storeOTP(email, deployment.id, otp) + await storeOTP('chat', deployment.id, email, otp) const emailHtml = await renderOTPEmail( otp, @@ -379,7 +192,7 @@ export const PUT = withRouteHandler( const deployment = deploymentResult[0] - const storedValue = await getOTP(email, deployment.id) + const storedValue = await getOTP('chat', deployment.id, email) if (!storedValue) { return addCorsHeaders( createErrorResponse('No verification code found, request a new one', 400), @@ -390,7 +203,7 @@ export const PUT = withRouteHandler( const { otp: storedOTP, attempts } = decodeOTPValue(storedValue) if (attempts >= MAX_OTP_ATTEMPTS) { - await deleteOTP(email, deployment.id) + await deleteOTP('chat', deployment.id, email) logger.warn(`[${requestId}] OTP already locked out for ${email}`) return addCorsHeaders( createErrorResponse('Too many failed attempts. Please request a new code.', 429), @@ -399,7 +212,7 @@ export const PUT = withRouteHandler( } if (storedOTP !== otp) { - const result = await incrementOTPAttempts(email, deployment.id, storedValue) + const result = await incrementOTPAttempts('chat', deployment.id, email, storedValue) if (result === 'locked') { logger.warn(`[${requestId}] OTP invalidated after max failed attempts for ${email}`) return addCorsHeaders( @@ -410,7 +223,7 @@ export const PUT = withRouteHandler( return addCorsHeaders(createErrorResponse('Invalid verification code', 400), request) } - await deleteOTP(email, deployment.id) + await deleteOTP('chat', deployment.id, email) const response = addCorsHeaders( createSuccessResponse({ diff --git a/apps/sim/app/api/form/[identifier]/otp/route.test.ts b/apps/sim/app/api/form/[identifier]/otp/route.test.ts new file mode 100644 index 00000000000..4b3b13441d0 --- /dev/null +++ b/apps/sim/app/api/form/[identifier]/otp/route.test.ts @@ -0,0 +1,695 @@ +/** + * Tests for form OTP API route + * + * @vitest-environment node + */ +import { + redisConfigMock, + redisConfigMockFns, + requestUtilsMockFns, + workflowsApiUtilsMock, + workflowsApiUtilsMockFns, +} from '@sim/testing' +import { NextRequest } from 'next/server' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +const { + mockRedisSet, + mockRedisGet, + mockRedisDel, + mockRedisTtl, + mockRedisEval, + mockRedisClient, + mockDbSelect, + mockDbInsert, + mockDbDelete, + mockDbUpdate, + mockSendEmail, + mockRenderOTPEmail, + mockAddCorsHeaders, + mockSetFormAuthCookie, + mockGetStorageMethod, + mockZodParse, + mockGetEnv, +} = vi.hoisted(() => { + const mockRedisSet = vi.fn() + const mockRedisGet = vi.fn() + const mockRedisDel = vi.fn() + const mockRedisTtl = vi.fn() + const mockRedisEval = vi.fn() + const mockRedisClient = { + set: mockRedisSet, + get: mockRedisGet, + del: mockRedisDel, + ttl: mockRedisTtl, + eval: mockRedisEval, + } + return { + mockRedisSet, + mockRedisGet, + mockRedisDel, + mockRedisTtl, + mockRedisEval, + mockRedisClient, + mockDbSelect: vi.fn(), + mockDbInsert: vi.fn(), + mockDbDelete: vi.fn(), + mockDbUpdate: vi.fn(), + mockSendEmail: vi.fn(), + mockRenderOTPEmail: vi.fn(), + mockAddCorsHeaders: vi.fn(), + mockSetFormAuthCookie: vi.fn(), + mockGetStorageMethod: vi.fn(), + mockZodParse: vi.fn(), + mockGetEnv: vi.fn(), + } +}) + +const mockGetRedisClient = redisConfigMockFns.mockGetRedisClient +const mockCreateSuccessResponse = workflowsApiUtilsMockFns.mockCreateSuccessResponse +const mockCreateErrorResponse = workflowsApiUtilsMockFns.mockCreateErrorResponse + +vi.mock('@/lib/core/config/redis', () => redisConfigMock) + +vi.mock('@sim/db', () => ({ + db: { + select: mockDbSelect, + insert: mockDbInsert, + delete: mockDbDelete, + update: mockDbUpdate, + transaction: vi.fn(async (callback: (tx: Record) => unknown) => { + return callback({ + select: mockDbSelect, + insert: mockDbInsert, + delete: mockDbDelete, + update: mockDbUpdate, + }) + }), + }, +})) + +vi.mock('drizzle-orm', () => ({ + eq: vi.fn((field: string, value: string) => ({ field, value, type: 'eq' })), + and: vi.fn((...conditions: unknown[]) => ({ conditions, type: 'and' })), + gt: vi.fn((field: string, value: string) => ({ field, value, type: 'gt' })), + lt: vi.fn((field: string, value: string) => ({ field, value, type: 'lt' })), + isNull: vi.fn((field: unknown) => ({ field, type: 'isNull' })), +})) + +vi.mock('@/lib/core/storage', () => ({ + getStorageMethod: mockGetStorageMethod, +})) + +const { mockCheckRateLimitDirect } = vi.hoisted(() => ({ + mockCheckRateLimitDirect: vi.fn(), +})) + +vi.mock('@/lib/core/rate-limiter', () => ({ + RateLimiter: class { + checkRateLimitDirect = mockCheckRateLimitDirect + }, +})) + +vi.mock('@/lib/messaging/email/mailer', () => ({ + sendEmail: mockSendEmail, +})) + +vi.mock('@/components/emails', () => ({ + renderOTPEmail: mockRenderOTPEmail, +})) + +vi.mock('@/lib/core/security/deployment', () => ({ + addCorsHeaders: mockAddCorsHeaders, + isEmailAllowed: (email: string, allowedEmails: string[]) => { + if (allowedEmails.includes(email)) return true + const atIndex = email.indexOf('@') + if (atIndex > 0) { + const domain = email.substring(atIndex + 1) + if (domain && allowedEmails.some((allowed: string) => allowed === `@${domain}`)) return true + } + return false + }, +})) + +vi.mock('@/app/api/form/utils', () => ({ + setFormAuthCookie: mockSetFormAuthCookie, +})) + +vi.mock('@/app/api/workflows/utils', () => workflowsApiUtilsMock) + +vi.mock('@/lib/core/config/env', () => ({ + env: { + NEXT_PUBLIC_APP_URL: 'http://localhost:3000', + NODE_ENV: 'test', + }, + getEnv: mockGetEnv, + isTruthy: vi.fn().mockReturnValue(false), + isFalsy: vi.fn().mockReturnValue(true), +})) + +vi.mock('zod', () => { + class ZodError extends Error { + errors: Array<{ message: string }> + constructor(issues: Array<{ message: string }>) { + super('ZodError') + this.errors = issues + } + } + const chainable: Record = {} + const proxy: Record = new Proxy(chainable, { + get(target, prop) { + if (prop === 'parse') return mockZodParse + if (prop === 'safeParse') { + return (data: unknown) => ({ success: true, data }) + } + if (prop === 'then') return undefined + if (typeof prop === 'symbol') return Reflect.get(target, prop) + if (!(prop in target)) { + target[prop as string] = vi.fn().mockReturnValue(proxy) + } + return target[prop as string] + }, + }) + const makeChain = vi.fn(() => proxy) + return { + z: new Proxy( + { ZodError }, + { + get(target, prop) { + if (prop === 'ZodError') return ZodError + if (typeof prop === 'symbol') return Reflect.get(target, prop) + return makeChain + }, + } + ), + } +}) + +import { POST, PUT } from './route' + +describe('Form OTP API Route', () => { + const mockEmail = 'user@example.com' + const mockFormId = 'form-123' + const mockIdentifier = 'test-form' + const mockOTP = '123456' + + const deploymentRow = { + id: mockFormId, + authType: 'email', + allowedEmails: [mockEmail], + title: 'Test Form', + isActive: true, + } + + const verifyDeploymentRow = { + id: mockFormId, + authType: 'email', + password: null, + allowedEmails: [mockEmail], + isActive: true, + } + + const selectOnce = (rows: unknown[]) => + mockDbSelect.mockImplementationOnce(() => ({ + from: vi.fn().mockReturnValue({ + where: vi.fn().mockReturnValue({ + limit: vi.fn().mockResolvedValue(rows), + }), + }), + })) + + beforeEach(() => { + vi.clearAllMocks() + + vi.spyOn(Math, 'random').mockReturnValue(0.123456) + vi.spyOn(Date, 'now').mockReturnValue(1640995200000) + + vi.stubGlobal('crypto', { + ...crypto, + randomUUID: vi.fn().mockReturnValue('test-uuid-1234'), + }) + + mockGetRedisClient.mockReturnValue(mockRedisClient) + mockRedisSet.mockResolvedValue('OK') + mockRedisGet.mockResolvedValue(null) + mockRedisDel.mockResolvedValue(1) + mockRedisTtl.mockResolvedValue(600) + + mockDbSelect.mockImplementation(() => ({ + from: vi.fn().mockReturnValue({ + where: vi.fn().mockReturnValue({ + limit: vi.fn().mockResolvedValue([]), + }), + }), + })) + mockDbInsert.mockImplementation(() => ({ values: vi.fn().mockResolvedValue(undefined) })) + mockDbDelete.mockImplementation(() => ({ where: vi.fn().mockResolvedValue(undefined) })) + mockDbUpdate.mockImplementation(() => ({ + set: vi.fn().mockReturnValue({ where: vi.fn().mockResolvedValue(undefined) }), + })) + + mockGetStorageMethod.mockReturnValue('redis') + + mockSendEmail.mockResolvedValue({ success: true }) + mockRenderOTPEmail.mockResolvedValue('OTP Email') + + mockAddCorsHeaders.mockImplementation((response: unknown) => response) + mockCreateSuccessResponse.mockImplementation((data: unknown) => ({ + json: () => Promise.resolve(data), + status: 200, + })) + mockCreateErrorResponse.mockImplementation((message: string, status: number) => ({ + json: () => Promise.resolve({ error: message }), + status, + })) + + requestUtilsMockFns.mockGenerateRequestId.mockReturnValue('req-123') + requestUtilsMockFns.mockGetClientIp.mockReturnValue('1.2.3.4') + + mockCheckRateLimitDirect.mockResolvedValue({ + allowed: true, + remaining: 10, + resetAt: new Date(Date.now() + 60_000), + }) + + mockZodParse.mockImplementation((data: unknown) => data) + mockGetEnv.mockReturnValue('http://localhost:3000') + }) + + afterEach(() => { + vi.restoreAllMocks() + }) + + describe('POST /otp - request code', () => { + it('stores OTP in Redis when storage is redis and sends email', async () => { + selectOnce([deploymentRow]) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + await POST(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockRedisSet).toHaveBeenCalledWith( + `form-otp:${mockEmail}:${mockFormId}`, + expect.stringMatching(/^\d{6}:0$/), + 'EX', + 900 + ) + expect(mockSendEmail).toHaveBeenCalledWith( + expect.objectContaining({ to: mockEmail, subject: expect.stringContaining('Test Form') }) + ) + expect(mockDbInsert).not.toHaveBeenCalled() + }) + + it('stores OTP in database when storage is database', async () => { + mockGetStorageMethod.mockReturnValue('database') + mockGetRedisClient.mockReturnValue(null) + selectOnce([deploymentRow]) + const insertValues = vi.fn().mockResolvedValue(undefined) + mockDbInsert.mockImplementationOnce(() => ({ values: insertValues })) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + await POST(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(insertValues).toHaveBeenCalledWith( + expect.objectContaining({ + identifier: `form-otp:${mockFormId}:${mockEmail}`, + value: expect.stringMatching(/^\d{6}:0$/), + }) + ) + expect(mockRedisSet).not.toHaveBeenCalled() + }) + + it('returns 404 when form is not found', async () => { + selectOnce([]) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + await POST(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockCreateErrorResponse).toHaveBeenCalledWith('Form not found', 404) + expect(mockSendEmail).not.toHaveBeenCalled() + }) + + it('returns 403 when form is inactive', async () => { + selectOnce([{ ...deploymentRow, isActive: false }]) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + await POST(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockCreateErrorResponse).toHaveBeenCalledWith( + 'This form is currently unavailable', + 403 + ) + expect(mockSendEmail).not.toHaveBeenCalled() + }) + + it('returns 400 when form authType is not email', async () => { + selectOnce([{ ...deploymentRow, authType: 'public' }]) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + await POST(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockCreateErrorResponse).toHaveBeenCalledWith( + 'This form does not use email authentication', + 400 + ) + expect(mockSendEmail).not.toHaveBeenCalled() + }) + + it('returns 403 when email is not in allowedEmails', async () => { + selectOnce([{ ...deploymentRow, allowedEmails: ['other@example.com'] }]) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + await POST(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockCreateErrorResponse).toHaveBeenCalledWith( + 'Email not authorized for this form', + 403 + ) + expect(mockSendEmail).not.toHaveBeenCalled() + }) + + it('authorizes by domain match in allowedEmails', async () => { + selectOnce([{ ...deploymentRow, allowedEmails: ['@example.com'] }]) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + await POST(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockSendEmail).toHaveBeenCalled() + }) + + it('returns 429 with Retry-After when IP rate limit is exceeded', async () => { + mockCheckRateLimitDirect.mockResolvedValueOnce({ + allowed: false, + remaining: 0, + resetAt: new Date(Date.now() + 900_000), + retryAfterMs: 900_000, + }) + const headerSet = vi.fn() + mockCreateErrorResponse.mockImplementationOnce((message: string, status: number) => ({ + json: () => Promise.resolve({ error: message }), + status, + headers: { set: headerSet }, + })) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + const response = await POST(request, { + params: Promise.resolve({ identifier: mockIdentifier }), + }) + + expect(response.status).toBe(429) + expect(headerSet).toHaveBeenCalledWith('Retry-After', '900') + expect(mockSendEmail).not.toHaveBeenCalled() + expect(mockDbSelect).not.toHaveBeenCalled() + }) + + it('returns 429 with Retry-After when email rate limit is exceeded', async () => { + mockCheckRateLimitDirect + .mockResolvedValueOnce({ + allowed: true, + remaining: 9, + resetAt: new Date(Date.now() + 60_000), + }) + .mockResolvedValueOnce({ + allowed: false, + remaining: 0, + resetAt: new Date(Date.now() + 900_000), + retryAfterMs: 900_000, + }) + const headerSet = vi.fn() + mockCreateErrorResponse.mockImplementationOnce((message: string, status: number) => ({ + json: () => Promise.resolve({ error: message }), + status, + headers: { set: headerSet }, + })) + selectOnce([deploymentRow]) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + const response = await POST(request, { + params: Promise.resolve({ identifier: mockIdentifier }), + }) + + expect(response.status).toBe(429) + expect(headerSet).toHaveBeenCalledWith('Retry-After', '900') + expect(mockSendEmail).not.toHaveBeenCalled() + }) + + it('rate-limits the IP bucket before reading the deployment row', async () => { + mockCheckRateLimitDirect.mockResolvedValueOnce({ + allowed: false, + remaining: 0, + resetAt: new Date(Date.now() + 900_000), + retryAfterMs: 900_000, + }) + mockCreateErrorResponse.mockImplementationOnce((message: string, status: number) => ({ + json: () => Promise.resolve({ error: message }), + status, + headers: { set: vi.fn() }, + })) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + await POST(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockDbSelect).not.toHaveBeenCalled() + }) + + it('returns 500 when email send fails', async () => { + selectOnce([deploymentRow]) + mockSendEmail.mockResolvedValueOnce({ success: false, message: 'smtp down' }) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'POST', + body: JSON.stringify({ email: mockEmail }), + }) + + await POST(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockCreateErrorResponse).toHaveBeenCalledWith('Failed to send verification email', 500) + }) + }) + + describe('PUT /otp - verify code', () => { + it('verifies OTP, deletes it, and sets the form auth cookie on success', async () => { + selectOnce([verifyDeploymentRow]) + mockRedisGet.mockResolvedValue(`${mockOTP}:0`) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'PUT', + body: JSON.stringify({ email: mockEmail, otp: mockOTP }), + }) + + await PUT(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockRedisGet).toHaveBeenCalledWith(`form-otp:${mockEmail}:${mockFormId}`) + expect(mockRedisDel).toHaveBeenCalledWith(`form-otp:${mockEmail}:${mockFormId}`) + expect(mockSetFormAuthCookie).toHaveBeenCalledWith( + expect.any(Object), + mockFormId, + 'email', + null + ) + expect(mockCreateSuccessResponse).toHaveBeenCalledWith({ authenticated: true }) + }) + + it('returns 404 when form is not found', async () => { + selectOnce([]) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'PUT', + body: JSON.stringify({ email: mockEmail, otp: mockOTP }), + }) + + await PUT(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockCreateErrorResponse).toHaveBeenCalledWith('Form not found', 404) + expect(mockSetFormAuthCookie).not.toHaveBeenCalled() + }) + + it('returns 403 when form is inactive at verify time', async () => { + selectOnce([{ ...verifyDeploymentRow, isActive: false }]) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'PUT', + body: JSON.stringify({ email: mockEmail, otp: mockOTP }), + }) + + await PUT(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockCreateErrorResponse).toHaveBeenCalledWith( + 'This form is currently unavailable', + 403 + ) + expect(mockSetFormAuthCookie).not.toHaveBeenCalled() + }) + + it('returns 403 when email is no longer in allowedEmails at verify time', async () => { + selectOnce([{ ...verifyDeploymentRow, allowedEmails: ['other@example.com'] }]) + mockRedisGet.mockResolvedValue(`${mockOTP}:0`) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'PUT', + body: JSON.stringify({ email: mockEmail, otp: mockOTP }), + }) + + await PUT(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockCreateErrorResponse).toHaveBeenCalledWith( + 'Email not authorized for this form', + 403 + ) + expect(mockSetFormAuthCookie).not.toHaveBeenCalled() + }) + + it('returns 400 when no OTP is stored', async () => { + selectOnce([verifyDeploymentRow]) + mockRedisGet.mockResolvedValue(null) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'PUT', + body: JSON.stringify({ email: mockEmail, otp: mockOTP }), + }) + + await PUT(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockCreateErrorResponse).toHaveBeenCalledWith( + 'No verification code found, request a new one', + 400 + ) + expect(mockSetFormAuthCookie).not.toHaveBeenCalled() + }) + + it('atomically increments attempts on wrong OTP and returns 400', async () => { + selectOnce([verifyDeploymentRow]) + mockRedisGet.mockResolvedValue('654321:0') + mockRedisEval.mockResolvedValue('654321:1') + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'PUT', + body: JSON.stringify({ email: mockEmail, otp: 'wrong1' }), + }) + + await PUT(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockRedisEval).toHaveBeenCalledWith( + expect.any(String), + 1, + `form-otp:${mockEmail}:${mockFormId}`, + 5 + ) + expect(mockCreateErrorResponse).toHaveBeenCalledWith('Invalid verification code', 400) + expect(mockSetFormAuthCookie).not.toHaveBeenCalled() + }) + + it('invalidates OTP and returns 429 after max failed attempts', async () => { + selectOnce([verifyDeploymentRow]) + mockRedisGet.mockResolvedValue('654321:4') + mockRedisEval.mockResolvedValue('LOCKED') + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'PUT', + body: JSON.stringify({ email: mockEmail, otp: 'wrong5' }), + }) + + await PUT(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockCreateErrorResponse).toHaveBeenCalledWith( + 'Too many failed attempts. Please request a new code.', + 429 + ) + expect(mockSetFormAuthCookie).not.toHaveBeenCalled() + }) + + it('rejects when stored OTP is already at max attempts', async () => { + selectOnce([verifyDeploymentRow]) + mockRedisGet.mockResolvedValue(`${mockOTP}:5`) + const deleteWhere = vi.fn().mockResolvedValue(undefined) + mockDbDelete.mockImplementation(() => ({ where: deleteWhere })) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'PUT', + body: JSON.stringify({ email: mockEmail, otp: mockOTP }), + }) + + await PUT(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockCreateErrorResponse).toHaveBeenCalledWith( + 'Too many failed attempts. Please request a new code.', + 429 + ) + expect(mockSetFormAuthCookie).not.toHaveBeenCalled() + }) + + it('uses database storage path when configured', async () => { + mockGetStorageMethod.mockReturnValue('database') + mockGetRedisClient.mockReturnValue(null) + let selectCallCount = 0 + mockDbSelect.mockImplementation(() => ({ + from: vi.fn().mockReturnValue({ + where: vi.fn().mockReturnValue({ + limit: vi.fn().mockImplementation(() => { + selectCallCount++ + if (selectCallCount === 1) return Promise.resolve([verifyDeploymentRow]) + return Promise.resolve([ + { + value: `${mockOTP}:0`, + expiresAt: new Date(Date.now() + 10 * 60 * 1000), + }, + ]) + }), + }), + }), + })) + const deleteWhere = vi.fn().mockResolvedValue(undefined) + mockDbDelete.mockImplementation(() => ({ where: deleteWhere })) + + const request = new NextRequest('http://localhost:3000/api/form/test/otp', { + method: 'PUT', + body: JSON.stringify({ email: mockEmail, otp: mockOTP }), + }) + + await PUT(request, { params: Promise.resolve({ identifier: mockIdentifier }) }) + + expect(mockDbDelete).toHaveBeenCalled() + expect(mockRedisDel).not.toHaveBeenCalled() + expect(mockSetFormAuthCookie).toHaveBeenCalled() + }) + }) +}) diff --git a/apps/sim/app/api/form/[identifier]/otp/route.ts b/apps/sim/app/api/form/[identifier]/otp/route.ts new file mode 100644 index 00000000000..176a3be50b6 --- /dev/null +++ b/apps/sim/app/api/form/[identifier]/otp/route.ts @@ -0,0 +1,267 @@ +import { db } from '@sim/db' +import { form } from '@sim/db/schema' +import { createLogger } from '@sim/logger' +import { and, eq, isNull } from 'drizzle-orm' +import type { NextRequest } from 'next/server' +import { renderOTPEmail } from '@/components/emails' +import { requestFormEmailOtpContract, verifyFormEmailOtpContract } from '@/lib/api/contracts/forms' +import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' +import { RateLimiter } from '@/lib/core/rate-limiter' +import { addCorsHeaders, isEmailAllowed } from '@/lib/core/security/deployment' +import { + decodeOTPValue, + deleteOTP, + generateOTP, + getOTP, + incrementOTPAttempts, + MAX_OTP_ATTEMPTS, + OTP_EMAIL_RATE_LIMIT, + OTP_IP_RATE_LIMIT, + storeOTP, +} from '@/lib/core/security/otp' +import { generateRequestId, getClientIp } from '@/lib/core/utils/request' +import { withRouteHandler } from '@/lib/core/utils/with-route-handler' +import { sendEmail } from '@/lib/messaging/email/mailer' +import { setFormAuthCookie } from '@/app/api/form/utils' +import { createErrorResponse, createSuccessResponse } from '@/app/api/workflows/utils' + +const logger = createLogger('FormOtpAPI') + +const rateLimiter = new RateLimiter() + +export const POST = withRouteHandler( + async (request: NextRequest, context: { params: Promise<{ identifier: string }> }) => { + const { identifier } = await context.params + const requestId = generateRequestId() + + try { + const ip = getClientIp(request) + const ipRateLimit = await rateLimiter.checkRateLimitDirect( + `form-otp:ip:${identifier}:${ip}`, + OTP_IP_RATE_LIMIT + ) + if (!ipRateLimit.allowed) { + logger.warn(`[${requestId}] OTP IP rate limit exceeded for ${identifier} from ${ip}`) + const retryAfter = Math.ceil( + (ipRateLimit.retryAfterMs ?? OTP_IP_RATE_LIMIT.refillIntervalMs) / 1000 + ) + const response = createErrorResponse('Too many requests. Please try again later.', 429) + response.headers.set('Retry-After', String(retryAfter)) + return addCorsHeaders(response, request) + } + + const parsed = await parseRequest(requestFormEmailOtpContract, request, context, { + validationErrorResponse: (error) => + addCorsHeaders( + createErrorResponse(getValidationErrorMessage(error, 'Invalid request'), 400), + request + ), + }) + if (!parsed.success) return parsed.response + const { email } = parsed.data.body + + const deploymentResult = await db + .select({ + id: form.id, + authType: form.authType, + allowedEmails: form.allowedEmails, + title: form.title, + isActive: form.isActive, + }) + .from(form) + .where(and(eq(form.identifier, identifier), isNull(form.archivedAt))) + .limit(1) + + if (deploymentResult.length === 0) { + logger.warn(`[${requestId}] Form not found for identifier: ${identifier}`) + return addCorsHeaders(createErrorResponse('Form not found', 404), request) + } + + const deployment = deploymentResult[0] + + if (!deployment.isActive) { + return addCorsHeaders( + createErrorResponse('This form is currently unavailable', 403), + request + ) + } + + if (deployment.authType !== 'email') { + return addCorsHeaders( + createErrorResponse('This form does not use email authentication', 400), + request + ) + } + + const allowedEmails: string[] = Array.isArray(deployment.allowedEmails) + ? (deployment.allowedEmails as string[]) + : [] + + if (!isEmailAllowed(email, allowedEmails)) { + return addCorsHeaders( + createErrorResponse('Email not authorized for this form', 403), + request + ) + } + + const emailRateLimit = await rateLimiter.checkRateLimitDirect( + `form-otp:email:${deployment.id}:${email.toLowerCase()}`, + OTP_EMAIL_RATE_LIMIT + ) + if (!emailRateLimit.allowed) { + logger.warn( + `[${requestId}] OTP email rate limit exceeded for ${email} on form ${deployment.id}` + ) + const retryAfter = Math.ceil( + (emailRateLimit.retryAfterMs ?? OTP_EMAIL_RATE_LIMIT.refillIntervalMs) / 1000 + ) + const response = createErrorResponse( + 'Too many verification code requests. Please try again later.', + 429 + ) + response.headers.set('Retry-After', String(retryAfter)) + return addCorsHeaders(response, request) + } + + const otp = generateOTP() + await storeOTP('form', deployment.id, email, otp) + + const emailHtml = await renderOTPEmail( + otp, + email, + 'email-verification', + deployment.title || 'Form' + ) + + const emailResult = await sendEmail({ + to: email, + subject: `Verification code for ${deployment.title || 'Form'}`, + html: emailHtml, + }) + + if (!emailResult.success) { + logger.error(`[${requestId}] Failed to send OTP email:`, emailResult.message) + return addCorsHeaders( + createErrorResponse('Failed to send verification email', 500), + request + ) + } + + logger.info(`[${requestId}] OTP sent to ${email} for form ${deployment.id}`) + return addCorsHeaders(createSuccessResponse({ message: 'Verification code sent' }), request) + } catch (error: any) { + logger.error(`[${requestId}] Error processing OTP request:`, error) + return addCorsHeaders( + createErrorResponse(error.message || 'Failed to process request', 500), + request + ) + } + } +) + +export const PUT = withRouteHandler( + async (request: NextRequest, context: { params: Promise<{ identifier: string }> }) => { + const { identifier } = await context.params + const requestId = generateRequestId() + + try { + const parsed = await parseRequest(verifyFormEmailOtpContract, request, context, { + validationErrorResponse: (error) => + addCorsHeaders( + createErrorResponse(getValidationErrorMessage(error, 'Invalid request'), 400), + request + ), + }) + if (!parsed.success) return parsed.response + const { email, otp } = parsed.data.body + + const deploymentResult = await db + .select({ + id: form.id, + authType: form.authType, + password: form.password, + allowedEmails: form.allowedEmails, + isActive: form.isActive, + }) + .from(form) + .where(and(eq(form.identifier, identifier), isNull(form.archivedAt))) + .limit(1) + + if (deploymentResult.length === 0) { + logger.warn(`[${requestId}] Form not found for identifier: ${identifier}`) + return addCorsHeaders(createErrorResponse('Form not found', 404), request) + } + + const deployment = deploymentResult[0] + + if (!deployment.isActive) { + return addCorsHeaders( + createErrorResponse('This form is currently unavailable', 403), + request + ) + } + + if (deployment.authType !== 'email') { + return addCorsHeaders( + createErrorResponse('This form does not use email authentication', 400), + request + ) + } + + const allowedEmails: string[] = Array.isArray(deployment.allowedEmails) + ? (deployment.allowedEmails as string[]) + : [] + + if (!isEmailAllowed(email, allowedEmails)) { + return addCorsHeaders( + createErrorResponse('Email not authorized for this form', 403), + request + ) + } + + const storedValue = await getOTP('form', deployment.id, email) + if (!storedValue) { + return addCorsHeaders( + createErrorResponse('No verification code found, request a new one', 400), + request + ) + } + + const { otp: storedOTP, attempts } = decodeOTPValue(storedValue) + + if (attempts >= MAX_OTP_ATTEMPTS) { + await deleteOTP('form', deployment.id, email) + logger.warn(`[${requestId}] OTP already locked out for ${email}`) + return addCorsHeaders( + createErrorResponse('Too many failed attempts. Please request a new code.', 429), + request + ) + } + + if (storedOTP !== otp) { + const result = await incrementOTPAttempts('form', deployment.id, email, storedValue) + if (result === 'locked') { + logger.warn(`[${requestId}] OTP invalidated after max failed attempts for ${email}`) + return addCorsHeaders( + createErrorResponse('Too many failed attempts. Please request a new code.', 429), + request + ) + } + return addCorsHeaders(createErrorResponse('Invalid verification code', 400), request) + } + + await deleteOTP('form', deployment.id, email) + + const response = addCorsHeaders(createSuccessResponse({ authenticated: true }), request) + setFormAuthCookie(response, deployment.id, deployment.authType, deployment.password) + + return response + } catch (error: any) { + logger.error(`[${requestId}] Error verifying OTP:`, error) + return addCorsHeaders( + createErrorResponse(error.message || 'Failed to process request', 500), + request + ) + } + } +) diff --git a/apps/sim/app/api/form/utils.test.ts b/apps/sim/app/api/form/utils.test.ts index 9c36ccc6e92..1826d9386c1 100644 --- a/apps/sim/app/api/form/utils.test.ts +++ b/apps/sim/app/api/form/utils.test.ts @@ -239,18 +239,20 @@ describe('Form API Utils', () => { }, } as any - // Exact email match should authorize + // Exact email match should require OTP verification, not authorize directly mockIsEmailAllowed.mockReturnValue(true) const result1 = await validateFormAuth('request-id', deployment, mockRequest, { email: 'user@example.com', }) - expect(result1.authorized).toBe(true) + expect(result1.authorized).toBe(false) + expect(result1.error).toBe('otp_required') - // Domain match should authorize + // Domain match should also require OTP verification const result2 = await validateFormAuth('request-id', deployment, mockRequest, { email: 'other@company.com', }) - expect(result2.authorized).toBe(true) + expect(result2.authorized).toBe(false) + expect(result2.error).toBe('otp_required') // Unknown email should not authorize mockIsEmailAllowed.mockReturnValue(false) diff --git a/apps/sim/app/api/form/utils.ts b/apps/sim/app/api/form/utils.ts index 55bbe65e17f..7b1f1df54dc 100644 --- a/apps/sim/app/api/form/utils.ts +++ b/apps/sim/app/api/form/utils.ts @@ -159,7 +159,7 @@ export async function validateFormAuth( const allowedEmails: string[] = deployment.allowedEmails || [] if (isEmailAllowed(email, allowedEmails)) { - return { authorized: true } + return { authorized: false, error: 'otp_required' } } return { authorized: false, error: 'Email not authorized for this form' } diff --git a/apps/sim/app/api/knowledge/[id]/documents/upsert/route.test.ts b/apps/sim/app/api/knowledge/[id]/documents/upsert/route.test.ts new file mode 100644 index 00000000000..d5f64cf306e --- /dev/null +++ b/apps/sim/app/api/knowledge/[id]/documents/upsert/route.test.ts @@ -0,0 +1,111 @@ +/** + * Tests for knowledge base document upsert API route + * + * @vitest-environment node + */ +import { + auditMock, + createMockRequest, + hybridAuthMock, + hybridAuthMockFns, + knowledgeApiUtilsMock, +} from '@sim/testing' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { mockDbChain } = vi.hoisted(() => { + const chain = { + select: vi.fn().mockReturnThis(), + from: vi.fn().mockReturnThis(), + where: vi.fn().mockReturnThis(), + limit: vi.fn().mockResolvedValue([]), + } + return { mockDbChain: chain } +}) + +vi.mock('@sim/db', () => ({ db: mockDbChain })) +vi.mock('@/lib/auth/hybrid', () => hybridAuthMock) +vi.mock('@/app/api/knowledge/utils', () => knowledgeApiUtilsMock) +vi.mock('@sim/audit', () => auditMock) + +vi.mock('@/lib/knowledge/documents/service', () => ({ + createDocumentRecords: vi.fn(), + deleteDocument: vi.fn(), + getProcessingConfig: vi.fn().mockReturnValue({ maxConcurrentDocuments: 1, batchSize: 1 }), + processDocumentsWithQueue: vi.fn(), +})) + +import { createDocumentRecords, processDocumentsWithQueue } from '@/lib/knowledge/documents/service' +import { POST } from '@/app/api/knowledge/[id]/documents/upsert/route' +import { checkKnowledgeBaseWriteAccess } from '@/app/api/knowledge/utils' + +describe('POST /api/knowledge/[id]/documents/upsert', () => { + const params = Promise.resolve({ id: 'kb-123' }) + + beforeEach(() => { + vi.clearAllMocks() + mockDbChain.select.mockReturnThis() + mockDbChain.from.mockReturnThis() + mockDbChain.where.mockReturnThis() + mockDbChain.limit.mockResolvedValue([]) + + hybridAuthMockFns.mockCheckSessionOrInternalAuth.mockResolvedValue({ + success: true, + userId: 'user-1', + authType: 'session', + userName: 'Test User', + userEmail: 'test@example.com', + }) + + vi.mocked(checkKnowledgeBaseWriteAccess).mockResolvedValue({ + hasAccess: true, + knowledgeBase: { id: 'kb-123', userId: 'user-1', workspaceId: 'ws-1', name: 'KB' }, + } as any) + + vi.mocked(createDocumentRecords).mockResolvedValue([ + { documentId: 'doc-new', filename: 'note.txt' }, + ] as any) + vi.mocked(processDocumentsWithQueue).mockResolvedValue(undefined as any) + }) + + const baseBody = { + filename: 'note.txt', + fileSize: 11, + mimeType: 'text/plain', + } + + it('accepts a data: URI', async () => { + const req = createMockRequest('POST', { + ...baseBody, + fileUrl: 'data:text/plain;base64,SGVsbG8gd29ybGQ=', + }) + const res = await POST(req, { params }) + expect(res.status).toBe(200) + expect(createDocumentRecords).toHaveBeenCalled() + }) + + it('accepts an https URL', async () => { + const req = createMockRequest('POST', { + ...baseBody, + fileUrl: 'https://example.com/note.txt', + }) + const res = await POST(req, { params }) + expect(res.status).toBe(200) + expect(createDocumentRecords).toHaveBeenCalled() + }) + + it.each([ + ['absolute local path', '/etc/passwd'], + ['app config path', '/app/.env'], + ['file:// URL', 'file:///etc/passwd'], + ['relative serve path', '/api/files/serve/kb/foo.pdf'], + ['ftp URL', 'ftp://example.com/file.pdf'], + ['parent traversal', '../../etc/passwd'], + ['windows path', 'C:\\Windows\\System32\\config\\SAM'], + ])('rejects %s with 400 and never invokes the pipeline', async (_label, fileUrl) => { + const req = createMockRequest('POST', { ...baseBody, fileUrl }) + const res = await POST(req, { params }) + expect(res.status).toBe(400) + expect(createDocumentRecords).not.toHaveBeenCalled() + expect(processDocumentsWithQueue).not.toHaveBeenCalled() + }) +}) diff --git a/apps/sim/app/api/knowledge/[id]/route.test.ts b/apps/sim/app/api/knowledge/[id]/route.test.ts index 111e42829bd..2d0dc1ce2e2 100644 --- a/apps/sim/app/api/knowledge/[id]/route.test.ts +++ b/apps/sim/app/api/knowledge/[id]/route.test.ts @@ -31,6 +31,7 @@ vi.mock('@/lib/knowledge/service', async (importOriginal) => { getKnowledgeBaseById: vi.fn(), updateKnowledgeBase: vi.fn(), deleteKnowledgeBase: vi.fn(), + KnowledgeBasePermissionError: actual.KnowledgeBasePermissionError, } }) @@ -39,6 +40,7 @@ vi.mock('@/app/api/knowledge/utils', () => knowledgeApiUtilsMock) import { deleteKnowledgeBase, getKnowledgeBaseById, + KnowledgeBasePermissionError, updateKnowledgeBase, } from '@/lib/knowledge/service' import { DELETE, GET, PUT } from '@/app/api/knowledge/[id]/route' @@ -229,10 +231,59 @@ describe('Knowledge Base By ID API Route', () => { workspaceId: undefined, chunkingConfig: undefined, }, - expect.any(String) + expect.any(String), + { actorUserId: 'user-123' } ) }) + it('returns 403 when service rejects a cross-workspace transfer', async () => { + authMockFns.mockGetSession.mockResolvedValue({ + user: { id: 'attacker', email: 'a@example.com' }, + }) + + resetMocks() + + vi.mocked(checkKnowledgeBaseWriteAccess).mockResolvedValueOnce({ + hasAccess: true, + knowledgeBase: { id: 'kb-123', userId: 'user-123', workspaceId: 'ws-current' }, + }) + + vi.mocked(updateKnowledgeBase).mockRejectedValueOnce( + new KnowledgeBasePermissionError('User does not have permission on the target workspace') + ) + + const req = createMockRequest('PUT', { workspaceId: 'ws-target' }) + const response = await PUT(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(403) + expect(data.error).toBe('User does not have permission on the target workspace') + }) + + it('returns 403 when service rejects clearing workspaceId', async () => { + authMockFns.mockGetSession.mockResolvedValue({ + user: { id: 'user-123', email: 'test@example.com' }, + }) + + resetMocks() + + vi.mocked(checkKnowledgeBaseWriteAccess).mockResolvedValueOnce({ + hasAccess: true, + knowledgeBase: { id: 'kb-123', userId: 'user-123', workspaceId: 'ws-current' }, + }) + + vi.mocked(updateKnowledgeBase).mockRejectedValueOnce( + new KnowledgeBasePermissionError('Knowledge base workspace cannot be cleared') + ) + + const req = createMockRequest('PUT', { workspaceId: null }) + const response = await PUT(req, { params: mockParams }) + const data = await response.json() + + expect(response.status).toBe(403) + expect(data.error).toBe('Knowledge base workspace cannot be cleared') + }) + it('should return unauthorized for unauthenticated user', async () => { authMockFns.mockGetSession.mockResolvedValue(null) diff --git a/apps/sim/app/api/knowledge/[id]/route.ts b/apps/sim/app/api/knowledge/[id]/route.ts index 5013a8c29c7..34a85d1a684 100644 --- a/apps/sim/app/api/knowledge/[id]/route.ts +++ b/apps/sim/app/api/knowledge/[id]/route.ts @@ -11,6 +11,7 @@ import { deleteKnowledgeBase, getKnowledgeBaseById, KnowledgeBaseConflictError, + KnowledgeBasePermissionError, updateKnowledgeBase, } from '@/lib/knowledge/service' import { checkKnowledgeBaseAccess, checkKnowledgeBaseWriteAccess } from '@/app/api/knowledge/utils' @@ -101,7 +102,8 @@ export const PUT = withRouteHandler( workspaceId: validatedData.workspaceId, chunkingConfig: validatedData.chunkingConfig, }, - requestId + requestId, + { actorUserId: userId } ) logger.info(`[${requestId}] Knowledge base updated: ${id} for user ${userId}`) @@ -141,6 +143,10 @@ export const PUT = withRouteHandler( if (error instanceof KnowledgeBaseConflictError) { return NextResponse.json({ error: error.message }, { status: 409 }) } + if (error instanceof KnowledgeBasePermissionError) { + logger.warn(`[${requestId}] Forbidden knowledge base update on ${id}: ${error.message}`) + return NextResponse.json({ error: error.message }, { status: 403 }) + } logger.error(`[${requestId}] Error updating knowledge base`, error) return NextResponse.json({ error: 'Failed to update knowledge base' }, { status: 500 }) diff --git a/apps/sim/app/api/knowledge/route.test.ts b/apps/sim/app/api/knowledge/route.test.ts index bc0ab08d755..4ad2aad2acf 100644 --- a/apps/sim/app/api/knowledge/route.test.ts +++ b/apps/sim/app/api/knowledge/route.test.ts @@ -155,6 +155,23 @@ describe('Knowledge Base API Route', () => { expect(data.details).toBeDefined() }) + it('returns 403 when user lacks permission on target workspace', async () => { + authMockFns.mockGetSession.mockResolvedValue({ + user: { id: 'attacker', email: 'a@example.com' }, + }) + permissionsMockFns.mockGetUserEntityPermissions.mockResolvedValueOnce('read') + + const req = createMockRequest('POST', validKnowledgeBaseData) + const response = await POST(req) + const data = await response.json() + + expect(response.status).toBe(403) + expect(data.error).toBe( + 'User does not have permission to create knowledge bases in this workspace' + ) + expect(mockDbChain.insert).not.toHaveBeenCalled() + }) + it('should validate chunking config constraints', async () => { authMockFns.mockGetSession.mockResolvedValue({ user: { id: 'user-123', email: 'test@example.com' }, diff --git a/apps/sim/app/api/knowledge/route.ts b/apps/sim/app/api/knowledge/route.ts index 8cea52b8eb7..e14efd14656 100644 --- a/apps/sim/app/api/knowledge/route.ts +++ b/apps/sim/app/api/knowledge/route.ts @@ -15,6 +15,7 @@ import { createKnowledgeBase, getKnowledgeBases, KnowledgeBaseConflictError, + KnowledgeBasePermissionError, type KnowledgeBaseScope, } from '@/lib/knowledge/service' import { captureServerEvent } from '@/lib/posthog/server' @@ -159,6 +160,10 @@ export const POST = withRouteHandler(async (req: NextRequest) => { if (createError instanceof KnowledgeBaseConflictError) { return NextResponse.json({ error: createError.message }, { status: 409 }) } + if (createError instanceof KnowledgeBasePermissionError) { + logger.warn(`[${requestId}] Forbidden knowledge base creation: ${createError.message}`) + return NextResponse.json({ error: createError.message }, { status: 403 }) + } throw createError } } catch (error) { diff --git a/apps/sim/app/api/mcp/servers/test-connection/route.ts b/apps/sim/app/api/mcp/servers/test-connection/route.ts index 46ef05fc2bd..bd88be77aad 100644 --- a/apps/sim/app/api/mcp/servers/test-connection/route.ts +++ b/apps/sim/app/api/mcp/servers/test-connection/route.ts @@ -95,6 +95,9 @@ export const POST = withRouteHandler( } try { + // Initial pre-resolution check; the authoritative resolved IP is + // captured after env-var resolution below and used to pin the + // connection against DNS rebinding. await validateMcpServerSsrf(body.url) } catch (e) { if (e instanceof McpDnsResolutionError) { @@ -140,8 +143,9 @@ export const POST = withRouteHandler( throw e } + let resolvedIP: string | null try { - await validateMcpServerSsrf(testConfig.url) + resolvedIP = await validateMcpServerSsrf(testConfig.url) } catch (e) { if (e instanceof McpDnsResolutionError) { return createMcpErrorResponse(e, e.message, 502) @@ -162,7 +166,11 @@ export const POST = withRouteHandler( let client: McpClient | null = null try { - client = new McpClient(testConfig, testSecurityPolicy) + client = new McpClient({ + config: testConfig, + securityPolicy: testSecurityPolicy, + resolvedIP: resolvedIP ?? undefined, + }) await client.connect() result.negotiatedVersion = client.getNegotiatedVersion() diff --git a/apps/sim/app/api/tools/agiloft/attach/route.test.ts b/apps/sim/app/api/tools/agiloft/attach/route.test.ts new file mode 100644 index 00000000000..f1e4c8c4264 --- /dev/null +++ b/apps/sim/app/api/tools/agiloft/attach/route.test.ts @@ -0,0 +1,144 @@ +/** + * @vitest-environment node + */ +import { + createMockRequest, + hybridAuthMockFns, + inputValidationMock, + inputValidationMockFns, +} from '@sim/testing' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { mockProcessFilesToUserFiles, mockDownloadFileFromStorage, mockAssertToolFileAccess } = + vi.hoisted(() => ({ + mockProcessFilesToUserFiles: vi.fn(), + mockDownloadFileFromStorage: vi.fn(), + mockAssertToolFileAccess: vi.fn(), + })) + +vi.mock('@/lib/core/security/input-validation.server', () => inputValidationMock) +vi.mock('@/lib/uploads/utils/file-utils', () => ({ + processFilesToUserFiles: mockProcessFilesToUserFiles, +})) +vi.mock('@/lib/uploads/utils/file-utils.server', () => ({ + downloadFileFromStorage: mockDownloadFileFromStorage, +})) +vi.mock('@/app/api/files/authorization', () => ({ + assertToolFileAccess: mockAssertToolFileAccess, +})) + +import { POST } from '@/app/api/tools/agiloft/attach/route' + +const PINNED_IP = '93.184.216.34' + +const baseBody = { + instanceUrl: 'https://example.agiloft.com', + knowledgeBase: 'demo', + login: 'admin', + password: 'secret', + table: 'contracts', + recordId: '42', + fieldName: 'attachments', + file: { key: 's3://bucket/file.txt', name: 'file.txt', size: 5, type: 'text/plain' }, + fileName: 'file.txt', +} + +function mockSecureFetchResponse(body: { + ok?: boolean + status?: number + json?: unknown + text?: string +}) { + return { + ok: body.ok ?? true, + status: body.status ?? 200, + statusText: '', + headers: new Headers(), + body: null, + text: async () => body.text ?? '', + json: async () => body.json ?? {}, + arrayBuffer: async () => new ArrayBuffer(0), + } +} + +beforeEach(() => { + vi.clearAllMocks() + hybridAuthMockFns.mockCheckInternalAuth.mockResolvedValue({ + success: true, + userId: 'user-1', + authType: 'internal_jwt', + }) + inputValidationMockFns.mockValidateUrlWithDNS.mockResolvedValue({ + isValid: true, + resolvedIP: PINNED_IP, + originalHostname: 'example.agiloft.com', + }) + mockProcessFilesToUserFiles.mockReturnValue([ + { key: 's3://bucket/file.txt', name: 'file.txt', size: 5, type: 'text/plain' }, + ]) + mockAssertToolFileAccess.mockResolvedValue(null) + mockDownloadFileFromStorage.mockResolvedValue(Buffer.from('hello')) +}) + +describe('POST /api/tools/agiloft/attach', () => { + it('rejects unauthenticated requests', async () => { + hybridAuthMockFns.mockCheckInternalAuth.mockResolvedValueOnce({ + success: false, + error: 'unauthorized', + }) + + const response = await POST(createMockRequest('POST', baseBody)) + expect(response.status).toBe(401) + expect(inputValidationMockFns.mockSecureFetchWithPinnedIP).not.toHaveBeenCalled() + }) + + it('blocks SSRF when the instance URL fails DNS validation', async () => { + inputValidationMockFns.mockValidateUrlWithDNS.mockResolvedValueOnce({ + isValid: false, + error: 'instanceUrl resolves to a blocked IP address', + }) + + const response = await POST( + createMockRequest('POST', { ...baseBody, instanceUrl: 'https://attacker.example.com' }) + ) + + expect(response.status).toBe(400) + expect(inputValidationMockFns.mockSecureFetchWithPinnedIP).not.toHaveBeenCalled() + }) + + it('pins the resolved IP for login, attach, and logout (TOCTOU fix)', async () => { + inputValidationMockFns.mockSecureFetchWithPinnedIP + .mockResolvedValueOnce(mockSecureFetchResponse({ json: { access_token: 'tok-att' } })) + .mockResolvedValueOnce(mockSecureFetchResponse({ text: '1' })) + .mockResolvedValueOnce(mockSecureFetchResponse({})) + + const response = await POST(createMockRequest('POST', baseBody)) + expect(response.status).toBe(200) + const data = (await response.json()) as { + success: true + output: { totalAttachments: number; fileName: string } + } + expect(data.output.totalAttachments).toBe(1) + expect(data.output.fileName).toBe('file.txt') + + const calls = inputValidationMockFns.mockSecureFetchWithPinnedIP.mock.calls + expect(calls).toHaveLength(3) + for (const call of calls) { + expect(call[1]).toBe(PINNED_IP) + } + + expect(calls[0][0]).toContain('https://example.agiloft.com/ewws/EWLogin') + expect(calls[1][0]).toContain('https://example.agiloft.com/ewws/EWAttach') + expect(calls[1][2]).toMatchObject({ + method: 'PUT', + headers: { + Authorization: 'Bearer tok-att', + 'Content-Type': 'application/octet-stream', + }, + }) + expect(calls[2][0]).toContain('https://example.agiloft.com/ewws/EWLogout') + + // DNS only resolved once. + expect(inputValidationMockFns.mockValidateUrlWithDNS).toHaveBeenCalledTimes(1) + }) +}) diff --git a/apps/sim/app/api/tools/agiloft/attach/route.ts b/apps/sim/app/api/tools/agiloft/attach/route.ts index 6257502ae4c..7a21cc02b54 100644 --- a/apps/sim/app/api/tools/agiloft/attach/route.ts +++ b/apps/sim/app/api/tools/agiloft/attach/route.ts @@ -4,14 +4,19 @@ import { type NextRequest, NextResponse } from 'next/server' import { agiloftAttachContract } from '@/lib/api/contracts/tools/agiloft' import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' import { checkInternalAuth } from '@/lib/auth/hybrid' -import { validateUrlWithDNS } from '@/lib/core/security/input-validation.server' +import { secureFetchWithPinnedIP } from '@/lib/core/security/input-validation.server' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' import type { RawFileInput } from '@/lib/uploads/utils/file-schemas' import { processFilesToUserFiles } from '@/lib/uploads/utils/file-utils' import { downloadFileFromStorage } from '@/lib/uploads/utils/file-utils.server' import { assertToolFileAccess } from '@/app/api/files/authorization' -import { agiloftLogin, agiloftLogout, buildAttachFileUrl } from '@/tools/agiloft/utils' +import { + agiloftLogin, + agiloftLogout, + buildAttachFileUrl, + resolveAgiloftInstance, +} from '@/tools/agiloft/utils' export const dynamic = 'force-dynamic' @@ -72,18 +77,17 @@ export const POST = withRouteHandler(async (request: NextRequest) => { const fileBuffer = await downloadFileFromStorage(userFile, requestId, logger) const resolvedFileName = data.fileName || userFile.name || 'attachment' - const urlValidation = await validateUrlWithDNS(data.instanceUrl, 'instanceUrl') - if (!urlValidation.isValid) { + let resolvedIP: string + try { + resolvedIP = await resolveAgiloftInstance(data.instanceUrl) + } catch (error) { logger.warn(`[${requestId}] SSRF attempt blocked for Agiloft instance URL`, { instanceUrl: data.instanceUrl, }) - return NextResponse.json( - { success: false, error: urlValidation.error || 'Invalid instance URL' }, - { status: 400 } - ) + return NextResponse.json({ success: false, error: toError(error).message }, { status: 400 }) } - const token = await agiloftLogin(data) + const token = await agiloftLogin(data, resolvedIP) const base = data.instanceUrl.replace(/\/$/, '') try { @@ -91,7 +95,7 @@ export const POST = withRouteHandler(async (request: NextRequest) => { logger.info(`[${requestId}] Uploading file to Agiloft: ${resolvedFileName}`) - const agiloftResponse = await fetch(url, { + const agiloftResponse = await secureFetchWithPinnedIP(url, resolvedIP, { method: 'PUT', headers: { 'Content-Type': 'application/octet-stream', @@ -135,7 +139,7 @@ export const POST = withRouteHandler(async (request: NextRequest) => { }, }) } finally { - await agiloftLogout(data.instanceUrl, data.knowledgeBase, token) + await agiloftLogout(data.instanceUrl, data.knowledgeBase, token, resolvedIP) } } catch (error) { logger.error(`[${requestId}] Error attaching file to Agiloft:`, error) diff --git a/apps/sim/app/api/tools/agiloft/retrieve/route.test.ts b/apps/sim/app/api/tools/agiloft/retrieve/route.test.ts new file mode 100644 index 00000000000..efd435b5b04 --- /dev/null +++ b/apps/sim/app/api/tools/agiloft/retrieve/route.test.ts @@ -0,0 +1,163 @@ +/** + * @vitest-environment node + */ +import { + createMockRequest, + hybridAuthMockFns, + inputValidationMock, + inputValidationMockFns, +} from '@sim/testing' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +vi.mock('@/lib/core/security/input-validation.server', () => inputValidationMock) + +import { POST } from '@/app/api/tools/agiloft/retrieve/route' + +const PINNED_IP = '93.184.216.34' + +const baseBody = { + instanceUrl: 'https://example.agiloft.com', + knowledgeBase: 'demo', + login: 'admin', + password: 'secret', + table: 'contracts', + recordId: '42', + fieldName: 'attachments', + position: '0', +} + +function mockSecureFetchResponse(body: { + ok?: boolean + status?: number + json?: unknown + text?: string + arrayBuffer?: ArrayBuffer + headers?: Headers +}) { + return { + ok: body.ok ?? true, + status: body.status ?? 200, + statusText: '', + headers: body.headers ?? new Headers(), + body: null, + text: async () => body.text ?? '', + json: async () => body.json ?? {}, + arrayBuffer: async () => body.arrayBuffer ?? new ArrayBuffer(0), + } +} + +beforeEach(() => { + vi.clearAllMocks() + hybridAuthMockFns.mockCheckInternalAuth.mockResolvedValue({ + success: true, + userId: 'user-1', + authType: 'internal_jwt', + }) + inputValidationMockFns.mockValidateUrlWithDNS.mockResolvedValue({ + isValid: true, + resolvedIP: PINNED_IP, + originalHostname: 'example.agiloft.com', + }) +}) + +describe('POST /api/tools/agiloft/retrieve', () => { + it('rejects unauthenticated requests', async () => { + hybridAuthMockFns.mockCheckInternalAuth.mockResolvedValueOnce({ + success: false, + error: 'unauthorized', + }) + + const response = await POST(createMockRequest('POST', baseBody)) + expect(response.status).toBe(401) + expect(inputValidationMockFns.mockSecureFetchWithPinnedIP).not.toHaveBeenCalled() + }) + + it('blocks SSRF when the instance URL fails DNS validation', async () => { + inputValidationMockFns.mockValidateUrlWithDNS.mockResolvedValueOnce({ + isValid: false, + error: 'instanceUrl resolves to a blocked IP address', + }) + + const response = await POST( + createMockRequest('POST', { ...baseBody, instanceUrl: 'https://attacker.example.com' }) + ) + + expect(response.status).toBe(400) + const data = (await response.json()) as { success: false; error: string } + expect(data.success).toBe(false) + expect(data.error).toContain('blocked IP') + expect(inputValidationMockFns.mockSecureFetchWithPinnedIP).not.toHaveBeenCalled() + }) + + it('pins the resolved IP for login, retrieve, and logout (TOCTOU fix)', async () => { + const fileBytes = Buffer.from('hello-attachment', 'utf-8') + + inputValidationMockFns.mockSecureFetchWithPinnedIP + .mockResolvedValueOnce(mockSecureFetchResponse({ json: { access_token: 'tok-xyz' } })) + .mockResolvedValueOnce( + mockSecureFetchResponse({ + arrayBuffer: fileBytes.buffer.slice( + fileBytes.byteOffset, + fileBytes.byteOffset + fileBytes.byteLength + ) as ArrayBuffer, + headers: new Headers({ + 'content-type': 'text/plain', + 'content-disposition': 'attachment; filename="report.txt"', + }), + }) + ) + .mockResolvedValueOnce(mockSecureFetchResponse({})) + + const response = await POST(createMockRequest('POST', baseBody)) + expect(response.status).toBe(200) + const data = (await response.json()) as { + success: true + output: { file: { name: string; mimeType: string; data: string; size: number } } + } + + expect(data.output.file.name).toBe('report.txt') + expect(data.output.file.mimeType).toBe('text/plain') + expect(data.output.file.size).toBe(fileBytes.length) + expect(Buffer.from(data.output.file.data, 'base64').toString('utf-8')).toBe('hello-attachment') + + const calls = inputValidationMockFns.mockSecureFetchWithPinnedIP.mock.calls + expect(calls).toHaveLength(3) + + // All three outbound calls must use the pre-resolved IP. + for (const call of calls) { + expect(call[1]).toBe(PINNED_IP) + } + + // Original hostname is preserved in the URL (so TLS SNI works). + expect(calls[0][0]).toContain('https://example.agiloft.com/ewws/EWLogin') + expect(calls[1][0]).toContain('https://example.agiloft.com/ewws/EWRetrieve') + expect(calls[1][2]).toMatchObject({ + method: 'GET', + headers: { Authorization: 'Bearer tok-xyz' }, + }) + expect(calls[2][0]).toContain('https://example.agiloft.com/ewws/EWLogout') + + // DNS only resolved once — no second lookup that could rebind. + expect(inputValidationMockFns.mockValidateUrlWithDNS).toHaveBeenCalledTimes(1) + }) + + it('propagates upstream errors and still calls logout', async () => { + inputValidationMockFns.mockSecureFetchWithPinnedIP + .mockResolvedValueOnce(mockSecureFetchResponse({ json: { access_token: 'tok-err' } })) + .mockResolvedValueOnce( + mockSecureFetchResponse({ ok: false, status: 404, text: 'Record not found' }) + ) + .mockResolvedValueOnce(mockSecureFetchResponse({})) + + const response = await POST(createMockRequest('POST', baseBody)) + expect(response.status).toBe(404) + const data = (await response.json()) as { success: false; error: string } + expect(data.error).toContain('Record not found') + + // Logout still runs. + expect(inputValidationMockFns.mockSecureFetchWithPinnedIP).toHaveBeenCalledTimes(3) + expect(inputValidationMockFns.mockSecureFetchWithPinnedIP.mock.calls[2][0]).toContain( + '/ewws/EWLogout' + ) + }) +}) diff --git a/apps/sim/app/api/tools/agiloft/retrieve/route.ts b/apps/sim/app/api/tools/agiloft/retrieve/route.ts index 64bd72daae8..0d6137988ed 100644 --- a/apps/sim/app/api/tools/agiloft/retrieve/route.ts +++ b/apps/sim/app/api/tools/agiloft/retrieve/route.ts @@ -4,10 +4,15 @@ import { type NextRequest, NextResponse } from 'next/server' import { agiloftRetrieveContract } from '@/lib/api/contracts/tools/agiloft' import { getValidationErrorMessage, parseRequest } from '@/lib/api/server' import { checkInternalAuth } from '@/lib/auth/hybrid' -import { validateUrlWithDNS } from '@/lib/core/security/input-validation.server' +import { secureFetchWithPinnedIP } from '@/lib/core/security/input-validation.server' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' -import { agiloftLogin, agiloftLogout, buildRetrieveAttachmentUrl } from '@/tools/agiloft/utils' +import { + agiloftLogin, + agiloftLogout, + buildRetrieveAttachmentUrl, + resolveAgiloftInstance, +} from '@/tools/agiloft/utils' export const dynamic = 'force-dynamic' @@ -48,18 +53,17 @@ export const POST = withRouteHandler(async (request: NextRequest) => { if (!parsed.success) return parsed.response const data = parsed.data.body - const urlValidation = await validateUrlWithDNS(data.instanceUrl, 'instanceUrl') - if (!urlValidation.isValid) { + let resolvedIP: string + try { + resolvedIP = await resolveAgiloftInstance(data.instanceUrl) + } catch (error) { logger.warn(`[${requestId}] SSRF attempt blocked for Agiloft instance URL`, { instanceUrl: data.instanceUrl, }) - return NextResponse.json( - { success: false, error: urlValidation.error || 'Invalid instance URL' }, - { status: 400 } - ) + return NextResponse.json({ success: false, error: toError(error).message }, { status: 400 }) } - const token = await agiloftLogin(data) + const token = await agiloftLogin(data, resolvedIP) const base = data.instanceUrl.replace(/\/$/, '') try { @@ -71,7 +75,7 @@ export const POST = withRouteHandler(async (request: NextRequest) => { position: data.position, }) - const agiloftResponse = await fetch(url, { + const agiloftResponse = await secureFetchWithPinnedIP(url, resolvedIP, { method: 'GET', headers: { Authorization: `Bearer ${token}`, @@ -123,7 +127,7 @@ export const POST = withRouteHandler(async (request: NextRequest) => { }, }) } finally { - await agiloftLogout(data.instanceUrl, data.knowledgeBase, token) + await agiloftLogout(data.instanceUrl, data.knowledgeBase, token, resolvedIP) } } catch (error) { logger.error(`[${requestId}] Error retrieving Agiloft attachment:`, error) diff --git a/apps/sim/app/form/[identifier]/components/email-auth.tsx b/apps/sim/app/form/[identifier]/components/email-auth.tsx new file mode 100644 index 00000000000..b75cb159c3c --- /dev/null +++ b/apps/sim/app/form/[identifier]/components/email-auth.tsx @@ -0,0 +1,284 @@ +'use client' + +import { useEffect, useState } from 'react' +import { createLogger } from '@sim/logger' +import { toError } from '@sim/utils/errors' +import { Input, InputOTP, InputOTPGroup, InputOTPSlot, Label, Loader } from '@/components/emcn' +import { cn } from '@/lib/core/utils/cn' +import { quickValidateEmail } from '@/lib/messaging/email/validation' +import AuthBackground from '@/app/(auth)/components/auth-background' +import { AUTH_SUBMIT_BTN } from '@/app/(auth)/components/auth-button-classes' +import { SupportFooter } from '@/app/(auth)/components/support-footer' +import Navbar from '@/app/(landing)/components/navbar/navbar' +import { useFormEmailOtpRequest, useFormEmailOtpVerify } from '@/hooks/queries/forms' + +const logger = createLogger('FormEmailAuth') + +interface EmailAuthProps { + identifier: string + onAuthenticated: () => void +} + +function validateEmailField(emailValue: string): string[] { + const errors: string[] = [] + + if (!emailValue || !emailValue.trim()) { + errors.push('Email is required.') + return errors + } + + const validation = quickValidateEmail(emailValue.trim().toLowerCase()) + if (!validation.isValid) { + errors.push(validation.reason || 'Please enter a valid email address.') + } + + return errors +} + +export function EmailAuth({ identifier, onAuthenticated }: EmailAuthProps) { + const [email, setEmail] = useState('') + const [authError, setAuthError] = useState(null) + const [emailErrors, setEmailErrors] = useState([]) + const [showEmailValidationError, setShowEmailValidationError] = useState(false) + + const [showOtpVerification, setShowOtpVerification] = useState(false) + const [otpValue, setOtpValue] = useState('') + const [countdown, setCountdown] = useState(0) + + const requestOtp = useFormEmailOtpRequest(identifier) + const verifyOtp = useFormEmailOtpVerify(identifier) + + useEffect(() => { + if (countdown <= 0) return + const timer = setTimeout(() => setCountdown((c) => c - 1), 1000) + return () => clearTimeout(timer) + }, [countdown]) + + const handleEmailChange = (e: React.ChangeEvent) => { + const newEmail = e.target.value + setEmail(newEmail) + const errors = validateEmailField(newEmail) + setEmailErrors(errors) + setShowEmailValidationError(false) + } + + const handleSendOtp = async () => { + const emailValidationErrors = validateEmailField(email) + setEmailErrors(emailValidationErrors) + setShowEmailValidationError(emailValidationErrors.length > 0) + + if (emailValidationErrors.length > 0) return + + setAuthError(null) + + try { + await requestOtp.mutateAsync({ email }) + setShowOtpVerification(true) + } catch (error) { + logger.error('Error sending OTP:', error) + setEmailErrors([toError(error).message || 'Failed to send verification code']) + setShowEmailValidationError(true) + } + } + + const handleVerifyOtp = async (otp?: string) => { + const codeToVerify = otp || otpValue + if (!codeToVerify || codeToVerify.length !== 6) return + + setAuthError(null) + + try { + await verifyOtp.mutateAsync({ email, otp: codeToVerify }) + onAuthenticated() + } catch (error) { + logger.error('Error verifying OTP:', error) + setAuthError(toError(error).message || 'Invalid verification code') + } + } + + const handleResendOtp = async () => { + setAuthError(null) + setCountdown(30) + + try { + await requestOtp.mutateAsync({ email }) + setOtpValue('') + } catch (error) { + logger.error('Error resending OTP:', error) + setAuthError(toError(error).message || 'Failed to resend verification code') + setCountdown(0) + } + } + + return ( + +
+
+ +
+
+
+
+
+

+ {showOtpVerification ? 'Verify Your Email' : 'Email Verification'} +

+

+ {showOtpVerification + ? `A verification code has been sent to ${email}` + : 'This form requires email verification'} +

+
+ +
+ {!showOtpVerification ? ( +
{ + e.preventDefault() + handleSendOtp() + }} + className='space-y-6' + > +
+ + 0 && + 'border-red-500 focus:border-red-500' + )} + /> + {showEmailValidationError && emailErrors.length > 0 && ( +
+ {emailErrors.map((error) => ( +

{error}

+ ))} +
+ )} +
+ + +
+ ) : ( +
+

+ Enter the 6-digit code to verify your account. If you don't see it in your + inbox, check your spam folder. +

+ +
+ { + setOtpValue(value) + if (value.length === 6) { + handleVerifyOtp(value) + } + }} + disabled={verifyOtp.isPending} + className={cn('gap-2', authError && 'otp-error')} + > + + {[0, 1, 2, 3, 4, 5].map((index) => ( + + ))} + + +
+ + {authError && ( +
+

{authError}

+
+ )} + + + +
+

+ Didn't receive a code?{' '} + {countdown > 0 ? ( + + Resend in{' '} + + {countdown}s + + + ) : ( + + )} +

+
+ +
+ +
+
+ )} +
+
+
+
+ +
+
+ ) +} diff --git a/apps/sim/app/form/[identifier]/components/index.ts b/apps/sim/app/form/[identifier]/components/index.ts index 31cb46d6843..e888196967c 100644 --- a/apps/sim/app/form/[identifier]/components/index.ts +++ b/apps/sim/app/form/[identifier]/components/index.ts @@ -1,3 +1,4 @@ +export { EmailAuth } from './email-auth' export { FormErrorState } from './error-state' export { FormField } from './form-field' export { FormLoadingState } from './loading-state' diff --git a/apps/sim/app/form/[identifier]/form.tsx b/apps/sim/app/form/[identifier]/form.tsx index f6264ddf0fb..4e809096b55 100644 --- a/apps/sim/app/form/[identifier]/form.tsx +++ b/apps/sim/app/form/[identifier]/form.tsx @@ -10,6 +10,7 @@ import { AUTH_SUBMIT_BTN } from '@/app/(auth)/components/auth-button-classes' import { SupportFooter } from '@/app/(auth)/components/support-footer' import Navbar from '@/app/(landing)/components/navbar/navbar' import { + EmailAuth, FormErrorState, FormField, FormLoadingState, @@ -241,6 +242,10 @@ export default function Form({ identifier }: { identifier: string }) { return } + if (authRequired === 'email') { + return fetchFormConfig()} /> + } + if (isSubmitted && thankYouData) { return ( diff --git a/apps/sim/hooks/queries/forms.ts b/apps/sim/hooks/queries/forms.ts index e43e8fe7ecc..e20df733b5e 100644 --- a/apps/sim/hooks/queries/forms.ts +++ b/apps/sim/hooks/queries/forms.ts @@ -14,8 +14,10 @@ import { type FormStatusResponse, getFormDetailContract, getFormStatusContract, + requestFormEmailOtpContract, type UpdateFormInput, updateFormContract, + verifyFormEmailOtpContract, } from '@/lib/api/contracts/forms' import { deploymentKeys } from './deployments' @@ -35,6 +37,36 @@ export const formKeys = { */ export type { FormAuthType } +/** + * Requests a one-time passcode for an email-gated deployed form. + * Used for both the initial send and resend flows. + */ +export function useFormEmailOtpRequest(identifier: string) { + return useMutation({ + mutationFn: async ({ email }: { email: string }) => { + await requestJson(requestFormEmailOtpContract, { + params: { identifier }, + body: { email }, + }) + }, + }) +} + +/** + * Verifies a one-time passcode for an email-gated deployed form. + * On success the server sets the auth cookie; the caller should re-fetch the form config. + */ +export function useFormEmailOtpVerify(identifier: string) { + return useMutation({ + mutationFn: async ({ email, otp }: { email: string; otp: string }) => { + await requestJson(verifyFormEmailOtpContract, { + params: { identifier }, + body: { email, otp }, + }) + }, + }) +} + /** * Field configuration for form fields */ diff --git a/apps/sim/lib/api/contracts/forms.ts b/apps/sim/lib/api/contracts/forms.ts index af252043b65..0283121edca 100644 --- a/apps/sim/lib/api/contracts/forms.ts +++ b/apps/sim/lib/api/contracts/forms.ts @@ -148,6 +148,22 @@ export const formMutationResponseSchema = z.object({ message: z.string(), }) +export const formEmailOtpRequestBodySchema = z.object({ + email: z.string().email('Invalid email address'), +}) + +export const formEmailOtpVerifyBodySchema = formEmailOtpRequestBodySchema.extend({ + otp: z.string().length(6, 'OTP must be 6 digits'), +}) + +export const formEmailOtpRequestResponseSchema = z.object({ + message: z.string(), +}) + +export const formEmailOtpVerifyResponseSchema = z.object({ + authenticated: z.literal(true), +}) + export const getFormStatusContract = defineRouteContract({ method: 'GET', path: '/api/workflows/[id]/form/status', @@ -199,6 +215,28 @@ export const deleteFormContract = defineRouteContract({ }, }) +export const requestFormEmailOtpContract = defineRouteContract({ + method: 'POST', + path: '/api/form/[identifier]/otp', + params: formIdentifierParamsSchema, + body: formEmailOtpRequestBodySchema, + response: { + mode: 'json', + schema: formEmailOtpRequestResponseSchema, + }, +}) + +export const verifyFormEmailOtpContract = defineRouteContract({ + method: 'PUT', + path: '/api/form/[identifier]/otp', + params: formIdentifierParamsSchema, + body: formEmailOtpVerifyBodySchema, + response: { + mode: 'json', + schema: formEmailOtpVerifyResponseSchema, + }, +}) + export const validateFormIdentifierContract = defineRouteContract({ method: 'GET', path: '/api/form/validate', diff --git a/apps/sim/lib/api/contracts/knowledge/documents.ts b/apps/sim/lib/api/contracts/knowledge/documents.ts index 17cb344f37f..6a85005f700 100644 --- a/apps/sim/lib/api/contracts/knowledge/documents.ts +++ b/apps/sim/lib/api/contracts/knowledge/documents.ts @@ -5,6 +5,7 @@ import { documentNumberFieldSchema, documentTagFieldSchema, knowledgeBaseParamsSchema, + knowledgeDocumentFileUrlSchema, knowledgeDocumentParamsSchema, nullableWireDateSchema, paginationSchema, @@ -55,7 +56,7 @@ export const listKnowledgeDocumentsQuerySchema = z.object({ export const createDocumentBodySchema = z.object({ filename: z.string().min(1, 'Filename is required'), - fileUrl: z.string().url('File URL must be valid'), + fileUrl: knowledgeDocumentFileUrlSchema, fileSize: z.number().min(1, 'File size must be greater than 0'), mimeType: z.string().min(1, 'MIME type is required'), tag1: z.string().optional(), @@ -101,7 +102,7 @@ export type SingleCreateDocumentBody = z.input { + it('accepts data: URIs', () => { + const result = knowledgeDocumentFileUrlSchema.safeParse( + 'data:text/plain;base64,SGVsbG8gd29ybGQ=' + ) + expect(result.success).toBe(true) + }) + + it('accepts https URLs', () => { + const result = knowledgeDocumentFileUrlSchema.safeParse('https://example.com/file.pdf') + expect(result.success).toBe(true) + }) + + it('accepts http URLs', () => { + const result = knowledgeDocumentFileUrlSchema.safeParse( + 'http://localhost:3000/api/files/serve/kb/foo.pdf?context=knowledge-base' + ) + expect(result.success).toBe(true) + }) + + it('is case-insensitive on the scheme', () => { + expect(knowledgeDocumentFileUrlSchema.safeParse('HTTPS://example.com/x').success).toBe(true) + expect(knowledgeDocumentFileUrlSchema.safeParse('Http://example.com/x').success).toBe(true) + }) + + it.each([ + ['absolute local path', '/etc/passwd'], + ['app path', '/app/.env'], + ['relative path', './secrets.txt'], + ['parent traversal', '../../etc/shadow'], + ['file:// scheme', 'file:///etc/passwd'], + ['ftp scheme', 'ftp://example.com/x'], + ['javascript scheme', 'javascript:alert(1)'], + ['gopher scheme', 'gopher://example.com'], + ['relative serve path', '/api/files/serve/kb/foo.pdf'], + ['windows path', 'C:\\Windows\\System32\\config\\SAM'], + ['empty string', ''], + ['whitespace prefix', ' https://example.com/x'], + ])('rejects %s', (_label, value) => { + const result = knowledgeDocumentFileUrlSchema.safeParse(value) + expect(result.success).toBe(false) + }) + + it('returns a useful error message for unsupported schemes', () => { + const result = knowledgeDocumentFileUrlSchema.safeParse('/etc/passwd') + if (result.success) throw new Error('expected failure') + expect(result.error.issues[0].message).toMatch(/data: URI or an http\(s\):\/\/ URL/) + }) +}) diff --git a/apps/sim/lib/api/contracts/knowledge/shared.ts b/apps/sim/lib/api/contracts/knowledge/shared.ts index 070cd4606de..9e68e895a66 100644 --- a/apps/sim/lib/api/contracts/knowledge/shared.ts +++ b/apps/sim/lib/api/contracts/knowledge/shared.ts @@ -23,6 +23,21 @@ export const knowledgeConnectorParamsSchema = knowledgeBaseParamsSchema.extend({ connectorId: z.string().min(1), }) +/** + * A `fileUrl` accepted by knowledge document ingestion endpoints. + * + * Must be a `data:` URI or an `http(s)://` URL. Local paths, `file://`, + * and other schemes are rejected at the boundary to prevent the background + * parser from reading arbitrary files off the Sim server's filesystem. + */ +export const knowledgeDocumentFileUrlSchema = z + .string() + .min(1, 'File URL is required') + .refine( + (value) => value.startsWith('data:') || /^https?:\/\//i.test(value), + 'File URL must be a data: URI or an http(s):// URL' + ) + export const documentTagFieldSchema = z.string().nullable().optional() export const documentNumberFieldSchema = z.number().nullable().optional() export const documentBooleanFieldSchema = z.boolean().nullable().optional() diff --git a/apps/sim/lib/core/security/otp.ts b/apps/sim/lib/core/security/otp.ts new file mode 100644 index 00000000000..576d61ed508 --- /dev/null +++ b/apps/sim/lib/core/security/otp.ts @@ -0,0 +1,247 @@ +import { randomInt } from 'crypto' +import { db } from '@sim/db' +import { verification } from '@sim/db/schema' +import { generateId } from '@sim/utils/id' +import { and, eq, gt } from 'drizzle-orm' +import { getRedisClient } from '@/lib/core/config/redis' +import type { TokenBucketConfig } from '@/lib/core/rate-limiter' +import { getStorageMethod } from '@/lib/core/storage' + +export type DeploymentKind = 'chat' | 'form' + +/** + * Shared OTP configuration for deployment (chat/form) email-auth gates. + */ +export const OTP_EXPIRY_SECONDS = 15 * 60 +export const OTP_EXPIRY_MS = OTP_EXPIRY_SECONDS * 1000 +export const MAX_OTP_ATTEMPTS = 5 + +export const OTP_IP_RATE_LIMIT: TokenBucketConfig = { + maxTokens: 10, + refillRate: 10, + refillIntervalMs: 15 * 60_000, +} + +export const OTP_EMAIL_RATE_LIMIT: TokenBucketConfig = { + maxTokens: 3, + refillRate: 3, + refillIntervalMs: 15 * 60_000, +} + +/** + * Key formats are kept per-kind to preserve any in-flight OTPs already issued + * against existing chat deployments. The chat Redis key uses the legacy `otp:` + * prefix; the chat DB identifier uses `chat-otp:`. Forms use `form-otp:` for + * both. + */ +const OTP_KEYS = { + chat: { + redisKey: (email: string, deploymentId: string) => `otp:${email}:${deploymentId}`, + dbIdentifier: (email: string, deploymentId: string) => `chat-otp:${deploymentId}:${email}`, + }, + form: { + redisKey: (email: string, deploymentId: string) => `form-otp:${email}:${deploymentId}`, + dbIdentifier: (email: string, deploymentId: string) => `form-otp:${deploymentId}:${email}`, + }, +} as const satisfies Record< + DeploymentKind, + { + redisKey: (email: string, deploymentId: string) => string + dbIdentifier: (email: string, deploymentId: string) => string + } +> + +/** Returns a cryptographically random 6-digit OTP code. */ +export function generateOTP(): string { + return randomInt(100000, 1000000).toString() +} + +/** + * OTP values are stored as `"code:attempts"` (e.g. `"654321:0"`). + * This keeps the attempt counter in the same key/row as the OTP itself. + */ +function encodeOTPValue(otp: string, attempts: number): string { + return `${otp}:${attempts}` +} + +export function decodeOTPValue(value: string): { otp: string; attempts: number } { + const lastColon = value.lastIndexOf(':') + if (lastColon === -1) return { otp: value, attempts: 0 } + const attempts = Number.parseInt(value.slice(lastColon + 1), 10) + return { otp: value.slice(0, lastColon), attempts: Number.isNaN(attempts) ? 0 : attempts } +} + +/** + * Stores an OTP for a deployment+email pair, choosing Redis or the + * `verification` table based on the configured storage method. + */ +export async function storeOTP( + kind: DeploymentKind, + deploymentId: string, + email: string, + otp: string +): Promise { + const keys = OTP_KEYS[kind] + const value = encodeOTPValue(otp, 0) + const storageMethod = getStorageMethod() + + if (storageMethod === 'redis') { + const redis = getRedisClient() + if (!redis) throw new Error('Redis configured but client unavailable') + await redis.set(keys.redisKey(email, deploymentId), value, 'EX', OTP_EXPIRY_SECONDS) + return + } + + const now = new Date() + const expiresAt = new Date(now.getTime() + OTP_EXPIRY_MS) + const identifier = keys.dbIdentifier(email, deploymentId) + + await db.transaction(async (tx) => { + await tx.delete(verification).where(eq(verification.identifier, identifier)) + await tx.insert(verification).values({ + id: generateId(), + identifier, + value, + expiresAt, + createdAt: now, + updatedAt: now, + }) + }) +} + +export async function getOTP( + kind: DeploymentKind, + deploymentId: string, + email: string +): Promise { + const keys = OTP_KEYS[kind] + const storageMethod = getStorageMethod() + + if (storageMethod === 'redis') { + const redis = getRedisClient() + if (!redis) throw new Error('Redis configured but client unavailable') + return redis.get(keys.redisKey(email, deploymentId)) + } + + const now = new Date() + const [record] = await db + .select({ value: verification.value }) + .from(verification) + .where( + and( + eq(verification.identifier, keys.dbIdentifier(email, deploymentId)), + gt(verification.expiresAt, now) + ) + ) + .limit(1) + + return record?.value ?? null +} + +/** + * Lua script for atomic OTP attempt increment in Redis. + * Returns `'LOCKED'` if max attempts reached (key deleted), new encoded value + * otherwise, nil if key missing. + */ +const ATOMIC_INCREMENT_SCRIPT = ` +local val = redis.call('GET', KEYS[1]) +if not val then return nil end +local colon = val:find(':([^:]*$)') +local otp, attempts +if colon then + otp = val:sub(1, colon - 1) + attempts = tonumber(val:sub(colon + 1)) or 0 +else + otp = val + attempts = 0 +end +attempts = attempts + 1 +if attempts >= tonumber(ARGV[1]) then + redis.call('DEL', KEYS[1]) + return 'LOCKED' +end +local newVal = otp .. ':' .. attempts +local ttl = redis.call('TTL', KEYS[1]) +if ttl > 0 then + redis.call('SET', KEYS[1], newVal, 'EX', ttl) +else + redis.call('SET', KEYS[1], newVal) +end +return newVal +` + +/** + * Atomically increments an OTP's failed-attempt counter. Returns `'locked'` + * if the max-attempts threshold was reached (and the OTP was deleted), or + * `'incremented'` otherwise. The DB path uses optimistic locking with retry. + */ +export async function incrementOTPAttempts( + kind: DeploymentKind, + deploymentId: string, + email: string, + currentValue: string +): Promise<'locked' | 'incremented'> { + const keys = OTP_KEYS[kind] + const storageMethod = getStorageMethod() + + if (storageMethod === 'redis') { + const redis = getRedisClient() + if (!redis) throw new Error('Redis configured but client unavailable') + const key = keys.redisKey(email, deploymentId) + const result = await redis.eval(ATOMIC_INCREMENT_SCRIPT, 1, key, MAX_OTP_ATTEMPTS) + if (result === null || result === 'LOCKED') return 'locked' + return 'incremented' + } + + const identifier = keys.dbIdentifier(email, deploymentId) + const MAX_RETRIES = 3 + let value = currentValue + + for (let attempt = 0; attempt < MAX_RETRIES; attempt++) { + const { otp, attempts } = decodeOTPValue(value) + const newAttempts = attempts + 1 + + if (newAttempts >= MAX_OTP_ATTEMPTS) { + await db.delete(verification).where(eq(verification.identifier, identifier)) + return 'locked' + } + + const newValue = encodeOTPValue(otp, newAttempts) + const updated = await db + .update(verification) + .set({ value: newValue, updatedAt: new Date() }) + .where(and(eq(verification.identifier, identifier), eq(verification.value, value))) + .returning({ id: verification.id }) + + if (updated.length > 0) return 'incremented' + + const fresh = await getOTP(kind, deploymentId, email) + if (!fresh) return 'locked' + value = fresh + } + + const final = await getOTP(kind, deploymentId, email) + if (!final) return 'locked' + const { attempts: finalAttempts } = decodeOTPValue(final) + return finalAttempts >= MAX_OTP_ATTEMPTS ? 'locked' : 'incremented' +} + +export async function deleteOTP( + kind: DeploymentKind, + deploymentId: string, + email: string +): Promise { + const keys = OTP_KEYS[kind] + const storageMethod = getStorageMethod() + + if (storageMethod === 'redis') { + const redis = getRedisClient() + if (!redis) throw new Error('Redis configured but client unavailable') + await redis.del(keys.redisKey(email, deploymentId)) + return + } + + await db + .delete(verification) + .where(eq(verification.identifier, keys.dbIdentifier(email, deploymentId))) +} diff --git a/apps/sim/lib/knowledge/documents/document-processor.ts b/apps/sim/lib/knowledge/documents/document-processor.ts index 67f246d06b5..d68f6f4aaff 100644 --- a/apps/sim/lib/knowledge/documents/document-processor.ts +++ b/apps/sim/lib/knowledge/documents/document-processor.ts @@ -15,7 +15,7 @@ import { } from '@/lib/chunkers' import type { ChunkingStrategy, StrategyOptions } from '@/lib/chunkers/types' import { env, envNumber } from '@/lib/core/config/env' -import { parseBuffer, parseFile } from '@/lib/file-parsers' +import { parseBuffer } from '@/lib/file-parsers' import type { FileParseMetadata } from '@/lib/file-parsers/types' import { resolveParserExtension } from '@/lib/knowledge/documents/parser-extension' import { retryWithExponentialBackoff } from '@/lib/knowledge/documents/utils' @@ -395,8 +395,7 @@ async function downloadFileForBase64(fileUrl: string): Promise { if (fileUrl.startsWith('http')) { return downloadFileWithTimeout(fileUrl) } - const fs = await import('fs/promises') - return fs.readFile(fileUrl) + throw new Error('Unsupported fileUrl scheme: only data: URIs and http(s):// URLs are allowed') } function processOCRContent(result: OCRResult, filename: string): string { @@ -790,9 +789,7 @@ async function parseWithFileParser(fileUrl: string, filename: string, mimeType: content = result.content metadata = result.metadata || {} } else { - const result = await parseFile(fileUrl) - content = result.content - metadata = result.metadata || {} + throw new Error('Unsupported fileUrl scheme: only data: URIs and http(s):// URLs are allowed') } if (!content.trim()) { diff --git a/apps/sim/lib/knowledge/service.test.ts b/apps/sim/lib/knowledge/service.test.ts new file mode 100644 index 00000000000..08e81b2746c --- /dev/null +++ b/apps/sim/lib/knowledge/service.test.ts @@ -0,0 +1,114 @@ +/** + * @vitest-environment node + */ +import { + dbChainMock, + dbChainMockFns, + permissionsMock, + permissionsMockFns, + resetDbChainMock, +} from '@sim/testing' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +vi.mock('@sim/db', () => dbChainMock) +vi.mock('@/lib/workspaces/permissions/utils', () => permissionsMock) + +import { KnowledgeBasePermissionError, updateKnowledgeBase } from '@/lib/knowledge/service' + +/** + * These tests guard the workspace mass-assignment fix: + * a user with write/admin on the *source* workspace must not be able to move a + * knowledge base into a workspace where they have no permission, and must not + * be able to clear `workspaceId` (which would orphan the KB to its original + * `userId`, who may not be the caller). + */ +describe('updateKnowledgeBase — workspace transfer authorization', () => { + beforeEach(() => { + vi.clearAllMocks() + dbChainMockFns.limit.mockReset() + resetDbChainMock() + }) + + it('rejects workspaceId change without actorUserId', async () => { + await expect( + updateKnowledgeBase('kb-1', { workspaceId: 'ws-target' }, 'req-1') + ).rejects.toBeInstanceOf(KnowledgeBasePermissionError) + expect(permissionsMockFns.mockGetUserEntityPermissions).not.toHaveBeenCalled() + }) + + it('rejects clearing workspaceId to null when actor is not the KB owner', async () => { + dbChainMockFns.limit.mockResolvedValueOnce([{ workspaceId: 'ws-current', userId: 'owner' }]) + + await expect( + updateKnowledgeBase('kb-1', { workspaceId: null }, 'req-1', { actorUserId: 'attacker' }) + ).rejects.toMatchObject({ + code: 'KNOWLEDGE_BASE_FORBIDDEN', + message: 'Only the knowledge base owner can remove it from a workspace', + }) + expect(permissionsMockFns.mockGetUserEntityPermissions).not.toHaveBeenCalled() + }) + + it('allows the KB owner to clear workspaceId to null (gate passes; target permission not checked)', async () => { + dbChainMockFns.limit.mockResolvedValueOnce([{ workspaceId: 'ws-current', userId: 'owner' }]) + + await expect( + updateKnowledgeBase('kb-1', { workspaceId: null }, 'req-1', { actorUserId: 'owner' }) + ).rejects.not.toBeInstanceOf(KnowledgeBasePermissionError) + expect(permissionsMockFns.mockGetUserEntityPermissions).not.toHaveBeenCalled() + }) + + it('rejects transfer when actor has no permission on target workspace', async () => { + dbChainMockFns.limit.mockResolvedValueOnce([{ workspaceId: 'ws-current', userId: 'u-1' }]) + permissionsMockFns.mockGetUserEntityPermissions.mockResolvedValueOnce(null) + + await expect( + updateKnowledgeBase('kb-1', { workspaceId: 'ws-target' }, 'req-1', { + actorUserId: 'attacker', + }) + ).rejects.toMatchObject({ + code: 'KNOWLEDGE_BASE_FORBIDDEN', + message: 'User does not have permission on the target workspace', + }) + expect(permissionsMockFns.mockGetUserEntityPermissions).toHaveBeenCalledWith( + 'attacker', + 'workspace', + 'ws-target' + ) + }) + + it('rejects transfer when actor only has read permission on target workspace', async () => { + dbChainMockFns.limit.mockResolvedValueOnce([{ workspaceId: 'ws-current', userId: 'u-1' }]) + permissionsMockFns.mockGetUserEntityPermissions.mockResolvedValueOnce('read') + + await expect( + updateKnowledgeBase('kb-1', { workspaceId: 'ws-target' }, 'req-1', { + actorUserId: 'reader', + }) + ).rejects.toBeInstanceOf(KnowledgeBasePermissionError) + }) + + it('throws when knowledge base does not exist during transfer', async () => { + dbChainMockFns.limit.mockResolvedValueOnce([]) + + await expect( + updateKnowledgeBase('kb-missing', { workspaceId: 'ws-target' }, 'req-1', { + actorUserId: 'u-1', + }) + ).rejects.toThrow('Knowledge base kb-missing not found') + expect(permissionsMockFns.mockGetUserEntityPermissions).not.toHaveBeenCalled() + }) + + it('locks the knowledge base row (SELECT … FOR UPDATE) before the permission check', async () => { + dbChainMockFns.limit.mockResolvedValueOnce([{ workspaceId: 'ws-current', userId: 'u-1' }]) + permissionsMockFns.mockGetUserEntityPermissions.mockResolvedValueOnce(null) + + await expect( + updateKnowledgeBase('kb-1', { workspaceId: 'ws-target' }, 'req-1', { + actorUserId: 'attacker', + }) + ).rejects.toBeInstanceOf(KnowledgeBasePermissionError) + + expect(dbChainMockFns.transaction).toHaveBeenCalledTimes(1) + expect(dbChainMockFns.for).toHaveBeenCalledWith('update') + }) +}) diff --git a/apps/sim/lib/knowledge/service.ts b/apps/sim/lib/knowledge/service.ts index 0ba5de8162a..8bca342cb77 100644 --- a/apps/sim/lib/knowledge/service.ts +++ b/apps/sim/lib/knowledge/service.ts @@ -21,6 +21,10 @@ export class KnowledgeBaseConflictError extends Error { } } +export class KnowledgeBasePermissionError extends Error { + readonly code = 'KNOWLEDGE_BASE_FORBIDDEN' as const +} + export type KnowledgeBaseScope = 'active' | 'archived' | 'all' /** @@ -148,7 +152,9 @@ export async function createKnowledgeBase( const hasPermission = await getUserEntityPermissions(data.userId, 'workspace', data.workspaceId) if (hasPermission !== 'admin' && hasPermission !== 'write') { - throw new Error('User does not have permission to create knowledge bases in this workspace') + throw new KnowledgeBasePermissionError( + 'User does not have permission to create knowledge bases in this workspace' + ) } const newKnowledgeBase = { @@ -226,7 +232,8 @@ export async function updateKnowledgeBase( overlap: number } }, - requestId: string + requestId: string, + options?: { actorUserId?: string } ): Promise { const now = new Date() const updateData: { @@ -252,38 +259,81 @@ export async function updateKnowledgeBase( updateData.chunkingConfig = updates.chunkingConfig } - if (updates.name !== undefined) { - const existing = await db - .select({ id: knowledgeBase.id, workspaceId: knowledgeBase.workspaceId }) - .from(knowledgeBase) - .where(and(eq(knowledgeBase.id, knowledgeBaseId), isNull(knowledgeBase.deletedAt))) - .limit(1) + if (updates.workspaceId !== undefined && !options?.actorUserId) { + throw new KnowledgeBasePermissionError( + 'actorUserId is required to change a knowledge base workspace' + ) + } - if (existing.length > 0 && existing[0].workspaceId) { - const duplicate = await db - .select({ id: knowledgeBase.id }) + try { + await db.transaction(async (tx) => { + const [currentKb] = await tx + .select({ workspaceId: knowledgeBase.workspaceId, userId: knowledgeBase.userId }) .from(knowledgeBase) - .where( - and( - eq(knowledgeBase.workspaceId, existing[0].workspaceId), - eq(knowledgeBase.name, updates.name), - isNull(knowledgeBase.deletedAt), - ne(knowledgeBase.id, knowledgeBaseId) - ) - ) + .where(and(eq(knowledgeBase.id, knowledgeBaseId), isNull(knowledgeBase.deletedAt))) + .for('update') .limit(1) - if (duplicate.length > 0) { - throw new KnowledgeBaseConflictError(updates.name) + if (!currentKb) { + throw new Error(`Knowledge base ${knowledgeBaseId} not found`) } - } - } - try { - await db - .update(knowledgeBase) - .set(updateData) - .where(and(eq(knowledgeBase.id, knowledgeBaseId), isNull(knowledgeBase.deletedAt))) + if (updates.workspaceId !== undefined) { + const actorUserId = options?.actorUserId as string + const currentWorkspaceId = currentKb.workspaceId ?? null + const targetWorkspaceId = updates.workspaceId ?? null + + if (targetWorkspaceId !== currentWorkspaceId) { + if (!targetWorkspaceId) { + if (actorUserId !== currentKb.userId) { + throw new KnowledgeBasePermissionError( + 'Only the knowledge base owner can remove it from a workspace' + ) + } + } else { + const targetPermission = await getUserEntityPermissions( + actorUserId, + 'workspace', + targetWorkspaceId + ) + if (targetPermission !== 'write' && targetPermission !== 'admin') { + throw new KnowledgeBasePermissionError( + 'User does not have permission on the target workspace' + ) + } + } + } + } + + if (updates.name !== undefined) { + const effectiveWorkspaceId = + updates.workspaceId !== undefined ? updates.workspaceId : currentKb.workspaceId + + if (effectiveWorkspaceId) { + const duplicate = await tx + .select({ id: knowledgeBase.id }) + .from(knowledgeBase) + .where( + and( + eq(knowledgeBase.workspaceId, effectiveWorkspaceId), + eq(knowledgeBase.name, updates.name), + isNull(knowledgeBase.deletedAt), + ne(knowledgeBase.id, knowledgeBaseId) + ) + ) + .limit(1) + + if (duplicate.length > 0) { + throw new KnowledgeBaseConflictError(updates.name) + } + } + } + + await tx + .update(knowledgeBase) + .set(updateData) + .where(and(eq(knowledgeBase.id, knowledgeBaseId), isNull(knowledgeBase.deletedAt))) + }) } catch (error: unknown) { if (getPostgresErrorCode(error) === '23505' && updates.name !== undefined) { throw new KnowledgeBaseConflictError(updates.name) diff --git a/apps/sim/lib/mcp/client.ts b/apps/sim/lib/mcp/client.ts index 93588aecdd3..bbc5cb19e00 100644 --- a/apps/sim/lib/mcp/client.ts +++ b/apps/sim/lib/mcp/client.ts @@ -18,6 +18,7 @@ import { import { createLogger } from '@sim/logger' import { getErrorMessage } from '@sim/utils/errors' import { getMaxExecutionTimeout } from '@/lib/core/execution-limits' +import { createMcpPinnedFetch } from '@/lib/mcp/pinned-fetch' import { type McpClientOptions, McpConnectionError, @@ -51,34 +52,15 @@ export class McpClient { '2024-11-05', // Initial stable release ] - /** - * Creates a new MCP client. - * - * Accepts either the legacy (config, securityPolicy?) signature - * or a single McpClientOptions object with an optional onToolsChanged callback. - */ - constructor(config: McpServerConfig, securityPolicy?: McpSecurityPolicy) - constructor(options: McpClientOptions) - constructor( - configOrOptions: McpServerConfig | McpClientOptions, - securityPolicy?: McpSecurityPolicy - ) { - if ('config' in configOrOptions) { - this.config = configOrOptions.config - this.securityPolicy = configOrOptions.securityPolicy ?? { - requireConsent: true, - auditLevel: 'basic', - maxToolExecutionsPerHour: 1000, - } - this.onToolsChanged = configOrOptions.onToolsChanged - } else { - this.config = configOrOptions - this.securityPolicy = securityPolicy ?? { - requireConsent: true, - auditLevel: 'basic', - maxToolExecutionsPerHour: 1000, - } + constructor(options: McpClientOptions) { + this.config = options.config + this.securityPolicy = options.securityPolicy ?? { + requireConsent: true, + auditLevel: 'basic', + maxToolExecutionsPerHour: 1000, } + this.onToolsChanged = options.onToolsChanged + const resolvedIP = options.resolvedIP this.connectionStatus = { connected: false } @@ -90,6 +72,7 @@ export class McpClient { requestInit: { headers: this.config.headers, }, + ...(resolvedIP ? { fetch: createMcpPinnedFetch(resolvedIP) } : {}), }) this.client = new Client( diff --git a/apps/sim/lib/mcp/connection-manager.ts b/apps/sim/lib/mcp/connection-manager.ts index 3d6627be57b..a150b194a87 100644 --- a/apps/sim/lib/mcp/connection-manager.ts +++ b/apps/sim/lib/mcp/connection-manager.ts @@ -71,7 +71,8 @@ export class McpConnectionManager { async connect( config: McpServerConfig, userId: string, - workspaceId: string + workspaceId: string, + resolvedIP?: string | null ): Promise<{ supportsListChanged: boolean }> { if (this.disposed) { logger.warn('Connection manager is disposed, ignoring connect request') @@ -106,6 +107,7 @@ export class McpConnectionManager { maxToolExecutionsPerHour: 1000, }, onToolsChanged, + resolvedIP: resolvedIP ?? undefined, }) try { diff --git a/apps/sim/lib/mcp/domain-check.test.ts b/apps/sim/lib/mcp/domain-check.test.ts index 6cc76716ca0..ff559caa8cf 100644 --- a/apps/sim/lib/mcp/domain-check.test.ts +++ b/apps/sim/lib/mcp/domain-check.test.ts @@ -4,13 +4,17 @@ import { inputValidationMock, inputValidationMockFns } from '@sim/testing' import { beforeEach, describe, expect, it, vi } from 'vitest' -const { mockGetAllowedMcpDomainsFromEnv, mockDnsLookup } = vi.hoisted(() => ({ +const { mockGetAllowedMcpDomainsFromEnv, mockDnsLookup, hostedFlag } = vi.hoisted(() => ({ mockGetAllowedMcpDomainsFromEnv: vi.fn<() => string[] | null>(), mockDnsLookup: vi.fn(), + hostedFlag: { value: false }, })) vi.mock('@/lib/core/config/feature-flags', () => ({ getAllowedMcpDomainsFromEnv: mockGetAllowedMcpDomainsFromEnv, + get isHosted() { + return hostedFlag.value + }, })) vi.mock('@/lib/core/security/input-validation.server', () => inputValidationMock) @@ -331,41 +335,44 @@ describe('validateMcpServerSsrf', () => { beforeEach(() => { vi.clearAllMocks() mockGetAllowedMcpDomainsFromEnv.mockReturnValue(null) + hostedFlag.value = false }) - it('does nothing for undefined URL', async () => { - await expect(validateMcpServerSsrf(undefined)).resolves.toBeUndefined() + it('returns null for undefined URL', async () => { + await expect(validateMcpServerSsrf(undefined)).resolves.toBeNull() expect(mockDnsLookup).not.toHaveBeenCalled() }) - it('skips validation for env var URLs', async () => { - await expect(validateMcpServerSsrf('{{MCP_SERVER_URL}}')).resolves.toBeUndefined() + it('returns null and skips validation for env var URLs', async () => { + await expect(validateMcpServerSsrf('{{MCP_SERVER_URL}}')).resolves.toBeNull() expect(mockDnsLookup).not.toHaveBeenCalled() }) - it('skips validation for URLs with env var in hostname', async () => { - await expect(validateMcpServerSsrf('https://{{MCP_HOST}}/mcp')).resolves.toBeUndefined() + it('returns null and skips validation for URLs with env var in hostname', async () => { + await expect(validateMcpServerSsrf('https://{{MCP_HOST}}/mcp')).resolves.toBeNull() expect(mockDnsLookup).not.toHaveBeenCalled() }) - it('allows localhost URLs without DNS lookup', async () => { - await expect(validateMcpServerSsrf('http://localhost:3000/mcp')).resolves.toBeUndefined() + it('returns null for localhost URLs without DNS lookup', async () => { + await expect(validateMcpServerSsrf('http://localhost:3000/mcp')).resolves.toBeNull() expect(mockDnsLookup).not.toHaveBeenCalled() }) - it('allows 127.0.0.1 URLs without DNS lookup', async () => { - await expect(validateMcpServerSsrf('http://127.0.0.1:8080/mcp')).resolves.toBeUndefined() + it('returns null for 127.0.0.1 literal without DNS lookup', async () => { + await expect(validateMcpServerSsrf('http://127.0.0.1:8080/mcp')).resolves.toBeNull() expect(mockDnsLookup).not.toHaveBeenCalled() }) - it('allows URLs that resolve to public IPs', async () => { + it('returns resolved IP for URLs that resolve to public IPs', async () => { mockDnsLookup.mockResolvedValue({ address: '93.184.216.34' }) - await expect(validateMcpServerSsrf('https://example.com/mcp')).resolves.toBeUndefined() + await expect(validateMcpServerSsrf('https://example.com/mcp')).resolves.toBe('93.184.216.34') }) - it('allows HTTP URLs on non-localhost hosts', async () => { + it('returns resolved IP for HTTP URLs on non-localhost hosts', async () => { mockDnsLookup.mockResolvedValue({ address: '93.184.216.34' }) - await expect(validateMcpServerSsrf('http://example.com:3000/mcp')).resolves.toBeUndefined() + await expect(validateMcpServerSsrf('http://example.com:3000/mcp')).resolves.toBe( + '93.184.216.34' + ) }) it('throws McpSsrfError for cloud metadata IP literal', async () => { @@ -402,21 +409,97 @@ describe('validateMcpServerSsrf', () => { ) }) - it('allows URLs resolving to loopback (localhost alias)', async () => { + it('returns resolved IP for URLs resolving to loopback on self-hosted (localhost alias)', async () => { mockDnsLookup.mockResolvedValue({ address: '127.0.0.1' }) - await expect(validateMcpServerSsrf('http://my-local-alias:3000/mcp')).resolves.toBeUndefined() + await expect(validateMcpServerSsrf('http://my-local-alias:3000/mcp')).resolves.toBe('127.0.0.1') }) it('throws for malformed URLs', async () => { await expect(validateMcpServerSsrf('not-a-url')).rejects.toThrow(McpSsrfError) }) + describe('hosted environment', () => { + beforeEach(() => { + hostedFlag.value = true + }) + + it('rejects localhost URLs on hosted', async () => { + await expect(validateMcpServerSsrf('http://localhost:3000/mcp')).rejects.toThrow(McpSsrfError) + }) + + it('rejects 127.0.0.1 URLs on hosted', async () => { + await expect(validateMcpServerSsrf('http://127.0.0.1:8080/mcp')).rejects.toThrow(McpSsrfError) + }) + + it('rejects [::1] URLs on hosted', async () => { + await expect(validateMcpServerSsrf('http://[::1]:8080/mcp')).rejects.toThrow(McpSsrfError) + }) + + it('rejects URLs resolving to loopback on hosted', async () => { + mockDnsLookup.mockResolvedValue({ address: '127.0.0.1' }) + await expect(validateMcpServerSsrf('http://my-local-alias:3000/mcp')).rejects.toThrow( + McpSsrfError + ) + }) + + it('returns resolved IP for public IP resolutions on hosted', async () => { + mockDnsLookup.mockResolvedValue({ address: '93.184.216.34' }) + await expect(validateMcpServerSsrf('https://example.com/mcp')).resolves.toBe('93.184.216.34') + }) + + it('skips loopback check on hosted when allowlist is configured', async () => { + mockGetAllowedMcpDomainsFromEnv.mockReturnValue(['localhost']) + await expect(validateMcpServerSsrf('http://localhost:3000/mcp')).resolves.toBeNull() + }) + + it('still blocks RFC-1918 IP literals on hosted (regression)', async () => { + await expect(validateMcpServerSsrf('http://10.0.0.1/mcp')).rejects.toThrow(McpSsrfError) + await expect(validateMcpServerSsrf('http://192.168.1.1/mcp')).rejects.toThrow(McpSsrfError) + }) + + it('still blocks cloud metadata IP on hosted (regression)', async () => { + await expect( + validateMcpServerSsrf('http://169.254.169.254/latest/meta-data/') + ).rejects.toThrow(McpSsrfError) + }) + + it('still blocks DNS resolutions to private IPs on hosted (regression)', async () => { + mockDnsLookup.mockResolvedValue({ address: '10.0.0.5' }) + await expect(validateMcpServerSsrf('https://internal.corp/mcp')).rejects.toThrow(McpSsrfError) + }) + + it('still skips env var hostnames on hosted', async () => { + await expect(validateMcpServerSsrf('{{MCP_SERVER_URL}}')).resolves.toBeNull() + await expect(validateMcpServerSsrf('https://{{MCP_HOST}}/mcp')).resolves.toBeNull() + expect(mockDnsLookup).not.toHaveBeenCalled() + }) + }) + + describe('self-hosted environment (regression)', () => { + beforeEach(() => { + hostedFlag.value = false + }) + + it('still allows localhost URLs (returns null, no pinning needed)', async () => { + await expect(validateMcpServerSsrf('http://localhost:3000/mcp')).resolves.toBeNull() + }) + + it('still allows 127.0.0.1 URLs (returns null, no pinning needed)', async () => { + await expect(validateMcpServerSsrf('http://127.0.0.1:8080/mcp')).resolves.toBeNull() + }) + + it('returns resolved loopback IP for DNS aliases (caller pins)', async () => { + mockDnsLookup.mockResolvedValue({ address: '127.0.0.1' }) + await expect(validateMcpServerSsrf('http://my-local-alias/mcp')).resolves.toBe('127.0.0.1') + }) + }) + it('skips all checks when ALLOWED_MCP_DOMAINS is configured', async () => { mockGetAllowedMcpDomainsFromEnv.mockReturnValue(['internal.corp']) - await expect(validateMcpServerSsrf('http://10.0.0.1/mcp')).resolves.toBeUndefined() + await expect(validateMcpServerSsrf('http://10.0.0.1/mcp')).resolves.toBeNull() await expect( validateMcpServerSsrf('http://169.254.169.254/latest/meta-data/') - ).resolves.toBeUndefined() + ).resolves.toBeNull() expect(mockDnsLookup).not.toHaveBeenCalled() }) }) diff --git a/apps/sim/lib/mcp/domain-check.ts b/apps/sim/lib/mcp/domain-check.ts index 83ec36c69f5..9e57b23c7f4 100644 --- a/apps/sim/lib/mcp/domain-check.ts +++ b/apps/sim/lib/mcp/domain-check.ts @@ -2,7 +2,7 @@ import dns from 'dns/promises' import { createLogger } from '@sim/logger' import { toError } from '@sim/utils/errors' import * as ipaddr from 'ipaddr.js' -import { getAllowedMcpDomainsFromEnv } from '@/lib/core/config/feature-flags' +import { getAllowedMcpDomainsFromEnv, isHosted } from '@/lib/core/config/feature-flags' import { isPrivateOrReservedIP } from '@/lib/core/security/input-validation.server' import { createEnvVarPattern } from '@/executor/utils/reference-validation' @@ -133,16 +133,25 @@ function isLocalhostHostname(hostname: string): boolean { * Does NOT enforce protocol (HTTP is allowed) or block service ports — MCP * servers legitimately run on HTTP and on arbitrary ports. * - * Localhost/loopback is always allowed for local dev MCP servers. + * Localhost/loopback is allowed for local dev MCP servers in self-hosted + * deployments, but blocked on the hosted environment (sim.ai) where users + * must not be able to reach the server's own loopback interface. * URLs with env var references in the hostname are skipped — they will be * validated after resolution at execution time. * + * Returns the resolved IP address when DNS resolution was performed (so the + * caller can pin subsequent connections to that IP and prevent DNS-rebinding + * TOCTOU attacks). Returns null in cases where pinning is unnecessary or + * impossible: no URL, allowlist-only mode, env-var hostnames (validated later), + * IP literals (no DNS to rebind), and localhost on self-hosted (no rebinding + * risk against a fixed loopback). + * * @throws McpSsrfError if the URL resolves to a blocked IP address */ -export async function validateMcpServerSsrf(url: string | undefined): Promise { - if (!url) return - if (getAllowedMcpDomainsFromEnv() !== null) return - if (hasEnvVarInHostname(url)) return +export async function validateMcpServerSsrf(url: string | undefined): Promise { + if (!url) return null + if (getAllowedMcpDomainsFromEnv() !== null) return null + if (hasEnvVarInHostname(url)) return null let hostname: string try { @@ -154,28 +163,47 @@ export async function validateMcpServerSsrf(url: string | undefined): Promise { + const capturedAgentOptions: unknown[] = [] + class MockAgent { + constructor(options: unknown) { + capturedAgentOptions.push(options) + } + } + return { + mockAgent: MockAgent, + mockCreatePinnedLookup: vi.fn(), + mockFetch: vi.fn(), + capturedAgentOptions, + } +}) + +vi.mock('undici', () => ({ Agent: mockAgent })) +vi.mock('@/lib/core/security/input-validation.server', () => ({ + createPinnedLookup: mockCreatePinnedLookup, +})) + +import { createMcpPinnedFetch } from './pinned-fetch' + +describe('createMcpPinnedFetch', () => { + const originalFetch = globalThis.fetch + + beforeEach(() => { + vi.clearAllMocks() + capturedAgentOptions.length = 0 + mockCreatePinnedLookup.mockReturnValue('pinned-lookup-fn') + globalThis.fetch = mockFetch as unknown as typeof fetch + mockFetch.mockResolvedValue(new Response('ok')) + }) + + afterEach(() => { + globalThis.fetch = originalFetch + }) + + it('builds an undici Agent with the pinned lookup for the resolved IP', () => { + createMcpPinnedFetch('203.0.113.10') + expect(mockCreatePinnedLookup).toHaveBeenCalledWith('203.0.113.10') + expect(capturedAgentOptions).toHaveLength(1) + expect(capturedAgentOptions[0]).toEqual({ connect: { lookup: 'pinned-lookup-fn' } }) + }) + + it('forwards the dispatcher on every fetch call', async () => { + const fetchLike = createMcpPinnedFetch('203.0.113.10') + await fetchLike('https://example.com/mcp', { method: 'POST' }) + expect(mockFetch).toHaveBeenCalledTimes(1) + const [url, init] = mockFetch.mock.calls[0] + expect(url).toBe('https://example.com/mcp') + expect((init as { dispatcher?: unknown }).dispatcher).toBeInstanceOf(mockAgent) + expect((init as { method?: string }).method).toBe('POST') + }) + + it('preserves caller-provided init options (headers, signal)', async () => { + const fetchLike = createMcpPinnedFetch('203.0.113.10') + const controller = new AbortController() + await fetchLike('https://example.com/mcp', { + method: 'GET', + headers: { 'x-test': '1' }, + signal: controller.signal, + }) + const init = mockFetch.mock.calls[0][1] as RequestInit & { dispatcher?: unknown } + expect(init.headers).toEqual({ 'x-test': '1' }) + expect(init.signal).toBe(controller.signal) + expect(init.dispatcher).toBeInstanceOf(mockAgent) + }) + + it('handles undefined init gracefully', async () => { + const fetchLike = createMcpPinnedFetch('203.0.113.10') + await fetchLike('https://example.com/mcp') + const init = mockFetch.mock.calls[0][1] as { dispatcher?: unknown } + expect(init.dispatcher).toBeInstanceOf(mockAgent) + }) + + it('reuses the same dispatcher across calls (one Agent per fetch instance)', async () => { + const fetchLike = createMcpPinnedFetch('203.0.113.10') + await fetchLike('https://example.com/a') + await fetchLike('https://example.com/b') + expect(capturedAgentOptions).toHaveLength(1) + const d1 = (mockFetch.mock.calls[0][1] as { dispatcher: unknown }).dispatcher + const d2 = (mockFetch.mock.calls[1][1] as { dispatcher: unknown }).dispatcher + expect(d1).toBe(d2) + }) +}) diff --git a/apps/sim/lib/mcp/pinned-fetch.ts b/apps/sim/lib/mcp/pinned-fetch.ts new file mode 100644 index 00000000000..227c746fb01 --- /dev/null +++ b/apps/sim/lib/mcp/pinned-fetch.ts @@ -0,0 +1,25 @@ +import type { FetchLike } from '@modelcontextprotocol/sdk/shared/transport.js' +import { Agent } from 'undici' +import { createPinnedLookup } from '@/lib/core/security/input-validation.server' + +/** + * Creates a FetchLike that pins all outbound HTTP connections to a pre-resolved + * IP address. Used by the MCP transport to prevent DNS-rebinding (TOCTOU) + * attacks: validation performs DNS once and confirms the IP is allowed; this + * fetch then forces every subsequent request (initial POST, SSE GET, redirects) + * to use that same IP, regardless of what the hostname now resolves to. + * + * The original hostname is preserved on the request so TLS SNI and the Host + * header continue to match the certificate. + */ +export function createMcpPinnedFetch(resolvedIP: string): FetchLike { + const dispatcher = new Agent({ + connect: { lookup: createPinnedLookup(resolvedIP) }, + }) + + return (url, init) => + globalThis.fetch(url, { + ...(init ?? {}), + dispatcher, + } as RequestInit & { dispatcher: Agent }) +} diff --git a/apps/sim/lib/mcp/service.ts b/apps/sim/lib/mcp/service.ts index 4ec764c3d41..7838f682822 100644 --- a/apps/sim/lib/mcp/service.ts +++ b/apps/sim/lib/mcp/service.ts @@ -69,13 +69,13 @@ class McpService { config: McpServerConfig, userId: string, workspaceId?: string - ): Promise { + ): Promise<{ config: McpServerConfig; resolvedIP: string | null }> { const { config: resolvedConfig } = await resolveMcpConfigEnvVars(config, userId, workspaceId, { strict: true, }) validateMcpDomain(resolvedConfig.url) - await validateMcpServerSsrf(resolvedConfig.url) - return resolvedConfig + const resolvedIP = await validateMcpServerSsrf(resolvedConfig.url) + return { config: resolvedConfig, resolvedIP } } /** @@ -156,7 +156,10 @@ class McpService { /** * Create and connect to an MCP client */ - private async createClient(config: McpServerConfig): Promise { + private async createClient( + config: McpServerConfig, + resolvedIP: string | null + ): Promise { const securityPolicy = { requireConsent: true, auditLevel: 'basic' as const, @@ -164,7 +167,11 @@ class McpService { allowedOrigins: config.url ? [new URL(config.url).origin] : undefined, } - const client = new McpClient(config, securityPolicy) + const client = new McpClient({ + config, + securityPolicy, + resolvedIP: resolvedIP ?? undefined, + }) await client.connect() return client } @@ -194,11 +201,15 @@ class McpService { throw new Error(`Server ${serverId} not found or not accessible`) } - const resolvedConfig = await this.resolveConfigEnvVars(config, userId, workspaceId) + const { config: resolvedConfig, resolvedIP } = await this.resolveConfigEnvVars( + config, + userId, + workspaceId + ) if (extraHeaders && Object.keys(extraHeaders).length > 0) { resolvedConfig.headers = { ...resolvedConfig.headers, ...extraHeaders } } - const client = await this.createClient(resolvedConfig) + const client = await this.createClient(resolvedConfig, resolvedIP) try { const result = await client.callTool(toolCall) @@ -348,14 +359,18 @@ class McpService { const allTools: McpTool[] = [] const results = await Promise.allSettled( servers.map(async (config) => { - const resolvedConfig = await this.resolveConfigEnvVars(config, userId, workspaceId) - const client = await this.createClient(resolvedConfig) + const { config: resolvedConfig, resolvedIP } = await this.resolveConfigEnvVars( + config, + userId, + workspaceId + ) + const client = await this.createClient(resolvedConfig, resolvedIP) try { const tools = await client.listTools() logger.debug( `[${requestId}] Discovered ${tools.length} tools from server ${config.name}` ) - return { serverId: config.id, tools, resolvedConfig } + return { serverId: config.id, tools, resolvedConfig, resolvedIP } } finally { await client.disconnect() } @@ -394,13 +409,15 @@ class McpService { if (mcpConnectionManager) { for (const [index, result] of results.entries()) { if (result.status === 'fulfilled') { - const { resolvedConfig } = result.value - mcpConnectionManager.connect(resolvedConfig, userId, workspaceId).catch((err) => { - logger.warn( - `[${requestId}] Persistent connection failed for ${servers[index].name}:`, - err - ) - }) + const { resolvedConfig, resolvedIP } = result.value + mcpConnectionManager + .connect(resolvedConfig, userId, workspaceId, resolvedIP) + .catch((err) => { + logger.warn( + `[${requestId}] Persistent connection failed for ${servers[index].name}:`, + err + ) + }) } } } @@ -450,8 +467,12 @@ class McpService { throw new Error(`Server ${serverId} not found or not accessible`) } - const resolvedConfig = await this.resolveConfigEnvVars(config, userId, workspaceId) - const client = await this.createClient(resolvedConfig) + const { config: resolvedConfig, resolvedIP } = await this.resolveConfigEnvVars( + config, + userId, + workspaceId + ) + const client = await this.createClient(resolvedConfig, resolvedIP) try { const tools = await client.listTools() @@ -490,8 +511,12 @@ class McpService { for (const config of servers) { try { - const resolvedConfig = await this.resolveConfigEnvVars(config, userId, workspaceId) - const client = await this.createClient(resolvedConfig) + const { config: resolvedConfig, resolvedIP } = await this.resolveConfigEnvVars( + config, + userId, + workspaceId + ) + const client = await this.createClient(resolvedConfig, resolvedIP) const tools = await client.listTools() await client.disconnect() diff --git a/apps/sim/lib/mcp/types.ts b/apps/sim/lib/mcp/types.ts index db9ac11fd0a..f4f5c939efd 100644 --- a/apps/sim/lib/mcp/types.ts +++ b/apps/sim/lib/mcp/types.ts @@ -161,6 +161,14 @@ export interface McpClientOptions { config: McpServerConfig securityPolicy?: McpSecurityPolicy onToolsChanged?: McpToolsChangedCallback + /** + * Pre-resolved IP address to pin all transport HTTP connections to. When + * set, the SDK transport uses a custom fetch backed by an undici Agent with + * a fixed DNS lookup, preventing DNS-rebinding (TOCTOU) attacks between + * URL validation and connection. Should be supplied by callers that have + * just validated the URL via `validateMcpServerSsrf`. + */ + resolvedIP?: string } /** diff --git a/apps/sim/package.json b/apps/sim/package.json index 71521b166bb..f175df18d51 100644 --- a/apps/sim/package.json +++ b/apps/sim/package.json @@ -196,6 +196,7 @@ "three": "0.177.0", "tldts": "7.0.30", "twilio": "5.9.0", + "undici": "7.25.0", "unified": "11.0.5", "unpdf": "1.4.0", "xlsx": "https://cdn.sheetjs.com/xlsx-0.20.3/xlsx-0.20.3.tgz", diff --git a/apps/sim/tools/agiloft/utils.test.ts b/apps/sim/tools/agiloft/utils.test.ts new file mode 100644 index 00000000000..10f6edced9a --- /dev/null +++ b/apps/sim/tools/agiloft/utils.test.ts @@ -0,0 +1,212 @@ +/** + * @vitest-environment node + */ +import { inputValidationMock, inputValidationMockFns } from '@sim/testing' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +vi.mock('@/lib/core/security/input-validation.server', () => inputValidationMock) + +import { executeAgiloftRequest, resolveAgiloftInstance } from '@/tools/agiloft/utils' + +const baseParams = { + instanceUrl: 'https://example.agiloft.com', + knowledgeBase: 'demo', + login: 'admin', + password: 'secret', + table: 'contracts', +} + +const PINNED_IP = '93.184.216.34' + +function mockSecureFetchResponse(body: { + ok?: boolean + status?: number + json?: unknown + text?: string + arrayBuffer?: ArrayBuffer +}) { + return { + ok: body.ok ?? true, + status: body.status ?? 200, + statusText: '', + headers: new Headers(), + body: null, + text: async () => body.text ?? '', + json: async () => body.json ?? {}, + arrayBuffer: async () => body.arrayBuffer ?? new ArrayBuffer(0), + } +} + +beforeEach(() => { + vi.clearAllMocks() + inputValidationMockFns.mockValidateUrlWithDNS.mockResolvedValue({ + isValid: true, + resolvedIP: PINNED_IP, + originalHostname: 'example.agiloft.com', + }) +}) + +describe('resolveAgiloftInstance', () => { + it('returns the resolved IP for a valid URL', async () => { + const ip = await resolveAgiloftInstance('https://example.agiloft.com') + expect(ip).toBe(PINNED_IP) + expect(inputValidationMockFns.mockValidateUrlWithDNS).toHaveBeenCalledWith( + 'https://example.agiloft.com', + 'instanceUrl' + ) + }) + + it('throws when the URL resolves to a blocked IP', async () => { + inputValidationMockFns.mockValidateUrlWithDNS.mockResolvedValueOnce({ + isValid: false, + error: 'instanceUrl resolves to a blocked IP address', + }) + + await expect(resolveAgiloftInstance('https://attacker.example.com')).rejects.toThrow( + 'instanceUrl resolves to a blocked IP address' + ) + }) + + it('throws when validation succeeds but no IP is returned', async () => { + inputValidationMockFns.mockValidateUrlWithDNS.mockResolvedValueOnce({ + isValid: true, + }) + + await expect(resolveAgiloftInstance('https://example.agiloft.com')).rejects.toThrow( + 'Invalid Agiloft instance URL' + ) + }) +}) + +describe('executeAgiloftRequest', () => { + it('pins the resolved IP across login, operation, and logout', async () => { + inputValidationMockFns.mockSecureFetchWithPinnedIP + // EWLogin + .mockResolvedValueOnce(mockSecureFetchResponse({ json: { access_token: 'tok-1' } })) + // operation + .mockResolvedValueOnce(mockSecureFetchResponse({ json: { id: 42, fields: { name: 'foo' } } })) + // EWLogout + .mockResolvedValueOnce(mockSecureFetchResponse({})) + + const result = await executeAgiloftRequest( + baseParams, + (base) => ({ + url: `${base}/ewws/REST/demo/contracts/42`, + method: 'GET', + headers: { Accept: 'application/json' }, + }), + async (response) => { + const data = (await response.json()) as { id: number; fields: Record } + return { + success: response.ok, + output: { id: String(data.id), fields: data.fields }, + } + } + ) + + expect(result).toEqual({ success: true, output: { id: '42', fields: { name: 'foo' } } }) + + const calls = inputValidationMockFns.mockSecureFetchWithPinnedIP.mock.calls + expect(calls).toHaveLength(3) + + // Every call MUST use the pre-resolved IP — this is the SSRF fix. + for (const call of calls) { + expect(call[1]).toBe(PINNED_IP) + } + + // Login URL preserves the original hostname (TLS SNI requirement). + expect(calls[0][0]).toBe( + 'https://example.agiloft.com/ewws/EWLogin?$KB=demo&$login=admin&$password=secret' + ) + expect(calls[0][2]).toEqual({ method: 'POST' }) + + // Operation request includes the bearer token issued by login. + expect(calls[1][0]).toBe('https://example.agiloft.com/ewws/REST/demo/contracts/42') + expect(calls[1][2]).toMatchObject({ + method: 'GET', + headers: { Accept: 'application/json', Authorization: 'Bearer tok-1' }, + }) + + // Logout uses the bearer token and the original hostname. + expect(calls[2][0]).toBe('https://example.agiloft.com/ewws/EWLogout?$KB=demo') + expect(calls[2][2]).toMatchObject({ + method: 'POST', + headers: { Authorization: 'Bearer tok-1' }, + }) + + // DNS lookup happens exactly once, before any HTTP request. + expect(inputValidationMockFns.mockValidateUrlWithDNS).toHaveBeenCalledTimes(1) + }) + + it('still calls logout when the operation throws', async () => { + inputValidationMockFns.mockSecureFetchWithPinnedIP + .mockResolvedValueOnce(mockSecureFetchResponse({ json: { access_token: 'tok-2' } })) + .mockResolvedValueOnce(mockSecureFetchResponse({ ok: false, status: 500 })) + .mockResolvedValueOnce(mockSecureFetchResponse({})) + + await expect( + executeAgiloftRequest( + baseParams, + (base) => ({ url: `${base}/ewws/REST/demo/contracts/42`, method: 'GET' }), + async (response) => { + if (!response.ok) throw new Error('operation failed') + return { success: true, output: {} } + } + ) + ).rejects.toThrow('operation failed') + + expect(inputValidationMockFns.mockSecureFetchWithPinnedIP).toHaveBeenCalledTimes(3) + expect(inputValidationMockFns.mockSecureFetchWithPinnedIP.mock.calls[2][0]).toContain( + '/ewws/EWLogout' + ) + }) + + it('swallows logout failures (best-effort)', async () => { + inputValidationMockFns.mockSecureFetchWithPinnedIP + .mockResolvedValueOnce(mockSecureFetchResponse({ json: { access_token: 'tok-3' } })) + .mockResolvedValueOnce(mockSecureFetchResponse({ json: { ok: true } })) + .mockRejectedValueOnce(new Error('logout network error')) + + const result = await executeAgiloftRequest( + baseParams, + (base) => ({ url: `${base}/ewws/REST/demo/contracts/42`, method: 'GET' }), + async () => ({ success: true, output: {} }) + ) + + expect(result.success).toBe(true) + }) + + it('throws when login does not return an access token', async () => { + inputValidationMockFns.mockSecureFetchWithPinnedIP.mockResolvedValueOnce( + mockSecureFetchResponse({ json: {} }) + ) + // Login failure should still trigger no logout, since no token was issued. + + await expect( + executeAgiloftRequest( + baseParams, + (base) => ({ url: `${base}/ewws/REST/demo/contracts/42`, method: 'GET' }), + async () => ({ success: true, output: {} }) + ) + ).rejects.toThrow('Agiloft login did not return an access token') + + expect(inputValidationMockFns.mockSecureFetchWithPinnedIP).toHaveBeenCalledTimes(1) + }) + + it('refuses to call any external endpoint when validation rejects the URL', async () => { + inputValidationMockFns.mockValidateUrlWithDNS.mockResolvedValueOnce({ + isValid: false, + error: 'instanceUrl resolves to a blocked IP address', + }) + + await expect( + executeAgiloftRequest( + { ...baseParams, instanceUrl: 'https://attacker.example.com' }, + (base) => ({ url: `${base}/ewws/REST/demo/contracts/42`, method: 'GET' }), + async () => ({ success: true, output: {} }) + ) + ).rejects.toThrow('instanceUrl resolves to a blocked IP address') + + expect(inputValidationMockFns.mockSecureFetchWithPinnedIP).not.toHaveBeenCalled() + }) +}) diff --git a/apps/sim/tools/agiloft/utils.ts b/apps/sim/tools/agiloft/utils.ts index 47184deb5fb..8fd13c5526c 100644 --- a/apps/sim/tools/agiloft/utils.ts +++ b/apps/sim/tools/agiloft/utils.ts @@ -1,5 +1,9 @@ import { createLogger } from '@sim/logger' -import { validateExternalUrl } from '@/lib/core/security/input-validation' +import { + type SecureFetchResponse, + secureFetchWithPinnedIP, + validateUrlWithDNS, +} from '@/lib/core/security/input-validation.server' import type { AgiloftAttachmentInfoParams, AgiloftBaseParams, @@ -21,33 +25,44 @@ interface AgiloftRequestConfig { url: string method: HttpMethod headers?: Record - body?: BodyInit + body?: string | Buffer | Uint8Array +} + +/** + * Validates the Agiloft instance URL and resolves its DNS once, returning the + * resolved IP so subsequent requests can pin to it. This prevents DNS-rebinding + * (TOCTOU) SSRF where the hostname could resolve to a private IP on a later + * lookup. + */ +export async function resolveAgiloftInstance(instanceUrl: string): Promise { + const validation = await validateUrlWithDNS(instanceUrl, 'instanceUrl') + if (!validation.isValid || !validation.resolvedIP) { + throw new Error(validation.error || 'Invalid Agiloft instance URL') + } + return validation.resolvedIP } /** * Exchanges login/password for a short-lived Bearer token via EWLogin. + * Requires a pre-resolved IP to prevent DNS rebinding between validation and + * the actual request. */ -async function agiloftLogin(params: AgiloftBaseParams): Promise { +async function agiloftLogin(params: AgiloftBaseParams, resolvedIP: string): Promise { const base = params.instanceUrl.replace(/\/$/, '') - const urlValidation = validateExternalUrl(params.instanceUrl, 'instanceUrl') - if (!urlValidation.isValid) { - throw new Error(`Invalid Agiloft instance URL: ${urlValidation.error}`) - } - const kb = encodeURIComponent(params.knowledgeBase) const login = encodeURIComponent(params.login) const password = encodeURIComponent(params.password) const url = `${base}/ewws/EWLogin?$KB=${kb}&$login=${login}&$password=${password}` - const response = await fetch(url, { method: 'POST' }) + const response = await secureFetchWithPinnedIP(url, resolvedIP, { method: 'POST' }) if (!response.ok) { const errorText = await response.text() throw new Error(`Agiloft login failed: ${response.status} - ${errorText}`) } - const data = await response.json() + const data = (await response.json()) as { access_token?: string } const token = data.access_token if (!token) { @@ -59,16 +74,18 @@ async function agiloftLogin(params: AgiloftBaseParams): Promise { /** * Cleans up the server session. Best-effort — failures are logged but not thrown. + * Requires a pre-resolved IP to prevent DNS rebinding. */ async function agiloftLogout( instanceUrl: string, knowledgeBase: string, - token: string + token: string, + resolvedIP: string ): Promise { try { const base = instanceUrl.replace(/\/$/, '') const kb = encodeURIComponent(knowledgeBase) - await fetch(`${base}/ewws/EWLogout?$KB=${kb}`, { + await secureFetchWithPinnedIP(`${base}/ewws/EWLogout?$KB=${kb}`, resolvedIP, { method: 'POST', headers: { Authorization: `Bearer ${token}` }, }) @@ -90,14 +107,15 @@ async function agiloftLogout( export async function executeAgiloftRequest( params: AgiloftBaseParams, buildRequest: (base: string) => AgiloftRequestConfig, - transformResponse: (response: Response) => Promise + transformResponse: (response: SecureFetchResponse) => Promise ): Promise { - const token = await agiloftLogin(params) + const resolvedIP = await resolveAgiloftInstance(params.instanceUrl) + const token = await agiloftLogin(params, resolvedIP) const base = params.instanceUrl.replace(/\/$/, '') try { const req = buildRequest(base) - const response = await fetch(req.url, { + const response = await secureFetchWithPinnedIP(req.url, resolvedIP, { method: req.method, headers: { ...req.headers, @@ -107,7 +125,7 @@ export async function executeAgiloftRequest( }) return await transformResponse(response) } finally { - await agiloftLogout(params.instanceUrl, params.knowledgeBase, token) + await agiloftLogout(params.instanceUrl, params.knowledgeBase, token, resolvedIP) } } diff --git a/apps/sim/tools/grafana/update_alert_rule.ts b/apps/sim/tools/grafana/update_alert_rule.ts index 9ca23bff773..e47276490cb 100644 --- a/apps/sim/tools/grafana/update_alert_rule.ts +++ b/apps/sim/tools/grafana/update_alert_rule.ts @@ -1,3 +1,4 @@ +import { secureFetchWithValidation } from '@/lib/core/security/input-validation.server' import { ALERT_RULE_OUTPUT_FIELDS, type GrafanaUpdateAlertRuleParams } from '@/tools/grafana/types' import { mapAlertRule } from '@/tools/grafana/utils' import type { ToolConfig, ToolResponse } from '@/tools/types' @@ -269,13 +270,14 @@ export const updateAlertRuleTool: ToolConfig return { success: true, output: mapAlertRule(data) } }, diff --git a/apps/sim/tools/grafana/update_dashboard.ts b/apps/sim/tools/grafana/update_dashboard.ts index 23449f36830..245f85af64a 100644 --- a/apps/sim/tools/grafana/update_dashboard.ts +++ b/apps/sim/tools/grafana/update_dashboard.ts @@ -1,3 +1,4 @@ +import { secureFetchWithValidation } from '@/lib/core/security/input-validation.server' import type { GrafanaUpdateDashboardParams } from '@/tools/grafana/types' import type { ToolConfig, ToolResponse } from '@/tools/types' @@ -183,11 +184,15 @@ export const updateDashboardTool: ToolConfig Date: Sat, 16 May 2026 22:10:42 -0700 Subject: [PATCH 2/6] fix(otp): don't leak caught error.message; fail-closed on DB retry exhaust - Chat/form OTP routes: replace `error.message || fallback` with generic `Failed to process request` in 500 responses (logger still captures detail). - otp.ts incrementOTPAttempts DB path: on MAX_RETRIES exhaustion, delete the verification row and return `'locked'` instead of trusting a possibly- undercounted final read. Co-Authored-By: Claude Opus 4.7 --- apps/sim/app/api/chat/[identifier]/otp/route.ts | 14 ++++---------- apps/sim/app/api/form/[identifier]/otp/route.ts | 14 ++++---------- apps/sim/lib/core/security/otp.ts | 12 ++++++++---- 3 files changed, 16 insertions(+), 24 deletions(-) diff --git a/apps/sim/app/api/chat/[identifier]/otp/route.ts b/apps/sim/app/api/chat/[identifier]/otp/route.ts index c89cd4721e1..fcccc003e86 100644 --- a/apps/sim/app/api/chat/[identifier]/otp/route.ts +++ b/apps/sim/app/api/chat/[identifier]/otp/route.ts @@ -143,12 +143,9 @@ export const POST = withRouteHandler( logger.info(`[${requestId}] OTP sent to ${email} for chat ${deployment.id}`) return addCorsHeaders(createSuccessResponse({ message: 'Verification code sent' }), request) - } catch (error: any) { + } catch (error) { logger.error(`[${requestId}] Error processing OTP request:`, error) - return addCorsHeaders( - createErrorResponse(error.message || 'Failed to process request', 500), - request - ) + return addCorsHeaders(createErrorResponse('Failed to process request', 500), request) } } ) @@ -239,12 +236,9 @@ export const PUT = withRouteHandler( setChatAuthCookie(response, deployment.id, deployment.authType, deployment.password) return response - } catch (error: any) { + } catch (error) { logger.error(`[${requestId}] Error verifying OTP:`, error) - return addCorsHeaders( - createErrorResponse(error.message || 'Failed to process request', 500), - request - ) + return addCorsHeaders(createErrorResponse('Failed to process request', 500), request) } } ) diff --git a/apps/sim/app/api/form/[identifier]/otp/route.ts b/apps/sim/app/api/form/[identifier]/otp/route.ts index 176a3be50b6..0d9804efa55 100644 --- a/apps/sim/app/api/form/[identifier]/otp/route.ts +++ b/apps/sim/app/api/form/[identifier]/otp/route.ts @@ -149,12 +149,9 @@ export const POST = withRouteHandler( logger.info(`[${requestId}] OTP sent to ${email} for form ${deployment.id}`) return addCorsHeaders(createSuccessResponse({ message: 'Verification code sent' }), request) - } catch (error: any) { + } catch (error) { logger.error(`[${requestId}] Error processing OTP request:`, error) - return addCorsHeaders( - createErrorResponse(error.message || 'Failed to process request', 500), - request - ) + return addCorsHeaders(createErrorResponse('Failed to process request', 500), request) } } ) @@ -256,12 +253,9 @@ export const PUT = withRouteHandler( setFormAuthCookie(response, deployment.id, deployment.authType, deployment.password) return response - } catch (error: any) { + } catch (error) { logger.error(`[${requestId}] Error verifying OTP:`, error) - return addCorsHeaders( - createErrorResponse(error.message || 'Failed to process request', 500), - request - ) + return addCorsHeaders(createErrorResponse('Failed to process request', 500), request) } } ) diff --git a/apps/sim/lib/core/security/otp.ts b/apps/sim/lib/core/security/otp.ts index 576d61ed508..5163487d10b 100644 --- a/apps/sim/lib/core/security/otp.ts +++ b/apps/sim/lib/core/security/otp.ts @@ -220,10 +220,14 @@ export async function incrementOTPAttempts( value = fresh } - const final = await getOTP(kind, deploymentId, email) - if (!final) return 'locked' - const { attempts: finalAttempts } = decodeOTPValue(final) - return finalAttempts >= MAX_OTP_ATTEMPTS ? 'locked' : 'incremented' + /** + * Retry exhaustion under heavy DB-path contention: this request did not + * succeed in writing its own +1, so the stored count may not reflect it. + * Fail closed — invalidate the OTP rather than return `'incremented'` with + * a possibly-undercounted attempt total. + */ + await db.delete(verification).where(eq(verification.identifier, identifier)) + return 'locked' } export async function deleteOTP( From 9fe9cb238fadf0e2e131601883fb51c930927cb4 Mon Sep 17 00:00:00 2001 From: Waleed Latif Date: Sat, 16 May 2026 22:12:11 -0700 Subject: [PATCH 3/6] fix(mcp): use undici fetch directly in pinned-fetch for typed dispatcher Replace `globalThis.fetch` + double-cast with `undici.fetch` so the `dispatcher` option is part of the real type contract. This guarantees pinning won't silently break if a future runtime swaps the underlying fetch implementation. Co-Authored-By: Claude Opus 4.7 --- apps/sim/lib/mcp/pinned-fetch.test.ts | 53 ++++++++++++--------------- apps/sim/lib/mcp/pinned-fetch.ts | 21 ++++++++--- 2 files changed, 40 insertions(+), 34 deletions(-) diff --git a/apps/sim/lib/mcp/pinned-fetch.test.ts b/apps/sim/lib/mcp/pinned-fetch.test.ts index 0e59158c2e8..3237ae4fe44 100644 --- a/apps/sim/lib/mcp/pinned-fetch.test.ts +++ b/apps/sim/lib/mcp/pinned-fetch.test.ts @@ -1,43 +1,38 @@ /** * @vitest-environment node */ -import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { beforeEach, describe, expect, it, vi } from 'vitest' -const { mockAgent, mockCreatePinnedLookup, mockFetch, capturedAgentOptions } = vi.hoisted(() => { - const capturedAgentOptions: unknown[] = [] - class MockAgent { - constructor(options: unknown) { - capturedAgentOptions.push(options) +const { mockAgent, mockCreatePinnedLookup, mockUndiciFetch, capturedAgentOptions } = vi.hoisted( + () => { + const capturedAgentOptions: unknown[] = [] + class MockAgent { + constructor(options: unknown) { + capturedAgentOptions.push(options) + } + } + return { + mockAgent: MockAgent, + mockCreatePinnedLookup: vi.fn(), + mockUndiciFetch: vi.fn(), + capturedAgentOptions, } } - return { - mockAgent: MockAgent, - mockCreatePinnedLookup: vi.fn(), - mockFetch: vi.fn(), - capturedAgentOptions, - } -}) +) -vi.mock('undici', () => ({ Agent: mockAgent })) +vi.mock('undici', () => ({ Agent: mockAgent, fetch: mockUndiciFetch })) vi.mock('@/lib/core/security/input-validation.server', () => ({ createPinnedLookup: mockCreatePinnedLookup, })) -import { createMcpPinnedFetch } from './pinned-fetch' +import { createMcpPinnedFetch } from '@/lib/mcp/pinned-fetch' describe('createMcpPinnedFetch', () => { - const originalFetch = globalThis.fetch - beforeEach(() => { vi.clearAllMocks() capturedAgentOptions.length = 0 mockCreatePinnedLookup.mockReturnValue('pinned-lookup-fn') - globalThis.fetch = mockFetch as unknown as typeof fetch - mockFetch.mockResolvedValue(new Response('ok')) - }) - - afterEach(() => { - globalThis.fetch = originalFetch + mockUndiciFetch.mockResolvedValue(new Response('ok')) }) it('builds an undici Agent with the pinned lookup for the resolved IP', () => { @@ -50,8 +45,8 @@ describe('createMcpPinnedFetch', () => { it('forwards the dispatcher on every fetch call', async () => { const fetchLike = createMcpPinnedFetch('203.0.113.10') await fetchLike('https://example.com/mcp', { method: 'POST' }) - expect(mockFetch).toHaveBeenCalledTimes(1) - const [url, init] = mockFetch.mock.calls[0] + expect(mockUndiciFetch).toHaveBeenCalledTimes(1) + const [url, init] = mockUndiciFetch.mock.calls[0] expect(url).toBe('https://example.com/mcp') expect((init as { dispatcher?: unknown }).dispatcher).toBeInstanceOf(mockAgent) expect((init as { method?: string }).method).toBe('POST') @@ -65,7 +60,7 @@ describe('createMcpPinnedFetch', () => { headers: { 'x-test': '1' }, signal: controller.signal, }) - const init = mockFetch.mock.calls[0][1] as RequestInit & { dispatcher?: unknown } + const init = mockUndiciFetch.mock.calls[0][1] as RequestInit & { dispatcher?: unknown } expect(init.headers).toEqual({ 'x-test': '1' }) expect(init.signal).toBe(controller.signal) expect(init.dispatcher).toBeInstanceOf(mockAgent) @@ -74,7 +69,7 @@ describe('createMcpPinnedFetch', () => { it('handles undefined init gracefully', async () => { const fetchLike = createMcpPinnedFetch('203.0.113.10') await fetchLike('https://example.com/mcp') - const init = mockFetch.mock.calls[0][1] as { dispatcher?: unknown } + const init = mockUndiciFetch.mock.calls[0][1] as { dispatcher?: unknown } expect(init.dispatcher).toBeInstanceOf(mockAgent) }) @@ -83,8 +78,8 @@ describe('createMcpPinnedFetch', () => { await fetchLike('https://example.com/a') await fetchLike('https://example.com/b') expect(capturedAgentOptions).toHaveLength(1) - const d1 = (mockFetch.mock.calls[0][1] as { dispatcher: unknown }).dispatcher - const d2 = (mockFetch.mock.calls[1][1] as { dispatcher: unknown }).dispatcher + const d1 = (mockUndiciFetch.mock.calls[0][1] as { dispatcher: unknown }).dispatcher + const d2 = (mockUndiciFetch.mock.calls[1][1] as { dispatcher: unknown }).dispatcher expect(d1).toBe(d2) }) }) diff --git a/apps/sim/lib/mcp/pinned-fetch.ts b/apps/sim/lib/mcp/pinned-fetch.ts index 227c746fb01..d480e0896cd 100644 --- a/apps/sim/lib/mcp/pinned-fetch.ts +++ b/apps/sim/lib/mcp/pinned-fetch.ts @@ -1,5 +1,5 @@ import type { FetchLike } from '@modelcontextprotocol/sdk/shared/transport.js' -import { Agent } from 'undici' +import { Agent, type RequestInit as UndiciRequestInit, fetch as undiciFetch } from 'undici' import { createPinnedLookup } from '@/lib/core/security/input-validation.server' /** @@ -9,6 +9,10 @@ import { createPinnedLookup } from '@/lib/core/security/input-validation.server' * fetch then forces every subsequent request (initial POST, SSE GET, redirects) * to use that same IP, regardless of what the hostname now resolves to. * + * Uses undici's `fetch` directly so the `dispatcher` option is part of the + * real type contract — not a cast that would silently break if a future + * runtime swapped out the implementation. + * * The original hostname is preserved on the request so TLS SNI and the Host * header continue to match the certificate. */ @@ -17,9 +21,16 @@ export function createMcpPinnedFetch(resolvedIP: string): FetchLike { connect: { lookup: createPinnedLookup(resolvedIP) }, }) - return (url, init) => - globalThis.fetch(url, { - ...(init ?? {}), + return (async (url, init) => { + // DOM `RequestInit` and undici's `RequestInit` are structurally compatible + // at runtime (Node's global fetch IS undici) but differ in TS types. + // Cast the init through unknown to bridge the typing without losing the + // critical `dispatcher` typing on the call itself. + const undiciInit: UndiciRequestInit = { + ...(init as unknown as UndiciRequestInit), dispatcher, - } as RequestInit & { dispatcher: Agent }) + } + const response = await undiciFetch(url as string | URL, undiciInit) + return response as unknown as Response + }) satisfies FetchLike } From 2b06b80da07fb3c663e923f9211ac214e4e0b1ec Mon Sep 17 00:00:00 2001 From: Waleed Latif Date: Sat, 16 May 2026 22:18:34 -0700 Subject: [PATCH 4/6] fix(build): keep agiloft/grafana tool configs client-safe MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Tool config files are statically reachable from the client bundle (via tools/registry.ts → tools/{service}/index.ts). Importing `@/lib/core/security/input-validation.server` from these files pulled `node:dns/promises` into the Turbopack client bundle and broke the build. Split agiloft utils into client-safe (`utils.ts`, plain fetch + sync `validateExternalUrl`) and server-only (`utils.server.ts`, DNS-pinned variants). Routes that need TOCTOU protection import the pinned helpers; the executor-side tool path falls back to sync URL validation (matches the supabase precedent and pre-PR baseline). Grafana update tools likewise switch from `secureFetchWithValidation` (server-only) to inline sync `validateExternalUrl` + plain fetch. Co-Authored-By: Claude Opus 4.7 --- .../sim/app/api/tools/agiloft/attach/route.ts | 12 +- .../app/api/tools/agiloft/retrieve/route.ts | 12 +- apps/sim/tools/agiloft/utils.server.ts | 79 ++++++++++ apps/sim/tools/agiloft/utils.test.ts | 140 ++++-------------- apps/sim/tools/agiloft/utils.ts | 48 ++---- apps/sim/tools/grafana/update_alert_rule.ts | 16 +- apps/sim/tools/grafana/update_dashboard.ts | 25 ++-- 7 files changed, 165 insertions(+), 167 deletions(-) create mode 100644 apps/sim/tools/agiloft/utils.server.ts diff --git a/apps/sim/app/api/tools/agiloft/attach/route.ts b/apps/sim/app/api/tools/agiloft/attach/route.ts index 7a21cc02b54..b0fcb351751 100644 --- a/apps/sim/app/api/tools/agiloft/attach/route.ts +++ b/apps/sim/app/api/tools/agiloft/attach/route.ts @@ -11,12 +11,12 @@ import type { RawFileInput } from '@/lib/uploads/utils/file-schemas' import { processFilesToUserFiles } from '@/lib/uploads/utils/file-utils' import { downloadFileFromStorage } from '@/lib/uploads/utils/file-utils.server' import { assertToolFileAccess } from '@/app/api/files/authorization' +import { buildAttachFileUrl } from '@/tools/agiloft/utils' import { - agiloftLogin, - agiloftLogout, - buildAttachFileUrl, + agiloftLoginPinned, + agiloftLogoutPinned, resolveAgiloftInstance, -} from '@/tools/agiloft/utils' +} from '@/tools/agiloft/utils.server' export const dynamic = 'force-dynamic' @@ -87,7 +87,7 @@ export const POST = withRouteHandler(async (request: NextRequest) => { return NextResponse.json({ success: false, error: toError(error).message }, { status: 400 }) } - const token = await agiloftLogin(data, resolvedIP) + const token = await agiloftLoginPinned(data, resolvedIP) const base = data.instanceUrl.replace(/\/$/, '') try { @@ -139,7 +139,7 @@ export const POST = withRouteHandler(async (request: NextRequest) => { }, }) } finally { - await agiloftLogout(data.instanceUrl, data.knowledgeBase, token, resolvedIP) + await agiloftLogoutPinned(data.instanceUrl, data.knowledgeBase, token, resolvedIP) } } catch (error) { logger.error(`[${requestId}] Error attaching file to Agiloft:`, error) diff --git a/apps/sim/app/api/tools/agiloft/retrieve/route.ts b/apps/sim/app/api/tools/agiloft/retrieve/route.ts index 0d6137988ed..539f0bf7c2e 100644 --- a/apps/sim/app/api/tools/agiloft/retrieve/route.ts +++ b/apps/sim/app/api/tools/agiloft/retrieve/route.ts @@ -7,12 +7,12 @@ import { checkInternalAuth } from '@/lib/auth/hybrid' import { secureFetchWithPinnedIP } from '@/lib/core/security/input-validation.server' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' +import { buildRetrieveAttachmentUrl } from '@/tools/agiloft/utils' import { - agiloftLogin, - agiloftLogout, - buildRetrieveAttachmentUrl, + agiloftLoginPinned, + agiloftLogoutPinned, resolveAgiloftInstance, -} from '@/tools/agiloft/utils' +} from '@/tools/agiloft/utils.server' export const dynamic = 'force-dynamic' @@ -63,7 +63,7 @@ export const POST = withRouteHandler(async (request: NextRequest) => { return NextResponse.json({ success: false, error: toError(error).message }, { status: 400 }) } - const token = await agiloftLogin(data, resolvedIP) + const token = await agiloftLoginPinned(data, resolvedIP) const base = data.instanceUrl.replace(/\/$/, '') try { @@ -127,7 +127,7 @@ export const POST = withRouteHandler(async (request: NextRequest) => { }, }) } finally { - await agiloftLogout(data.instanceUrl, data.knowledgeBase, token, resolvedIP) + await agiloftLogoutPinned(data.instanceUrl, data.knowledgeBase, token, resolvedIP) } } catch (error) { logger.error(`[${requestId}] Error retrieving Agiloft attachment:`, error) diff --git a/apps/sim/tools/agiloft/utils.server.ts b/apps/sim/tools/agiloft/utils.server.ts new file mode 100644 index 00000000000..3aaa0c62b71 --- /dev/null +++ b/apps/sim/tools/agiloft/utils.server.ts @@ -0,0 +1,79 @@ +import { createLogger } from '@sim/logger' +import { + type SecureFetchResponse, + secureFetchWithPinnedIP, + validateUrlWithDNS, +} from '@/lib/core/security/input-validation.server' +import type { AgiloftBaseParams } from '@/tools/agiloft/types' + +const logger = createLogger('AgiloftAuthServer') + +/** + * Validates the Agiloft instance URL and resolves its DNS once, returning the + * resolved IP so subsequent requests can pin to it. This prevents DNS-rebinding + * (TOCTOU) SSRF where the hostname could resolve to a private IP on a later + * lookup. Server-only — uses node:dns/promises. + */ +export async function resolveAgiloftInstance(instanceUrl: string): Promise { + const validation = await validateUrlWithDNS(instanceUrl, 'instanceUrl') + if (!validation.isValid || !validation.resolvedIP) { + throw new Error(validation.error || 'Invalid Agiloft instance URL') + } + return validation.resolvedIP +} + +/** + * DNS-pinned variant of agiloftLogin. Requires a pre-resolved IP so the + * connection cannot be steered to a different host between validation and + * the actual TCP connection. + */ +export async function agiloftLoginPinned( + params: AgiloftBaseParams, + resolvedIP: string +): Promise { + const base = params.instanceUrl.replace(/\/$/, '') + const kb = encodeURIComponent(params.knowledgeBase) + const login = encodeURIComponent(params.login) + const password = encodeURIComponent(params.password) + + const url = `${base}/ewws/EWLogin?$KB=${kb}&$login=${login}&$password=${password}` + const response = await secureFetchWithPinnedIP(url, resolvedIP, { method: 'POST' }) + + if (!response.ok) { + const errorText = await response.text() + throw new Error(`Agiloft login failed: ${response.status} - ${errorText}`) + } + + const data = (await response.json()) as { access_token?: string } + const token = data.access_token + + if (!token) { + throw new Error('Agiloft login did not return an access token') + } + + return token +} + +/** + * DNS-pinned variant of agiloftLogout. Best-effort — failures are logged but + * not thrown. + */ +export async function agiloftLogoutPinned( + instanceUrl: string, + knowledgeBase: string, + token: string, + resolvedIP: string +): Promise { + try { + const base = instanceUrl.replace(/\/$/, '') + const kb = encodeURIComponent(knowledgeBase) + await secureFetchWithPinnedIP(`${base}/ewws/EWLogout?$KB=${kb}`, resolvedIP, { + method: 'POST', + headers: { Authorization: `Bearer ${token}` }, + }) + } catch (error) { + logger.warn('Agiloft logout failed (best-effort)', { error }) + } +} + +export type { SecureFetchResponse } diff --git a/apps/sim/tools/agiloft/utils.test.ts b/apps/sim/tools/agiloft/utils.test.ts index 10f6edced9a..b80eb2a33ba 100644 --- a/apps/sim/tools/agiloft/utils.test.ts +++ b/apps/sim/tools/agiloft/utils.test.ts @@ -1,12 +1,8 @@ /** * @vitest-environment node */ -import { inputValidationMock, inputValidationMockFns } from '@sim/testing' -import { beforeEach, describe, expect, it, vi } from 'vitest' - -vi.mock('@/lib/core/security/input-validation.server', () => inputValidationMock) - -import { executeAgiloftRequest, resolveAgiloftInstance } from '@/tools/agiloft/utils' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { executeAgiloftRequest } from '@/tools/agiloft/utils' const baseParams = { instanceUrl: 'https://example.agiloft.com', @@ -16,77 +12,34 @@ const baseParams = { table: 'contracts', } -const PINNED_IP = '93.184.216.34' - -function mockSecureFetchResponse(body: { - ok?: boolean - status?: number - json?: unknown - text?: string - arrayBuffer?: ArrayBuffer -}) { +function mockFetchResponse(body: { ok?: boolean; status?: number; json?: unknown; text?: string }) { return { ok: body.ok ?? true, status: body.status ?? 200, statusText: '', headers: new Headers(), - body: null, text: async () => body.text ?? '', json: async () => body.json ?? {}, - arrayBuffer: async () => body.arrayBuffer ?? new ArrayBuffer(0), - } + } as unknown as Response } +const fetchSpy = vi.fn() + beforeEach(() => { - vi.clearAllMocks() - inputValidationMockFns.mockValidateUrlWithDNS.mockResolvedValue({ - isValid: true, - resolvedIP: PINNED_IP, - originalHostname: 'example.agiloft.com', - }) + fetchSpy.mockReset() + vi.stubGlobal('fetch', fetchSpy) }) -describe('resolveAgiloftInstance', () => { - it('returns the resolved IP for a valid URL', async () => { - const ip = await resolveAgiloftInstance('https://example.agiloft.com') - expect(ip).toBe(PINNED_IP) - expect(inputValidationMockFns.mockValidateUrlWithDNS).toHaveBeenCalledWith( - 'https://example.agiloft.com', - 'instanceUrl' - ) - }) - - it('throws when the URL resolves to a blocked IP', async () => { - inputValidationMockFns.mockValidateUrlWithDNS.mockResolvedValueOnce({ - isValid: false, - error: 'instanceUrl resolves to a blocked IP address', - }) - - await expect(resolveAgiloftInstance('https://attacker.example.com')).rejects.toThrow( - 'instanceUrl resolves to a blocked IP address' - ) - }) - - it('throws when validation succeeds but no IP is returned', async () => { - inputValidationMockFns.mockValidateUrlWithDNS.mockResolvedValueOnce({ - isValid: true, - }) - - await expect(resolveAgiloftInstance('https://example.agiloft.com')).rejects.toThrow( - 'Invalid Agiloft instance URL' - ) - }) +afterEach(() => { + vi.unstubAllGlobals() }) describe('executeAgiloftRequest', () => { - it('pins the resolved IP across login, operation, and logout', async () => { - inputValidationMockFns.mockSecureFetchWithPinnedIP - // EWLogin - .mockResolvedValueOnce(mockSecureFetchResponse({ json: { access_token: 'tok-1' } })) - // operation - .mockResolvedValueOnce(mockSecureFetchResponse({ json: { id: 42, fields: { name: 'foo' } } })) - // EWLogout - .mockResolvedValueOnce(mockSecureFetchResponse({})) + it('logs in, runs the operation with the bearer token, then logs out', async () => { + fetchSpy + .mockResolvedValueOnce(mockFetchResponse({ json: { access_token: 'tok-1' } })) + .mockResolvedValueOnce(mockFetchResponse({ json: { id: 42, fields: { name: 'foo' } } })) + .mockResolvedValueOnce(mockFetchResponse({})) const result = await executeAgiloftRequest( baseParams, @@ -106,43 +59,24 @@ describe('executeAgiloftRequest', () => { expect(result).toEqual({ success: true, output: { id: '42', fields: { name: 'foo' } } }) - const calls = inputValidationMockFns.mockSecureFetchWithPinnedIP.mock.calls + const calls = fetchSpy.mock.calls expect(calls).toHaveLength(3) - - // Every call MUST use the pre-resolved IP — this is the SSRF fix. - for (const call of calls) { - expect(call[1]).toBe(PINNED_IP) - } - - // Login URL preserves the original hostname (TLS SNI requirement). expect(calls[0][0]).toBe( 'https://example.agiloft.com/ewws/EWLogin?$KB=demo&$login=admin&$password=secret' ) - expect(calls[0][2]).toEqual({ method: 'POST' }) - - // Operation request includes the bearer token issued by login. expect(calls[1][0]).toBe('https://example.agiloft.com/ewws/REST/demo/contracts/42') - expect(calls[1][2]).toMatchObject({ + expect(calls[1][1]).toMatchObject({ method: 'GET', headers: { Accept: 'application/json', Authorization: 'Bearer tok-1' }, }) - - // Logout uses the bearer token and the original hostname. expect(calls[2][0]).toBe('https://example.agiloft.com/ewws/EWLogout?$KB=demo') - expect(calls[2][2]).toMatchObject({ - method: 'POST', - headers: { Authorization: 'Bearer tok-1' }, - }) - - // DNS lookup happens exactly once, before any HTTP request. - expect(inputValidationMockFns.mockValidateUrlWithDNS).toHaveBeenCalledTimes(1) }) it('still calls logout when the operation throws', async () => { - inputValidationMockFns.mockSecureFetchWithPinnedIP - .mockResolvedValueOnce(mockSecureFetchResponse({ json: { access_token: 'tok-2' } })) - .mockResolvedValueOnce(mockSecureFetchResponse({ ok: false, status: 500 })) - .mockResolvedValueOnce(mockSecureFetchResponse({})) + fetchSpy + .mockResolvedValueOnce(mockFetchResponse({ json: { access_token: 'tok-2' } })) + .mockResolvedValueOnce(mockFetchResponse({ ok: false, status: 500 })) + .mockResolvedValueOnce(mockFetchResponse({})) await expect( executeAgiloftRequest( @@ -155,16 +89,14 @@ describe('executeAgiloftRequest', () => { ) ).rejects.toThrow('operation failed') - expect(inputValidationMockFns.mockSecureFetchWithPinnedIP).toHaveBeenCalledTimes(3) - expect(inputValidationMockFns.mockSecureFetchWithPinnedIP.mock.calls[2][0]).toContain( - '/ewws/EWLogout' - ) + expect(fetchSpy).toHaveBeenCalledTimes(3) + expect(fetchSpy.mock.calls[2][0]).toContain('/ewws/EWLogout') }) it('swallows logout failures (best-effort)', async () => { - inputValidationMockFns.mockSecureFetchWithPinnedIP - .mockResolvedValueOnce(mockSecureFetchResponse({ json: { access_token: 'tok-3' } })) - .mockResolvedValueOnce(mockSecureFetchResponse({ json: { ok: true } })) + fetchSpy + .mockResolvedValueOnce(mockFetchResponse({ json: { access_token: 'tok-3' } })) + .mockResolvedValueOnce(mockFetchResponse({ json: { ok: true } })) .mockRejectedValueOnce(new Error('logout network error')) const result = await executeAgiloftRequest( @@ -177,10 +109,7 @@ describe('executeAgiloftRequest', () => { }) it('throws when login does not return an access token', async () => { - inputValidationMockFns.mockSecureFetchWithPinnedIP.mockResolvedValueOnce( - mockSecureFetchResponse({ json: {} }) - ) - // Login failure should still trigger no logout, since no token was issued. + fetchSpy.mockResolvedValueOnce(mockFetchResponse({ json: {} })) await expect( executeAgiloftRequest( @@ -190,23 +119,18 @@ describe('executeAgiloftRequest', () => { ) ).rejects.toThrow('Agiloft login did not return an access token') - expect(inputValidationMockFns.mockSecureFetchWithPinnedIP).toHaveBeenCalledTimes(1) + expect(fetchSpy).toHaveBeenCalledTimes(1) }) - it('refuses to call any external endpoint when validation rejects the URL', async () => { - inputValidationMockFns.mockValidateUrlWithDNS.mockResolvedValueOnce({ - isValid: false, - error: 'instanceUrl resolves to a blocked IP address', - }) - + it('rejects an instance URL that fails synchronous URL validation', async () => { await expect( executeAgiloftRequest( - { ...baseParams, instanceUrl: 'https://attacker.example.com' }, + { ...baseParams, instanceUrl: 'not-a-valid-url' }, (base) => ({ url: `${base}/ewws/REST/demo/contracts/42`, method: 'GET' }), async () => ({ success: true, output: {} }) ) - ).rejects.toThrow('instanceUrl resolves to a blocked IP address') + ).rejects.toThrow(/Invalid Agiloft instance URL/) - expect(inputValidationMockFns.mockSecureFetchWithPinnedIP).not.toHaveBeenCalled() + expect(fetchSpy).not.toHaveBeenCalled() }) }) diff --git a/apps/sim/tools/agiloft/utils.ts b/apps/sim/tools/agiloft/utils.ts index 8fd13c5526c..811187ab833 100644 --- a/apps/sim/tools/agiloft/utils.ts +++ b/apps/sim/tools/agiloft/utils.ts @@ -1,9 +1,5 @@ import { createLogger } from '@sim/logger' -import { - type SecureFetchResponse, - secureFetchWithPinnedIP, - validateUrlWithDNS, -} from '@/lib/core/security/input-validation.server' +import { validateExternalUrl } from '@/lib/core/security/input-validation' import type { AgiloftAttachmentInfoParams, AgiloftBaseParams, @@ -25,37 +21,26 @@ interface AgiloftRequestConfig { url: string method: HttpMethod headers?: Record - body?: string | Buffer | Uint8Array -} - -/** - * Validates the Agiloft instance URL and resolves its DNS once, returning the - * resolved IP so subsequent requests can pin to it. This prevents DNS-rebinding - * (TOCTOU) SSRF where the hostname could resolve to a private IP on a later - * lookup. - */ -export async function resolveAgiloftInstance(instanceUrl: string): Promise { - const validation = await validateUrlWithDNS(instanceUrl, 'instanceUrl') - if (!validation.isValid || !validation.resolvedIP) { - throw new Error(validation.error || 'Invalid Agiloft instance URL') - } - return validation.resolvedIP + body?: BodyInit } /** * Exchanges login/password for a short-lived Bearer token via EWLogin. - * Requires a pre-resolved IP to prevent DNS rebinding between validation and - * the actual request. */ -async function agiloftLogin(params: AgiloftBaseParams, resolvedIP: string): Promise { +async function agiloftLogin(params: AgiloftBaseParams): Promise { const base = params.instanceUrl.replace(/\/$/, '') + const urlValidation = validateExternalUrl(params.instanceUrl, 'instanceUrl') + if (!urlValidation.isValid) { + throw new Error(`Invalid Agiloft instance URL: ${urlValidation.error}`) + } + const kb = encodeURIComponent(params.knowledgeBase) const login = encodeURIComponent(params.login) const password = encodeURIComponent(params.password) const url = `${base}/ewws/EWLogin?$KB=${kb}&$login=${login}&$password=${password}` - const response = await secureFetchWithPinnedIP(url, resolvedIP, { method: 'POST' }) + const response = await fetch(url, { method: 'POST' }) if (!response.ok) { const errorText = await response.text() @@ -74,18 +59,16 @@ async function agiloftLogin(params: AgiloftBaseParams, resolvedIP: string): Prom /** * Cleans up the server session. Best-effort — failures are logged but not thrown. - * Requires a pre-resolved IP to prevent DNS rebinding. */ async function agiloftLogout( instanceUrl: string, knowledgeBase: string, - token: string, - resolvedIP: string + token: string ): Promise { try { const base = instanceUrl.replace(/\/$/, '') const kb = encodeURIComponent(knowledgeBase) - await secureFetchWithPinnedIP(`${base}/ewws/EWLogout?$KB=${kb}`, resolvedIP, { + await fetch(`${base}/ewws/EWLogout?$KB=${kb}`, { method: 'POST', headers: { Authorization: `Bearer ${token}` }, }) @@ -107,15 +90,14 @@ async function agiloftLogout( export async function executeAgiloftRequest( params: AgiloftBaseParams, buildRequest: (base: string) => AgiloftRequestConfig, - transformResponse: (response: SecureFetchResponse) => Promise + transformResponse: (response: Response) => Promise ): Promise { - const resolvedIP = await resolveAgiloftInstance(params.instanceUrl) - const token = await agiloftLogin(params, resolvedIP) + const token = await agiloftLogin(params) const base = params.instanceUrl.replace(/\/$/, '') try { const req = buildRequest(base) - const response = await secureFetchWithPinnedIP(req.url, resolvedIP, { + const response = await fetch(req.url, { method: req.method, headers: { ...req.headers, @@ -125,7 +107,7 @@ export async function executeAgiloftRequest( }) return await transformResponse(response) } finally { - await agiloftLogout(params.instanceUrl, params.knowledgeBase, token, resolvedIP) + await agiloftLogout(params.instanceUrl, params.knowledgeBase, token) } } diff --git a/apps/sim/tools/grafana/update_alert_rule.ts b/apps/sim/tools/grafana/update_alert_rule.ts index e47276490cb..19f2bf8164d 100644 --- a/apps/sim/tools/grafana/update_alert_rule.ts +++ b/apps/sim/tools/grafana/update_alert_rule.ts @@ -1,4 +1,4 @@ -import { secureFetchWithValidation } from '@/lib/core/security/input-validation.server' +import { validateExternalUrl } from '@/lib/core/security/input-validation' import { ALERT_RULE_OUTPUT_FIELDS, type GrafanaUpdateAlertRuleParams } from '@/tools/grafana/types' import { mapAlertRule } from '@/tools/grafana/utils' import type { ToolConfig, ToolResponse } from '@/tools/types' @@ -270,14 +270,22 @@ export const updateAlertRuleTool: ToolConfig Date: Sat, 16 May 2026 22:18:42 -0700 Subject: [PATCH 5/6] fix(knowledge): case-insensitive scheme checks for fileUrl MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Boundary schema accepted uppercase schemes (e.g. HTTPS://, DATA:) via the case-insensitive http regex, but the processor's case-sensitive startsWith('data:') / startsWith('http') / startsWith('https://') checks rejected them with a confusing "Unsupported fileUrl scheme" error. Aligns processor checks to the schema using case-insensitive regex per RFC 3986 §3.1. Co-Authored-By: Claude Opus 4.7 --- apps/sim/lib/api/contracts/knowledge/shared.ts | 2 +- apps/sim/lib/knowledge/documents/document-processor.ts | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/apps/sim/lib/api/contracts/knowledge/shared.ts b/apps/sim/lib/api/contracts/knowledge/shared.ts index 9e68e895a66..e75b3bde368 100644 --- a/apps/sim/lib/api/contracts/knowledge/shared.ts +++ b/apps/sim/lib/api/contracts/knowledge/shared.ts @@ -34,7 +34,7 @@ export const knowledgeDocumentFileUrlSchema = z .string() .min(1, 'File URL is required') .refine( - (value) => value.startsWith('data:') || /^https?:\/\//i.test(value), + (value) => /^data:/i.test(value) || /^https?:\/\//i.test(value), 'File URL must be a data: URI or an http(s):// URL' ) diff --git a/apps/sim/lib/knowledge/documents/document-processor.ts b/apps/sim/lib/knowledge/documents/document-processor.ts index d68f6f4aaff..5e550f5af60 100644 --- a/apps/sim/lib/knowledge/documents/document-processor.ts +++ b/apps/sim/lib/knowledge/documents/document-processor.ts @@ -315,7 +315,7 @@ async function handleFileForOCR( userId?: string, workspaceId?: string | null ) { - const isExternalHttps = fileUrl.startsWith('https://') && !isInternalFileUrl(fileUrl) + const isExternalHttps = /^https:\/\//i.test(fileUrl) && !isInternalFileUrl(fileUrl) if (isExternalHttps) { if (mimeType === 'application/pdf') { @@ -385,14 +385,14 @@ async function downloadFileWithTimeout(fileUrl: string): Promise { } async function downloadFileForBase64(fileUrl: string): Promise { - if (fileUrl.startsWith('data:')) { + if (/^data:/i.test(fileUrl)) { const [, base64Data] = fileUrl.split(',') if (!base64Data) { throw new Error('Invalid data URI format') } return Buffer.from(base64Data, 'base64') } - if (fileUrl.startsWith('http')) { + if (/^https?:\/\//i.test(fileUrl)) { return downloadFileWithTimeout(fileUrl) } throw new Error('Unsupported fileUrl scheme: only data: URIs and http(s):// URLs are allowed') @@ -782,9 +782,9 @@ async function parseWithFileParser(fileUrl: string, filename: string, mimeType: let content: string let metadata: FileParseMetadata = {} - if (fileUrl.startsWith('data:')) { + if (/^data:/i.test(fileUrl)) { content = await parseDataURI(fileUrl, filename, mimeType) - } else if (fileUrl.startsWith('http')) { + } else if (/^https?:\/\//i.test(fileUrl)) { const result = await parseHttpFile(fileUrl, filename, mimeType) content = result.content metadata = result.metadata || {} From 21a93a58f824a0a30ded5901abd580f062ffcfec Mon Sep 17 00:00:00 2001 From: Waleed Latif Date: Sat, 16 May 2026 22:22:17 -0700 Subject: [PATCH 6/6] fix(mcp): annotate undici/DOM type-bridge double-casts in pinned-fetch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Strict audit was failing on two new `as unknown as` casts in pinned-fetch.ts. They bridge DOM `RequestInit`/`Response` ↔ undici equivalents (structurally compatible at runtime since Node's global fetch is undici) and are required to satisfy the FetchLike contract. Annotate so they count as documented exemptions instead of new violations. Co-Authored-By: Claude Opus 4.7 --- apps/sim/lib/mcp/pinned-fetch.ts | 2 ++ 1 file changed, 2 insertions(+) diff --git a/apps/sim/lib/mcp/pinned-fetch.ts b/apps/sim/lib/mcp/pinned-fetch.ts index d480e0896cd..798de5710e6 100644 --- a/apps/sim/lib/mcp/pinned-fetch.ts +++ b/apps/sim/lib/mcp/pinned-fetch.ts @@ -27,10 +27,12 @@ export function createMcpPinnedFetch(resolvedIP: string): FetchLike { // Cast the init through unknown to bridge the typing without losing the // critical `dispatcher` typing on the call itself. const undiciInit: UndiciRequestInit = { + // double-cast-allowed: DOM RequestInit and undici RequestInit are structurally compatible at runtime (Node's global fetch IS undici) but the TS types differ ...(init as unknown as UndiciRequestInit), dispatcher, } const response = await undiciFetch(url as string | URL, undiciInit) + // double-cast-allowed: undici Response and DOM Response are structurally compatible at runtime; bridging the types is required to satisfy the FetchLike contract return response as unknown as Response }) satisfies FetchLike }