diff --git a/apps/docs/content/docs/en/tools/firecrawl.mdx b/apps/docs/content/docs/en/tools/firecrawl.mdx index 6c582bfc76d..84ccd439563 100644 --- a/apps/docs/content/docs/en/tools/firecrawl.mdx +++ b/apps/docs/content/docs/en/tools/firecrawl.mdx @@ -254,6 +254,8 @@ Parse uploaded documents (PDF, DOCX, HTML, etc.) into clean markdown using Firec | `proxy` | string | No | Proxy mode: "basic" or "auto" | | `zeroDataRetention` | boolean | No | Enable zero data retention. Defaults to false. | | `apiKey` | string | Yes | Firecrawl API key | +| `pricing` | custom | No | No description | +| `metadata` | string | No | No description | | `rateLimit` | string | No | No description | #### Output diff --git a/apps/docs/content/docs/en/tools/knowledge.mdx b/apps/docs/content/docs/en/tools/knowledge.mdx index 975754ed556..83cbe9b8fb3 100644 --- a/apps/docs/content/docs/en/tools/knowledge.mdx +++ b/apps/docs/content/docs/en/tools/knowledge.mdx @@ -47,6 +47,8 @@ Search for similar content in a knowledge base using vector similarity | `properties` | string | No | No description | | `tagName` | string | No | No description | | `tagValue` | string | No | No description | +| `rerankerEnabled` | boolean | No | Whether to apply Cohere reranking to vector search results | +| `rerankerModel` | string | No | Cohere rerank model to use \(one of: rerank-v4.0-pro, rerank-v4.0-fast, rerank-v3.5\) | | `tagFilters` | string | No | No description | #### Output diff --git a/apps/sim/app/api/knowledge/[id]/documents/[documentId]/chunks/route.ts b/apps/sim/app/api/knowledge/[id]/documents/[documentId]/chunks/route.ts index 6935272ad6b..58a83dc98e7 100644 --- a/apps/sim/app/api/knowledge/[id]/documents/[documentId]/chunks/route.ts +++ b/apps/sim/app/api/knowledge/[id]/documents/[documentId]/chunks/route.ts @@ -215,7 +215,12 @@ export const POST = withRouteHandler( let cost = null try { - cost = calculateCost('text-embedding-3-small', newChunk.tokenCount, 0, false) + cost = calculateCost( + accessCheck.knowledgeBase.embeddingModel, + newChunk.tokenCount, + 0, + false + ) } catch (error) { logger.warn(`[${requestId}] Failed to calculate cost for chunk upload`, { error: error instanceof Error ? error.message : 'Unknown error', @@ -240,7 +245,7 @@ export const POST = withRouteHandler( completion: 0, total: newChunk.tokenCount, }, - model: 'text-embedding-3-small', + model: accessCheck.knowledgeBase.embeddingModel, pricing: cost.pricing, }, } diff --git a/apps/sim/app/api/knowledge/[id]/route.ts b/apps/sim/app/api/knowledge/[id]/route.ts index 6f97a2515c4..05db471d433 100644 --- a/apps/sim/app/api/knowledge/[id]/route.ts +++ b/apps/sim/app/api/knowledge/[id]/route.ts @@ -27,8 +27,6 @@ const logger = createLogger('KnowledgeBaseByIdAPI') const UpdateKnowledgeBaseSchema = z.object({ name: z.string().min(1, 'Name is required').optional(), description: z.string().optional(), - embeddingModel: z.literal('text-embedding-3-small').optional(), - embeddingDimension: z.literal(1536).optional(), workspaceId: z.string().nullable().optional(), chunkingConfig: z .object({ diff --git a/apps/sim/app/api/knowledge/route.ts b/apps/sim/app/api/knowledge/route.ts index 7f8b0c1309b..ed16ffcdeeb 100644 --- a/apps/sim/app/api/knowledge/route.ts +++ b/apps/sim/app/api/knowledge/route.ts @@ -6,6 +6,7 @@ import { getSession } from '@/lib/auth' import { PlatformEvents } from '@/lib/core/telemetry' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' +import { EMBEDDING_DIMENSIONS, getConfiguredEmbeddingModel } from '@/lib/knowledge/embeddings' import { createKnowledgeBase, getKnowledgeBases, @@ -20,8 +21,6 @@ const CreateKnowledgeBaseSchema = z.object({ name: z.string().min(1, 'Name is required'), description: z.string().optional(), workspaceId: z.string().min(1, 'Workspace ID is required'), - embeddingModel: z.literal('text-embedding-3-small').default('text-embedding-3-small'), - embeddingDimension: z.literal(1536).default(1536), chunkingConfig: z .object({ maxSize: z.number().min(100).max(4000).default(1024), @@ -118,9 +117,13 @@ export const POST = withRouteHandler(async (req: NextRequest) => { try { const validatedData = CreateKnowledgeBaseSchema.parse(body) + const embeddingModel = getConfiguredEmbeddingModel() + const createData = { ...validatedData, userId: session.user.id, + embeddingModel, + embeddingDimension: EMBEDDING_DIMENSIONS, } const newKnowledgeBase = await createKnowledgeBase(createData, requestId) @@ -166,8 +169,8 @@ export const POST = withRouteHandler(async (req: NextRequest) => { metadata: { name: validatedData.name, description: validatedData.description, - embeddingModel: validatedData.embeddingModel, - embeddingDimension: validatedData.embeddingDimension, + embeddingModel, + embeddingDimension: EMBEDDING_DIMENSIONS, chunkingStrategy: validatedData.chunkingConfig.strategy, chunkMaxSize: validatedData.chunkingConfig.maxSize, chunkMinSize: validatedData.chunkingConfig.minSize, diff --git a/apps/sim/app/api/knowledge/search/route.test.ts b/apps/sim/app/api/knowledge/search/route.test.ts index 52c1fc47ccf..40aad7f0afe 100644 --- a/apps/sim/app/api/knowledge/search/route.test.ts +++ b/apps/sim/app/api/knowledge/search/route.test.ts @@ -432,6 +432,7 @@ describe('Knowledge Search API Route', () => { userId: 'user-123', name: 'Test KB', deletedAt: null, + embeddingModel: 'text-embedding-3-small', }, }) @@ -524,6 +525,7 @@ describe('Knowledge Search API Route', () => { userId: 'user-123', name: 'Test KB', deletedAt: null, + embeddingModel: 'text-embedding-3-small', }, }) @@ -571,6 +573,7 @@ describe('Knowledge Search API Route', () => { userId: 'user-123', name: 'Test KB', deletedAt: null, + embeddingModel: 'text-embedding-3-small', }, }) @@ -625,6 +628,7 @@ describe('Knowledge Search API Route', () => { userId: 'user-123', name: 'Test KB', deletedAt: null, + embeddingModel: 'text-embedding-3-small', }, }) @@ -694,6 +698,7 @@ describe('Knowledge Search API Route', () => { userId: 'user-123', name: 'Test KB', deletedAt: null, + embeddingModel: 'text-embedding-3-small', }, }) @@ -739,6 +744,7 @@ describe('Knowledge Search API Route', () => { userId: 'user-123', name: 'Test KB', deletedAt: null, + embeddingModel: 'text-embedding-3-small', }, }) @@ -877,6 +883,7 @@ describe('Knowledge Search API Route', () => { userId: 'user-123', name: 'Test KB', deletedAt: null, + embeddingModel: 'text-embedding-3-small', }, }) @@ -921,11 +928,17 @@ describe('Knowledge Search API Route', () => { userId: 'user-123', name: 'Test KB', deletedAt: null, + embeddingModel: 'text-embedding-3-small', }, }) .mockResolvedValueOnce({ hasAccess: true, - knowledgeBase: { id: 'kb-456', userId: 'user-123', name: 'Test KB 2' }, + knowledgeBase: { + id: 'kb-456', + userId: 'user-123', + name: 'Test KB 2', + embeddingModel: 'text-embedding-3-small', + }, }) mockGetDocumentTagDefinitions.mockResolvedValue(mockTagDefinitions) diff --git a/apps/sim/app/api/knowledge/search/route.ts b/apps/sim/app/api/knowledge/search/route.ts index 6c9db51ccc2..d10a3757e0b 100644 --- a/apps/sim/app/api/knowledge/search/route.ts +++ b/apps/sim/app/api/knowledge/search/route.ts @@ -7,6 +7,8 @@ import { PlatformEvents } from '@/lib/core/telemetry' import { generateRequestId } from '@/lib/core/utils/request' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' import { ALL_TAG_SLOTS } from '@/lib/knowledge/constants' +import { getEmbeddingModelInfo } from '@/lib/knowledge/embedding-models' +import { DEFAULT_RERANKER_MODEL, rerank, SUPPORTED_RERANKER_MODELS } from '@/lib/knowledge/reranker' import { getDocumentTagDefinitions } from '@/lib/knowledge/tags/service' import { buildUndefinedTagsError, validateTagValue } from '@/lib/knowledge/tags/utils' import type { StructuredFilter } from '@/lib/knowledge/types' @@ -20,7 +22,8 @@ import { handleVectorOnlySearch, type SearchResult, } from '@/app/api/knowledge/search/utils' -import { checkKnowledgeBaseAccess } from '@/app/api/knowledge/utils' +import { checkKnowledgeBaseAccess, type KnowledgeBaseAccessResult } from '@/app/api/knowledge/utils' +import { getRerankModelPricing } from '@/providers/models' import { calculateCost } from '@/providers/utils' const logger = createLogger('VectorSearchAPI') @@ -59,6 +62,11 @@ const VectorSearchSchema = z .optional() .nullable() .transform((val) => val || undefined), + rerankerEnabled: z.boolean().optional().default(false), + rerankerModel: z + .enum(SUPPORTED_RERANKER_MODELS as unknown as [string, ...string[]]) + .optional() + .default(DEFAULT_RERANKER_MODEL), }) .refine( (data) => { @@ -235,12 +243,26 @@ export const POST = withRouteHandler(async (request: NextRequest) => { ) } - const workspaceId = accessChecks.find((ac) => ac?.hasAccess)?.knowledgeBase?.workspaceId + const accessibleKbs = accessChecks + .filter((ac): ac is KnowledgeBaseAccessResult => Boolean(ac?.hasAccess)) + .map((ac) => ac.knowledgeBase) + const workspaceId = accessibleKbs[0]?.workspaceId + + const useReranker = validatedData.rerankerEnabled && Boolean(validatedData.query?.trim()) + const rerankerModel = useReranker ? validatedData.rerankerModel : null const hasQuery = validatedData.query && validatedData.query.trim().length > 0 - const queryEmbeddingPromise = hasQuery - ? generateSearchEmbedding(validatedData.query!, undefined, workspaceId) - : Promise.resolve(null) + const embeddingModels = Array.from(new Set(accessibleKbs.map((kb) => kb.embeddingModel))) + if (hasQuery && embeddingModels.length > 1) { + return NextResponse.json( + { + error: + 'Selected knowledge bases use different embedding models and cannot be searched together. Search them separately.', + }, + { status: 400 } + ) + } + const queryEmbeddingModel = embeddingModels[0] // Check if any requested knowledge bases were not accessible const inaccessibleKbIds = knowledgeBaseIds.filter((id) => !accessibleKbIds.includes(id)) @@ -252,6 +274,10 @@ export const POST = withRouteHandler(async (request: NextRequest) => { ) } + const queryEmbeddingPromise = hasQuery + ? generateSearchEmbedding(validatedData.query!, queryEmbeddingModel, workspaceId) + : Promise.resolve(null) + if (workflowId) { const authorization = await authorizeWorkflowByWorkspacePermission({ workflowId, @@ -278,6 +304,10 @@ export const POST = withRouteHandler(async (request: NextRequest) => { const hasFilters = structuredFilters && structuredFilters.length > 0 + // Oversample candidates when reranking so the reranker has more to choose from. + // Cap at 100 to bound Cohere request cost (1 search unit = ≤100 docs). + const candidateTopK = useReranker ? Math.min(100, validatedData.topK * 4) : validatedData.topK + if (!hasQuery && hasFilters) { // Tag-only search without vector similarity results = await handleTagOnlySearch({ @@ -291,24 +321,24 @@ export const POST = withRouteHandler(async (request: NextRequest) => { `[${requestId}] Executing tag + vector search with filters:`, structuredFilters ) - const strategy = getQueryStrategy(accessibleKbIds.length, validatedData.topK) + const strategy = getQueryStrategy(accessibleKbIds.length, candidateTopK) const queryVector = JSON.stringify(await queryEmbeddingPromise) results = await handleTagAndVectorSearch({ knowledgeBaseIds: accessibleKbIds, - topK: validatedData.topK, + topK: candidateTopK, structuredFilters, queryVector, distanceThreshold: strategy.distanceThreshold, }) } else if (hasQuery && !hasFilters) { // Vector-only search - const strategy = getQueryStrategy(accessibleKbIds.length, validatedData.topK) + const strategy = getQueryStrategy(accessibleKbIds.length, candidateTopK) const queryVector = JSON.stringify(await queryEmbeddingPromise) results = await handleVectorOnlySearch({ knowledgeBaseIds: accessibleKbIds, - topK: validatedData.topK, + topK: candidateTopK, queryVector, distanceThreshold: strategy.distanceThreshold, }) @@ -323,13 +353,60 @@ export const POST = withRouteHandler(async (request: NextRequest) => { ) } + // Optional Cohere rerank pass on top of vector results. + const rerankedScores = new Map() + // `rerankBilled` = Cohere was successfully called (even with 0 results) and we owe the search unit. + let rerankBilled = false + let rerankIsBYOK = false + if (useReranker && rerankerModel && results.length > 0) { + const candidateCount = results.length + try { + const { results: ranked, isBYOK } = await rerank( + validatedData.query!, + results.map((r) => ({ id: r.id, text: r.content })), + { model: rerankerModel, topN: validatedData.topK, workspaceId } + ) + rerankBilled = true + rerankIsBYOK = isBYOK + if (ranked.length === 0) { + logger.warn( + `[${requestId}] Reranker returned 0 results; falling back to vector ordering`, + { model: rerankerModel, candidateCount } + ) + results = results.slice(0, validatedData.topK) + } else { + const idToResult = new Map(results.map((r) => [r.id, r])) + results = ranked + .map((r) => idToResult.get(r.item.id)) + .filter((r): r is SearchResult => Boolean(r)) + for (const r of ranked) rerankedScores.set(r.item.id, r.relevanceScore) + logger.info(`[${requestId}] Reranked ${candidateCount} → ${results.length} results`, { + model: rerankerModel, + }) + } + } catch (error) { + logger.warn(`[${requestId}] Reranker failed; falling back to vector ordering`, { + error: error instanceof Error ? error.message : 'Unknown error', + model: rerankerModel, + candidateCount, + workspaceId, + }) + results = results.slice(0, validatedData.topK) + } + } else if (useReranker) { + results = results.slice(0, validatedData.topK) + } + // Calculate cost for the embedding (with fallback if calculation fails) let cost = null let tokenCount = null if (hasQuery) { try { - tokenCount = estimateTokenCount(validatedData.query!, 'openai') - cost = calculateCost('text-embedding-3-small', tokenCount.count, 0, false) + tokenCount = estimateTokenCount( + validatedData.query!, + getEmbeddingModelInfo(queryEmbeddingModel).tokenizerProvider + ) + cost = calculateCost(queryEmbeddingModel, tokenCount.count, 0, false) } catch (error) { logger.warn(`[${requestId}] Failed to calculate cost for search query`, { error: error instanceof Error ? error.message : 'Unknown error', @@ -338,6 +415,32 @@ export const POST = withRouteHandler(async (request: NextRequest) => { } } + // Add Cohere rerank cost (1 search unit per successful call, since we cap candidates ≤100). + // Bill on every successful API response — Cohere charges even when 0 results are returned. + let rerankerCost = 0 + if (rerankBilled && rerankerModel && !rerankIsBYOK) { + const pricing = getRerankModelPricing(rerankerModel) + if (pricing) { + rerankerCost = pricing.perSearchUnit + if (cost) { + cost = { + ...cost, + input: cost.input + rerankerCost, + total: cost.total + rerankerCost, + } + } else { + cost = { + input: rerankerCost, + output: 0, + total: rerankerCost, + pricing: { input: 0, output: 0, updatedAt: pricing.updatedAt }, + } + } + } else { + logger.warn(`[${requestId}] No pricing entry for rerank model ${rerankerModel}`) + } + } + // Fetch tag definitions for display name mapping (reuse the same fetch from filtering) const tagDefsResults = await Promise.all( accessibleKbIds.map(async (kbId) => { @@ -400,6 +503,7 @@ export const POST = withRouteHandler(async (request: NextRequest) => { } }) + const rerankerScore = rerankedScores.get(result.id) return { documentId: result.documentId, documentName: documentNameMap[result.documentId] || undefined, @@ -407,6 +511,7 @@ export const POST = withRouteHandler(async (request: NextRequest) => { chunkIndex: result.chunkIndex, metadata: tags, // Clean display name mapped tags similarity: hasQuery ? 1 - result.distance : 1, // Perfect similarity for tag-only searches + ...(rerankerScore !== undefined && { rerankerScore }), } }), query: validatedData.query || '', @@ -414,19 +519,22 @@ export const POST = withRouteHandler(async (request: NextRequest) => { knowledgeBaseId: accessibleKbIds[0], topK: validatedData.topK, totalResults: results.length, - ...(cost && tokenCount + ...(cost ? { cost: { input: cost.input, output: cost.output, total: cost.total, tokens: { - prompt: tokenCount.count, + prompt: tokenCount?.count ?? 0, completion: 0, - total: tokenCount.count, + total: tokenCount?.count ?? 0, }, - model: 'text-embedding-3-small', + model: queryEmbeddingModel, pricing: cost.pricing, + ...(rerankBilled && !rerankIsBYOK + ? { rerankerCost, rerankerModel, rerankerSearchUnits: 1 } + : {}), }, } : {}), diff --git a/apps/sim/app/api/knowledge/search/utils.test.ts b/apps/sim/app/api/knowledge/search/utils.test.ts index 9fd4fa34538..9ebdbe89b3c 100644 --- a/apps/sim/app/api/knowledge/search/utils.test.ts +++ b/apps/sim/app/api/knowledge/search/utils.test.ts @@ -220,7 +220,7 @@ describe('Knowledge Search Utils', () => { Object.keys(env).forEach((key) => delete (env as any)[key]) }) - it('should use default API version when not provided in Azure config', async () => { + it('falls back to OpenAI when AZURE_OPENAI_API_VERSION is not set', async () => { const { env } = await import('@/lib/core/config/env') Object.keys(env).forEach((key) => delete (env as any)[key]) Object.assign(env, { @@ -240,7 +240,7 @@ describe('Knowledge Search Utils', () => { await generateSearchEmbedding('test query') expect(vi.mocked(fetch)).toHaveBeenCalledWith( - expect.stringContaining('api-version='), + 'https://api.openai.com/v1/embeddings', expect.any(Object) ) @@ -282,7 +282,7 @@ describe('Knowledge Search Utils', () => { Object.keys(env).forEach((key) => delete (env as any)[key]) await expect(generateSearchEmbedding('test query')).rejects.toThrow( - 'Either OPENAI_API_KEY or Azure OpenAI configuration (AZURE_OPENAI_API_KEY + AZURE_OPENAI_ENDPOINT) must be configured' + 'OPENAI_API_KEY is not configured' ) }) @@ -354,6 +354,7 @@ describe('Knowledge Search Utils', () => { body: JSON.stringify({ input: ['test query'], encoding_format: 'float', + dimensions: 1536, }), }) ) diff --git a/apps/sim/app/api/knowledge/utils.test.ts b/apps/sim/app/api/knowledge/utils.test.ts index 650c7b1dc6b..313d4e24f0a 100644 --- a/apps/sim/app/api/knowledge/utils.test.ts +++ b/apps/sim/app/api/knowledge/utils.test.ts @@ -212,6 +212,7 @@ describe('Knowledge Utils', () => { id: 'kb1', userId: 'user1', workspaceId: null, + embeddingModel: 'text-embedding-3-small', chunkingConfig: { maxSize: 1024, minSize: 1, overlap: 200 }, }) docRows.push({ id: 'doc1', knowledgeBaseId: 'kb1' }) @@ -370,7 +371,7 @@ describe('Knowledge Utils', () => { Object.keys(env).forEach((key) => delete (env as any)[key]) await expect(generateEmbeddings(['test text'])).rejects.toThrow( - 'Either OPENAI_API_KEY or Azure OpenAI configuration (AZURE_OPENAI_API_KEY + AZURE_OPENAI_ENDPOINT) must be configured' + 'OPENAI_API_KEY is not configured' ) }) }) diff --git a/apps/sim/app/api/knowledge/utils.ts b/apps/sim/app/api/knowledge/utils.ts index 60042ccccf1..bdb7066ab45 100644 --- a/apps/sim/app/api/knowledge/utils.ts +++ b/apps/sim/app/api/knowledge/utils.ts @@ -103,7 +103,10 @@ export interface EmbeddingData { export interface KnowledgeBaseAccessResult { hasAccess: true - knowledgeBase: Pick + knowledgeBase: Pick< + KnowledgeBaseData, + 'id' | 'userId' | 'workspaceId' | 'name' | 'embeddingModel' + > } export interface KnowledgeBaseAccessDenied { @@ -117,7 +120,10 @@ export type KnowledgeBaseAccessCheck = KnowledgeBaseAccessResult | KnowledgeBase export interface DocumentAccessResult { hasAccess: true document: DocumentData - knowledgeBase: Pick + knowledgeBase: Pick< + KnowledgeBaseData, + 'id' | 'userId' | 'workspaceId' | 'name' | 'embeddingModel' + > } export interface DocumentAccessDenied { @@ -132,7 +138,10 @@ export interface ChunkAccessResult { hasAccess: true chunk: EmbeddingData document: DocumentData - knowledgeBase: Pick + knowledgeBase: Pick< + KnowledgeBaseData, + 'id' | 'userId' | 'workspaceId' | 'name' | 'embeddingModel' + > } export interface ChunkAccessDenied { @@ -156,6 +165,7 @@ export async function checkKnowledgeBaseAccess( userId: knowledgeBase.userId, workspaceId: knowledgeBase.workspaceId, name: knowledgeBase.name, + embeddingModel: knowledgeBase.embeddingModel, }) .from(knowledgeBase) .where(and(eq(knowledgeBase.id, knowledgeBaseId), isNull(knowledgeBase.deletedAt))) @@ -200,6 +210,7 @@ export async function checkKnowledgeBaseWriteAccess( userId: knowledgeBase.userId, workspaceId: knowledgeBase.workspaceId, name: knowledgeBase.name, + embeddingModel: knowledgeBase.embeddingModel, }) .from(knowledgeBase) .where(and(eq(knowledgeBase.id, knowledgeBaseId), isNull(knowledgeBase.deletedAt))) diff --git a/apps/sim/app/api/v1/knowledge/route.ts b/apps/sim/app/api/v1/knowledge/route.ts index a24ce394966..3a85651b9d7 100644 --- a/apps/sim/app/api/v1/knowledge/route.ts +++ b/apps/sim/app/api/v1/knowledge/route.ts @@ -2,6 +2,7 @@ import { AuditAction, AuditResourceType, recordAudit } from '@sim/audit' import { type NextRequest, NextResponse } from 'next/server' import { z } from 'zod' import { withRouteHandler } from '@/lib/core/utils/with-route-handler' +import { EMBEDDING_DIMENSIONS, getConfiguredEmbeddingModel } from '@/lib/knowledge/embeddings' import { createKnowledgeBase, getKnowledgeBases } from '@/lib/knowledge/service' import { authenticateRequest, @@ -92,8 +93,8 @@ export const POST = withRouteHandler(async (request: NextRequest) => { description, workspaceId, userId, - embeddingModel: 'text-embedding-3-small', - embeddingDimension: 1536, + embeddingModel: getConfiguredEmbeddingModel(), + embeddingDimension: EMBEDDING_DIMENSIONS, chunkingConfig: chunkingConfig ?? { maxSize: 1024, minSize: 100, overlap: 200 }, }, requestId diff --git a/apps/sim/app/api/v1/knowledge/search/route.test.ts b/apps/sim/app/api/v1/knowledge/search/route.test.ts new file mode 100644 index 00000000000..517a30e5bd2 --- /dev/null +++ b/apps/sim/app/api/v1/knowledge/search/route.test.ts @@ -0,0 +1,179 @@ +/** + * Tests for v1 knowledge search API route. + * Specifically guards the per-KB embedding model resolution and the + * multi-model rejection so the v1 endpoint stays in lockstep with the + * internal route. + * + * @vitest-environment node + */ +import { createMockRequest, knowledgeApiUtilsMock, knowledgeApiUtilsMockFns } from '@sim/testing' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { + mockHandleVectorOnlySearch, + mockHandleTagOnlySearch, + mockHandleTagAndVectorSearch, + mockGetQueryStrategy, + mockGenerateSearchEmbedding, + mockGetDocumentNamesByIds, + mockAuthenticateRequest, + mockValidateWorkspaceAccess, +} = vi.hoisted(() => ({ + mockHandleVectorOnlySearch: vi.fn(), + mockHandleTagOnlySearch: vi.fn(), + mockHandleTagAndVectorSearch: vi.fn(), + mockGetQueryStrategy: vi.fn(), + mockGenerateSearchEmbedding: vi.fn(), + mockGetDocumentNamesByIds: vi.fn(), + mockAuthenticateRequest: vi.fn(), + mockValidateWorkspaceAccess: vi.fn(), +})) + +vi.mock('@/app/api/knowledge/search/utils', () => ({ + handleVectorOnlySearch: mockHandleVectorOnlySearch, + handleTagOnlySearch: mockHandleTagOnlySearch, + handleTagAndVectorSearch: mockHandleTagAndVectorSearch, + getQueryStrategy: mockGetQueryStrategy, + generateSearchEmbedding: mockGenerateSearchEmbedding, + getDocumentNamesByIds: mockGetDocumentNamesByIds, +})) + +vi.mock('@/app/api/knowledge/utils', () => knowledgeApiUtilsMock) + +vi.mock('@/app/api/v1/knowledge/utils', () => ({ + authenticateRequest: mockAuthenticateRequest, + validateWorkspaceAccess: mockValidateWorkspaceAccess, + parseJsonBody: async (req: Request) => { + try { + return { success: true, data: await req.json() } + } catch { + return { + success: false, + response: new Response(JSON.stringify({ error: 'Invalid JSON' }), { status: 400 }), + } + } + }, + validateSchema: ( + schema: { + safeParse: (v: unknown) => { + success: boolean + data?: T + error?: { issues: { message: string }[] } + } + }, + data: unknown + ) => { + const result = schema.safeParse(data) + if (!result.success) { + return { + success: false, + response: new Response( + JSON.stringify({ error: result.error?.issues.map((i) => i.message).join(', ') }), + { status: 400 } + ), + } + } + return { success: true, data: result.data } + }, + handleError: (e: unknown) => + new Response(JSON.stringify({ error: e instanceof Error ? e.message : 'error' }), { + status: 500, + }), +})) + +vi.mock('@/lib/knowledge/tags/service', () => ({ + getDocumentTagDefinitions: vi.fn().mockResolvedValue([]), +})) + +import { POST } from '@/app/api/v1/knowledge/search/route' + +const mockCheckKnowledgeBaseAccess = knowledgeApiUtilsMockFns.mockCheckKnowledgeBaseAccess + +const baseKb = (id: string, embeddingModel: string) => ({ + id, + userId: 'user-1', + name: `KB ${id}`, + workspaceId: 'ws-1', + embeddingModel, + deletedAt: null, +}) + +describe('v1 knowledge search route — per-KB embedding model', () => { + beforeEach(() => { + vi.clearAllMocks() + mockAuthenticateRequest.mockResolvedValue({ + requestId: 'req-1', + userId: 'user-1', + rateLimit: {}, + }) + mockValidateWorkspaceAccess.mockResolvedValue(null) + mockGetQueryStrategy.mockReturnValue({ distanceThreshold: 0.5 }) + mockGenerateSearchEmbedding.mockResolvedValue([0.1, 0.2, 0.3]) + mockHandleVectorOnlySearch.mockResolvedValue([]) + mockGetDocumentNamesByIds.mockResolvedValue({}) + }) + + it('passes the KB embedding model into generateSearchEmbedding', async () => { + mockCheckKnowledgeBaseAccess.mockResolvedValueOnce({ + hasAccess: true, + knowledgeBase: baseKb('kb-gemini', 'gemini-embedding-001'), + }) + + const req = createMockRequest('POST', { + workspaceId: 'ws-1', + knowledgeBaseIds: 'kb-gemini', + query: 'hello', + }) + const res = await POST(req) + + expect(res.status).toBe(200) + expect(mockGenerateSearchEmbedding).toHaveBeenCalledWith( + 'hello', + 'gemini-embedding-001', + 'ws-1' + ) + }) + + it('rejects cross-KB queries with mixed embedding models', async () => { + mockCheckKnowledgeBaseAccess + .mockResolvedValueOnce({ + hasAccess: true, + knowledgeBase: baseKb('kb-openai', 'text-embedding-3-small'), + }) + .mockResolvedValueOnce({ + hasAccess: true, + knowledgeBase: baseKb('kb-gemini', 'gemini-embedding-001'), + }) + + const req = createMockRequest('POST', { + workspaceId: 'ws-1', + knowledgeBaseIds: ['kb-openai', 'kb-gemini'], + query: 'hello', + }) + const res = await POST(req) + + expect(res.status).toBe(400) + expect(mockGenerateSearchEmbedding).not.toHaveBeenCalled() + }) + + it('allows tag-only search across mixed embedding models', async () => { + mockHandleTagOnlySearch.mockResolvedValue([]) + mockCheckKnowledgeBaseAccess.mockResolvedValueOnce({ + hasAccess: true, + knowledgeBase: baseKb('kb-mixed', 'text-embedding-3-small'), + }) + + const req = createMockRequest('POST', { + workspaceId: 'ws-1', + knowledgeBaseIds: 'kb-mixed', + tagFilters: [{ tagName: 'category', operator: 'eq', value: 'docs' }], + }) + const res = await POST(req) + + expect(res.status).toBe(400) + // tagName "category" is undefined in our empty getDocumentTagDefinitions mock, + // so the route returns 400 before reaching the search handlers — but crucially + // it never tries to generate an embedding. + expect(mockGenerateSearchEmbedding).not.toHaveBeenCalled() + }) +}) diff --git a/apps/sim/app/api/v1/knowledge/search/route.ts b/apps/sim/app/api/v1/knowledge/search/route.ts index 4a622ff0bcf..fc671c482af 100644 --- a/apps/sim/app/api/v1/knowledge/search/route.ts +++ b/apps/sim/app/api/v1/knowledge/search/route.ts @@ -14,7 +14,7 @@ import { handleVectorOnlySearch, type SearchResult, } from '@/app/api/knowledge/search/utils' -import { checkKnowledgeBaseAccess } from '@/app/api/knowledge/utils' +import { checkKnowledgeBaseAccess, type KnowledgeBaseAccessResult } from '@/app/api/knowledge/utils' import { authenticateRequest, handleError, @@ -84,11 +84,13 @@ export const POST = withRouteHandler(async (request: NextRequest) => { const accessChecks = await Promise.all( knowledgeBaseIds.map((kbId) => checkKnowledgeBaseAccess(kbId, userId)) ) - const accessibleKbIds = knowledgeBaseIds.filter( - (_, idx) => - accessChecks[idx]?.hasAccess && - accessChecks[idx]?.knowledgeBase?.workspaceId === workspaceId - ) + const accessibleKbs = accessChecks + .filter( + (ac): ac is KnowledgeBaseAccessResult => + ac.hasAccess === true && ac.knowledgeBase.workspaceId === workspaceId + ) + .map((ac) => ac.knowledgeBase) + const accessibleKbIds = accessibleKbs.map((kb) => kb.id) if (accessibleKbIds.length === 0) { return NextResponse.json( @@ -173,6 +175,18 @@ export const POST = withRouteHandler(async (request: NextRequest) => { const hasQuery = query && query.trim().length > 0 const hasFilters = structuredFilters.length > 0 + const embeddingModels = Array.from(new Set(accessibleKbs.map((kb) => kb.embeddingModel))) + if (hasQuery && embeddingModels.length > 1) { + return NextResponse.json( + { + error: + 'Selected knowledge bases use different embedding models and cannot be searched together. Search them separately.', + }, + { status: 400 } + ) + } + const queryEmbeddingModel = embeddingModels[0] + let results: SearchResult[] if (!hasQuery && hasFilters) { @@ -184,7 +198,7 @@ export const POST = withRouteHandler(async (request: NextRequest) => { } else if (hasQuery && hasFilters) { const strategy = getQueryStrategy(accessibleKbIds.length, topK) const queryVector = JSON.stringify( - await generateSearchEmbedding(query!, undefined, workspaceId) + await generateSearchEmbedding(query!, queryEmbeddingModel, workspaceId) ) results = await handleTagAndVectorSearch({ knowledgeBaseIds: accessibleKbIds, @@ -196,7 +210,7 @@ export const POST = withRouteHandler(async (request: NextRequest) => { } else if (hasQuery) { const strategy = getQueryStrategy(accessibleKbIds.length, topK) const queryVector = JSON.stringify( - await generateSearchEmbedding(query!, undefined, workspaceId) + await generateSearchEmbedding(query!, queryEmbeddingModel, workspaceId) ) results = await handleVectorOnlySearch({ knowledgeBaseIds: accessibleKbIds, diff --git a/apps/sim/blocks/blocks/knowledge.ts b/apps/sim/blocks/blocks/knowledge.ts index 34b449373a1..37c0d5b2914 100644 --- a/apps/sim/blocks/blocks/knowledge.ts +++ b/apps/sim/blocks/blocks/knowledge.ts @@ -1,4 +1,5 @@ import { PackageSearchIcon } from '@/components/icons' +import { DEFAULT_RERANKER_MODEL, SUPPORTED_RERANKER_MODELS } from '@/lib/knowledge/reranker-models' import type { BlockConfig } from '@/blocks/types' export const KnowledgeBlock: BlockConfig = { @@ -86,6 +87,24 @@ export const KnowledgeBlock: BlockConfig = { dependsOn: ['knowledgeBaseSelector'], condition: { field: 'operation', value: 'search' }, }, + { + id: 'rerankerEnabled', + title: 'Rerank Results', + type: 'switch', + condition: { field: 'operation', value: 'search' }, + }, + { + id: 'rerankerModel', + title: 'Rerank Model', + type: 'dropdown', + options: SUPPORTED_RERANKER_MODELS.map((id) => ({ label: id, id })), + value: () => DEFAULT_RERANKER_MODEL, + condition: { + field: 'operation', + value: 'search', + and: { field: 'rerankerEnabled', value: true }, + }, + }, // --- List Documents --- { @@ -397,6 +416,8 @@ export const KnowledgeBlock: BlockConfig = { limit: { type: 'number', description: 'Max items to return' }, offset: { type: 'number', description: 'Pagination offset' }, tagFilters: { type: 'string', description: 'Tag filter criteria' }, + rerankerEnabled: { type: 'boolean', description: 'Apply Cohere reranking to search results' }, + rerankerModel: { type: 'string', description: 'Cohere rerank model identifier' }, documentTags: { type: 'string', description: 'Document tags' }, chunkSearch: { type: 'string', description: 'Search filter for chunks' }, chunkEnabledFilter: { type: 'string', description: 'Filter chunks by enabled status' }, diff --git a/apps/sim/lib/chunkers/docs-chunker.ts b/apps/sim/lib/chunkers/docs-chunker.ts index ddfecc3ab19..26b1fe449a8 100644 --- a/apps/sim/lib/chunkers/docs-chunker.ts +++ b/apps/sim/lib/chunkers/docs-chunker.ts @@ -4,7 +4,7 @@ import { createLogger } from '@sim/logger' import { TextChunker } from '@/lib/chunkers/text-chunker' import type { DocChunk, DocsChunkerOptions } from '@/lib/chunkers/types' import { estimateTokens } from '@/lib/chunkers/utils' -import { generateEmbeddings } from '@/lib/knowledge/embeddings' +import { generateEmbeddings, getConfiguredEmbeddingModel } from '@/lib/knowledge/embeddings' interface HeaderInfo { level: number @@ -74,9 +74,9 @@ export class DocsChunker { const headers = this.extractHeaders(cleanedContent) logger.info(`Generating embeddings for ${textChunks.length} chunks in ${relativePath}`) + const embeddingModel = getConfiguredEmbeddingModel() const embeddings: number[][] = - textChunks.length > 0 ? (await generateEmbeddings(textChunks)).embeddings : [] - const embeddingModel = 'text-embedding-3-small' + textChunks.length > 0 ? (await generateEmbeddings(textChunks, embeddingModel)).embeddings : [] const chunks: DocChunk[] = [] let currentPosition = 0 diff --git a/apps/sim/lib/copilot/tools/server/knowledge/knowledge-base.ts b/apps/sim/lib/copilot/tools/server/knowledge/knowledge-base.ts index 46d6062eb93..60d093c6a65 100644 --- a/apps/sim/lib/copilot/tools/server/knowledge/knowledge-base.ts +++ b/apps/sim/lib/copilot/tools/server/knowledge/knowledge-base.ts @@ -18,7 +18,11 @@ import { processDocumentAsync, updateDocument, } from '@/lib/knowledge/documents/service' -import { generateSearchEmbedding } from '@/lib/knowledge/embeddings' +import { + EMBEDDING_DIMENSIONS, + generateSearchEmbedding, + getConfiguredEmbeddingModel, +} from '@/lib/knowledge/embeddings' import { createKnowledgeBase, deleteKnowledgeBase, @@ -107,8 +111,8 @@ export const knowledgeBaseServerTool: BaseServerTool { logger.info(`[${requestId}] Generating embedding for manual chunk`) - const { embeddings } = await generateEmbeddings([chunkData.content], undefined, workspaceId) + const kbRow = await db + .select({ embeddingModel: knowledgeBase.embeddingModel }) + .from(knowledgeBase) + .where(and(eq(knowledgeBase.id, knowledgeBaseId), isNull(knowledgeBase.deletedAt))) + .limit(1) + if (kbRow.length === 0) { + throw new Error('Knowledge base not found') + } + const kbEmbeddingModel = kbRow[0].embeddingModel + const { embeddings } = await generateEmbeddings( + [chunkData.content], + kbEmbeddingModel, + workspaceId + ) - // Calculate accurate token count - const tokenCount = estimateTokenCount(chunkData.content, 'openai') + const tokenCount = estimateTokenCount( + chunkData.content, + getEmbeddingModelInfo(kbEmbeddingModel).tokenizerProvider + ) const chunkId = generateId() const now = new Date() @@ -160,7 +176,7 @@ export async function createChunk( contentLength: chunkData.content.length, tokenCount: tokenCount.count, embedding: embeddings[0], - embeddingModel: 'text-embedding-3-small', + embeddingModel: kbEmbeddingModel, startOffset: 0, // Manual chunks don't have document offsets endOffset: chunkData.content.length, // Inherit text tags from parent document @@ -360,10 +376,22 @@ export async function updateChunk( if (content !== currentChunk[0].content) { logger.info(`[${requestId}] Content changed, regenerating embedding for chunk ${chunkId}`) - const { embeddings } = await generateEmbeddings([content], undefined, workspaceId) - - // Calculate accurate token count - const tokenCount = estimateTokenCount(content, 'openai') + const kbRow = await tx + .select({ embeddingModel: knowledgeBase.embeddingModel }) + .from(knowledgeBase) + .innerJoin(document, eq(document.knowledgeBaseId, knowledgeBase.id)) + .where(eq(document.id, currentChunk[0].documentId)) + .limit(1) + const chunkEmbeddingModel = kbRow[0]?.embeddingModel + if (!chunkEmbeddingModel) { + throw new Error('Knowledge base for chunk not found') + } + const { embeddings } = await generateEmbeddings([content], chunkEmbeddingModel, workspaceId) + + const tokenCount = estimateTokenCount( + content, + getEmbeddingModelInfo(chunkEmbeddingModel).tokenizerProvider + ) dbUpdateData.content = content dbUpdateData.contentLength = newContentLength diff --git a/apps/sim/lib/knowledge/documents/service.ts b/apps/sim/lib/knowledge/documents/service.ts index 774f30d9867..3e8a0b26cce 100644 --- a/apps/sim/lib/knowledge/documents/service.ts +++ b/apps/sim/lib/knowledge/documents/service.ts @@ -34,6 +34,7 @@ import { env } from '@/lib/core/config/env' import { getCostMultiplier, isTriggerDevEnabled } from '@/lib/core/config/feature-flags' import { processDocument } from '@/lib/knowledge/documents/document-processor' import type { DocumentSortField, SortOrder } from '@/lib/knowledge/documents/types' +import { getEmbeddingModelInfo } from '@/lib/knowledge/embedding-models' import { generateEmbeddings } from '@/lib/knowledge/embeddings' import { buildUndefinedTagsError, @@ -43,6 +44,7 @@ import { validateTagValue, } from '@/lib/knowledge/tags/utils' import type { ProcessedDocumentTags } from '@/lib/knowledge/types' +import { estimateTokenCount } from '@/lib/tokenization/estimators' import { deleteFile } from '@/lib/uploads/core/storage-service' import { extractStorageKey } from '@/lib/uploads/utils/file-utils' import type { DocumentProcessingPayload } from '@/background/knowledge-processing' @@ -380,6 +382,7 @@ export async function processDocumentAsync( userId: knowledgeBase.userId, workspaceId: knowledgeBase.workspaceId, chunkingConfig: knowledgeBase.chunkingConfig, + embeddingModel: knowledgeBase.embeddingModel, }) .from(knowledgeBase) .where(and(eq(knowledgeBase.id, knowledgeBaseId), isNull(knowledgeBase.deletedAt))) @@ -429,9 +432,11 @@ export async function processDocumentAsync( overlap: rawConfig?.overlap ?? 200, } + const kbEmbeddingModel = kb[0].embeddingModel let totalEmbeddingTokens = 0 let embeddingIsBYOK = false - let embeddingModelName = 'text-embedding-3-small' + let embeddingModelName = kbEmbeddingModel + let embeddingPricingId = kbEmbeddingModel await withTimeout( (async () => { @@ -480,7 +485,8 @@ export async function processDocumentAsync( totalTokens: batchTokens, isBYOK, modelName, - } = await generateEmbeddings(batch, undefined, kb[0].workspaceId) + pricingId, + } = await generateEmbeddings(batch, kbEmbeddingModel, kb[0].workspaceId) for (const emb of batchEmbeddings) { embeddings.push(emb) } @@ -488,6 +494,7 @@ export async function processDocumentAsync( if (i === 0) { embeddingIsBYOK = isBYOK embeddingModelName = modelName + embeddingPricingId = pricingId } } } @@ -528,6 +535,8 @@ export async function processDocumentAsync( logger.info(`[${documentId}] Creating embedding records with tags`) + const tokenizerProvider = getEmbeddingModelInfo(kbEmbeddingModel).tokenizerProvider + const embeddingRecords = processed.chunks.map((chunk, chunkIndex) => ({ id: generateId(), knowledgeBaseId, @@ -536,9 +545,9 @@ export async function processDocumentAsync( chunkHash: sha256Hex(chunk.text), content: chunk.text, contentLength: chunk.text.length, - tokenCount: Math.ceil(chunk.text.length / 4), + tokenCount: estimateTokenCount(chunk.text, tokenizerProvider).count, embedding: embeddings[chunkIndex] || null, - embeddingModel: 'text-embedding-3-small', + embeddingModel: kbEmbeddingModel, startOffset: chunk.metadata.startIndex, endOffset: chunk.metadata.endIndex, tag1: documentTags.tag1, @@ -620,7 +629,7 @@ export async function processDocumentAsync( try { const costMultiplier = getCostMultiplier() const { total: cost } = calculateCost( - embeddingModelName, + embeddingPricingId, totalEmbeddingTokens, 0, false, diff --git a/apps/sim/lib/knowledge/embedding-models.ts b/apps/sim/lib/knowledge/embedding-models.ts new file mode 100644 index 00000000000..5d837a1fbb8 --- /dev/null +++ b/apps/sim/lib/knowledge/embedding-models.ts @@ -0,0 +1,48 @@ +/** + * Registry of embedding models supported by the platform. + * Selection happens server-side via the `KB_EMBEDDING_MODEL` env var; this + * registry exists to resolve provider, tokenizer, and pricing metadata at + * runtime for any model recorded on a knowledge base row. + */ + +export const EMBEDDING_DIMENSIONS = 1536 as const + +export const DEFAULT_EMBEDDING_MODEL = 'text-embedding-3-small' + +export type EmbeddingProviderKind = 'openai' | 'azure-openai' | 'gemini' + +export type TokenizerProviderId = 'openai' | 'google' + +export interface EmbeddingModelInfo { + provider: EmbeddingProviderKind + /** Pricing/billing label — must match an entry in EMBEDDING_MODEL_PRICING when billed. */ + pricingId: string + /** Provider id for `estimateTokenCount` so token counts match the embedding provider's tokenization. */ + tokenizerProvider: TokenizerProviderId +} + +export const SUPPORTED_EMBEDDING_MODELS: Partial> = { + 'text-embedding-3-small': { + provider: 'openai', + pricingId: 'text-embedding-3-small', + tokenizerProvider: 'openai', + }, + 'text-embedding-3-large': { + provider: 'openai', + pricingId: 'text-embedding-3-large', + tokenizerProvider: 'openai', + }, + 'gemini-embedding-001': { + provider: 'gemini', + pricingId: 'gemini-embedding-001', + tokenizerProvider: 'google', + }, +} + +export function getEmbeddingModelInfo(model: string): EmbeddingModelInfo { + const info = SUPPORTED_EMBEDDING_MODELS[model] + if (!info) { + throw new Error(`Unsupported embedding model: ${model}`) + } + return info +} diff --git a/apps/sim/lib/knowledge/embeddings.ts b/apps/sim/lib/knowledge/embeddings.ts index bff67ab41c0..1791cc08499 100644 --- a/apps/sim/lib/knowledge/embeddings.ts +++ b/apps/sim/lib/knowledge/embeddings.ts @@ -1,24 +1,30 @@ import { createLogger } from '@sim/logger' import { getBYOKKey } from '@/lib/api-key/byok' +import { getRotatingApiKey } from '@/lib/core/config/api-keys' import { env } from '@/lib/core/config/env' import { isRetryableError, retryWithExponentialBackoff } from '@/lib/knowledge/documents/utils' -import { batchByTokenLimit } from '@/lib/tokenization' +import { + DEFAULT_EMBEDDING_MODEL, + EMBEDDING_DIMENSIONS, + getEmbeddingModelInfo, + SUPPORTED_EMBEDDING_MODELS, + type TokenizerProviderId, +} from '@/lib/knowledge/embedding-models' +import { batchByTokenLimit, estimateTokenCount } from '@/lib/tokenization' const logger = createLogger('EmbeddingUtils') const MAX_TOKENS_PER_REQUEST = 8000 const MAX_CONCURRENT_BATCHES = env.KB_CONFIG_CONCURRENCY_LIMIT || 50 -const EMBEDDING_DIMENSIONS = 1536 +const EMBEDDING_REQUEST_TIMEOUT_MS = 60_000 -/** - * Check if the model supports custom dimensions. - * text-embedding-3-* models support the dimensions parameter. - * Checks for 'embedding-3' to handle Azure deployments with custom naming conventions. - */ -function supportsCustomDimensions(modelName: string): boolean { - const name = modelName.toLowerCase() - return name.includes('embedding-3') && !name.includes('ada') -} +export type { EmbeddingModelInfo } from '@/lib/knowledge/embedding-models' +export { + DEFAULT_EMBEDDING_MODEL, + EMBEDDING_DIMENSIONS, + getEmbeddingModelInfo, + SUPPORTED_EMBEDDING_MODELS, +} from '@/lib/knowledge/embedding-models' export class EmbeddingAPIError extends Error { public status: number @@ -30,112 +36,245 @@ export class EmbeddingAPIError extends Error { } } -interface EmbeddingConfig { - useAzure: boolean +export type EmbeddingInputType = 'document' | 'query' + +interface ProviderRequest { apiUrl: string headers: Record + body: unknown + parse: (json: unknown) => number[][] +} + +interface ResolvedProvider { modelName: string + pricingId: string isBYOK: boolean + /** Tokenizer used to estimate tokens when the API does not return a usage field. */ + tokenizerProvider: TokenizerProviderId + /** Hard per-request item cap enforced by the provider (e.g. Gemini caps at 100). */ + maxItemsPerRequest?: number + buildRequest: (inputs: string[], inputType: EmbeddingInputType) => ProviderRequest } -interface EmbeddingResponseItem { - embedding: number[] - index: number +/** Gemini's `batchEmbedContents` rejects requests with more than 100 items. */ +const GEMINI_MAX_ITEMS_PER_REQUEST = 100 + +async function resolveOpenAIKey(workspaceId?: string | null): Promise<{ + apiKey: string + isBYOK: boolean +}> { + if (workspaceId) { + const byokResult = await getBYOKKey(workspaceId, 'openai') + if (byokResult) { + logger.info('Using workspace BYOK key for OpenAI embeddings') + return { apiKey: byokResult.apiKey, isBYOK: true } + } + } + if (env.OPENAI_API_KEY) { + return { apiKey: env.OPENAI_API_KEY, isBYOK: false } + } + try { + return { apiKey: getRotatingApiKey('openai'), isBYOK: false } + } catch { + throw new Error('OPENAI_API_KEY is not configured') + } } -interface EmbeddingAPIResponse { - data: EmbeddingResponseItem[] - model: string - usage: { - prompt_tokens: number - total_tokens: number +async function resolveGeminiKey(workspaceId?: string | null): Promise<{ + apiKey: string + isBYOK: boolean +}> { + if (workspaceId) { + const byokResult = await getBYOKKey(workspaceId, 'google') + if (byokResult) { + logger.info('Using workspace BYOK key for Gemini embeddings') + return { apiKey: byokResult.apiKey, isBYOK: true } + } } + if (env.GEMINI_API_KEY) { + return { apiKey: env.GEMINI_API_KEY, isBYOK: false } + } + try { + return { apiKey: getRotatingApiKey('gemini'), isBYOK: false } + } catch { + throw new Error( + 'GEMINI_API_KEY (or GEMINI_API_KEY_1/2/3 for rotation) must be configured for Gemini embeddings' + ) + } +} + +function buildOpenAIProvider(modelName: string, apiKey: string): ResolvedProvider['buildRequest'] { + return (inputs) => ({ + apiUrl: 'https://api.openai.com/v1/embeddings', + headers: { + Authorization: `Bearer ${apiKey}`, + 'Content-Type': 'application/json', + }, + body: { + input: inputs, + model: modelName, + encoding_format: 'float', + dimensions: EMBEDDING_DIMENSIONS, + }, + parse: (json) => { + const data = json as { data: Array<{ embedding: number[] }> } + return data.data.map((item) => item.embedding) + }, + }) } -async function getEmbeddingConfig( - embeddingModel = 'text-embedding-3-small', +function buildAzureOpenAIProvider( + deployment: string, + apiKey: string, + endpoint: string, + apiVersion: string +): ResolvedProvider['buildRequest'] { + return (inputs) => ({ + apiUrl: `${endpoint}/openai/deployments/${deployment}/embeddings?api-version=${apiVersion}`, + headers: { + 'api-key': apiKey, + 'Content-Type': 'application/json', + }, + body: { + input: inputs, + encoding_format: 'float', + dimensions: EMBEDDING_DIMENSIONS, + }, + parse: (json) => { + const data = json as { data: Array<{ embedding: number[] }> } + return data.data.map((item) => item.embedding) + }, + }) +} + +/** + * Gemini does NOT auto-normalize embeddings when `outputDimensionality` is set below the + * native 3072 dimension on `gemini-embedding-001`. Manually L2-normalize so cosine and + * inner-product similarity work correctly. + */ +function l2Normalize(vector: number[]): number[] { + let sumSquares = 0 + for (const v of vector) sumSquares += v * v + const norm = Math.sqrt(sumSquares) + if (norm === 0) return vector + return vector.map((v) => v / norm) +} + +function buildGeminiProvider(modelName: string, apiKey: string): ResolvedProvider['buildRequest'] { + return (inputs, inputType) => ({ + apiUrl: `https://generativelanguage.googleapis.com/v1beta/models/${modelName}:batchEmbedContents`, + headers: { + 'Content-Type': 'application/json', + 'x-goog-api-key': apiKey, + }, + body: { + requests: inputs.map((text) => ({ + model: `models/${modelName}`, + content: { parts: [{ text }] }, + taskType: inputType === 'query' ? 'RETRIEVAL_QUERY' : 'RETRIEVAL_DOCUMENT', + outputDimensionality: EMBEDDING_DIMENSIONS, + })), + }, + parse: (json) => { + const data = json as { embeddings: Array<{ values: number[] }> } + return data.embeddings.map((item) => l2Normalize(item.values)) + }, + }) +} + +/** + * Returns the embedding model to use for new knowledge bases. + * Sourced from the `KB_EMBEDDING_MODEL` env var; falls back to the default if + * unset or set to an unsupported model. + */ +export function getConfiguredEmbeddingModel(): string { + const configured = env.KB_EMBEDDING_MODEL + if (configured && SUPPORTED_EMBEDDING_MODELS[configured]) { + return configured + } + if (configured) { + logger.warn( + `KB_EMBEDDING_MODEL="${configured}" is not a supported embedding model — falling back to ${DEFAULT_EMBEDDING_MODEL}` + ) + } + return DEFAULT_EMBEDDING_MODEL +} + +async function resolveProvider( + embeddingModel: string, workspaceId?: string | null -): Promise { +): Promise { const azureApiKey = env.AZURE_OPENAI_API_KEY const azureEndpoint = env.AZURE_OPENAI_ENDPOINT const azureApiVersion = env.AZURE_OPENAI_API_VERSION - const kbModelName = env.KB_OPENAI_MODEL_NAME || embeddingModel + const isOpenAIModel = SUPPORTED_EMBEDDING_MODELS[embeddingModel]?.provider === 'openai' + /** + * Azure deployment names default to the embedding model name when + * `KB_OPENAI_MODEL_NAME` is unset — this matches the pre-existing + * convention where deployments are named after the model they host. + */ + const azureDeploymentName = env.KB_OPENAI_MODEL_NAME || embeddingModel + const useAzure = Boolean(isOpenAIModel && azureApiKey && azureEndpoint && azureApiVersion) - const useAzure = !!(azureApiKey && azureEndpoint) + const info = getEmbeddingModelInfo(embeddingModel) if (useAzure) { return { - useAzure: true, - apiUrl: `${azureEndpoint}/openai/deployments/${kbModelName}/embeddings?api-version=${azureApiVersion}`, - headers: { - 'api-key': azureApiKey!, - 'Content-Type': 'application/json', - }, - modelName: kbModelName, + modelName: azureDeploymentName, + pricingId: info.pricingId, isBYOK: false, + tokenizerProvider: info.tokenizerProvider, + buildRequest: buildAzureOpenAIProvider( + azureDeploymentName, + azureApiKey!, + azureEndpoint!, + azureApiVersion! + ), } } - let openaiApiKey = env.OPENAI_API_KEY - let isBYOK = false - - if (workspaceId) { - const byokResult = await getBYOKKey(workspaceId, 'openai') - if (byokResult) { - logger.info('Using workspace BYOK key for OpenAI embeddings') - openaiApiKey = byokResult.apiKey - isBYOK = true + if (info.provider === 'openai') { + const { apiKey, isBYOK } = await resolveOpenAIKey(workspaceId) + return { + modelName: embeddingModel, + pricingId: info.pricingId, + isBYOK, + tokenizerProvider: info.tokenizerProvider, + buildRequest: buildOpenAIProvider(embeddingModel, apiKey), } } - if (!openaiApiKey) { - throw new Error( - 'Either OPENAI_API_KEY or Azure OpenAI configuration (AZURE_OPENAI_API_KEY + AZURE_OPENAI_ENDPOINT) must be configured' - ) + if (info.provider === 'gemini') { + const { apiKey, isBYOK } = await resolveGeminiKey(workspaceId) + return { + modelName: embeddingModel, + pricingId: info.pricingId, + isBYOK, + tokenizerProvider: info.tokenizerProvider, + maxItemsPerRequest: GEMINI_MAX_ITEMS_PER_REQUEST, + buildRequest: buildGeminiProvider(embeddingModel, apiKey), + } } - return { - useAzure: false, - apiUrl: 'https://api.openai.com/v1/embeddings', - headers: { - Authorization: `Bearer ${openaiApiKey}`, - 'Content-Type': 'application/json', - }, - modelName: embeddingModel, - isBYOK, - } + throw new Error(`Unknown embedding provider for model ${embeddingModel}`) } -const EMBEDDING_REQUEST_TIMEOUT_MS = 60_000 - async function callEmbeddingAPI( inputs: string[], - config: EmbeddingConfig + provider: ResolvedProvider, + inputType: EmbeddingInputType ): Promise<{ embeddings: number[][]; totalTokens: number }> { return retryWithExponentialBackoff( async () => { - const useDimensions = supportsCustomDimensions(config.modelName) - - const requestBody = config.useAzure - ? { - input: inputs, - encoding_format: 'float', - ...(useDimensions && { dimensions: EMBEDDING_DIMENSIONS }), - } - : { - input: inputs, - model: config.modelName, - encoding_format: 'float', - ...(useDimensions && { dimensions: EMBEDDING_DIMENSIONS }), - } + const request = provider.buildRequest(inputs, inputType) const controller = new AbortController() const timeout = setTimeout(() => controller.abort(), EMBEDDING_REQUEST_TIMEOUT_MS) - const response = await fetch(config.apiUrl, { + const response = await fetch(request.apiUrl, { method: 'POST', - headers: config.headers, - body: JSON.stringify(requestBody), + headers: request.headers, + body: JSON.stringify(request.body), signal: controller.signal, }).finally(() => clearTimeout(timeout)) @@ -147,11 +286,18 @@ async function callEmbeddingAPI( ) } - const data: EmbeddingAPIResponse = await response.json() - return { - embeddings: data.data.map((item) => item.embedding), - totalTokens: data.usage.total_tokens, - } + const json = await response.json() + const embeddings = request.parse(json) + const usage = (json as { usage?: { total_tokens?: number } }).usage + const totalTokens = + usage?.total_tokens ?? + // Gemini does not return usage.total_tokens — estimate with the provider's tokenizer + inputs.reduce( + (sum, text) => sum + estimateTokenCount(text, provider.tokenizerProvider).count, + 0 + ) + + return { embeddings, totalTokens } }, { maxRetries: 3, @@ -167,9 +313,15 @@ async function callEmbeddingAPI( ) } -/** - * Process batches with controlled concurrency - */ +function splitByItemLimit(items: T[], limit: number): T[][] { + if (items.length <= limit) return [items] + const result: T[][] = [] + for (let i = 0; i < items.length; i += limit) { + result.push(items.slice(i, i + limit)) + } + return result +} + async function processWithConcurrency( items: T[], concurrency: number, @@ -194,28 +346,31 @@ export interface GenerateEmbeddingsResult { totalTokens: number isBYOK: boolean modelName: string + /** Pricing identifier for use with calculateCost / EMBEDDING_MODEL_PRICING. */ + pricingId: string } /** * Generate embeddings for multiple texts with token-aware batching and parallel processing. - * Returns embeddings alongside actual token count, model name, and whether a workspace BYOK key - * was used (vs. the platform's shared key) — enabling callers to make correct billing decisions. */ export async function generateEmbeddings( texts: string[], - embeddingModel = 'text-embedding-3-small', + embeddingModel: string = DEFAULT_EMBEDDING_MODEL, workspaceId?: string | null ): Promise { - const config = await getEmbeddingConfig(embeddingModel, workspaceId) + const provider = await resolveProvider(embeddingModel, workspaceId) - const batches = batchByTokenLimit(texts, MAX_TOKENS_PER_REQUEST, embeddingModel) + const tokenBatches = batchByTokenLimit(texts, MAX_TOKENS_PER_REQUEST, embeddingModel) + const batches = provider.maxItemsPerRequest + ? tokenBatches.flatMap((batch) => splitByItemLimit(batch, provider.maxItemsPerRequest!)) + : tokenBatches const batchResults = await processWithConcurrency( batches, MAX_CONCURRENT_BATCHES, async (batch, i) => { try { - return await callEmbeddingAPI(batch, config) + return await callEmbeddingAPI(batch, provider, 'document') } catch (error) { logger.error(`Failed to generate embeddings for batch ${i + 1}/${batches.length}:`, error) throw error @@ -235,25 +390,24 @@ export async function generateEmbeddings( return { embeddings: allEmbeddings, totalTokens, - isBYOK: config.isBYOK, - modelName: config.modelName, + isBYOK: provider.isBYOK, + modelName: provider.modelName, + pricingId: provider.pricingId, } } /** - * Generate embedding for a single search query + * Generate embedding for a single search query. */ export async function generateSearchEmbedding( query: string, - embeddingModel = 'text-embedding-3-small', + embeddingModel: string = DEFAULT_EMBEDDING_MODEL, workspaceId?: string | null ): Promise { - const config = await getEmbeddingConfig(embeddingModel, workspaceId) + const provider = await resolveProvider(embeddingModel, workspaceId) - logger.info( - `Using ${config.useAzure ? 'Azure OpenAI' : 'OpenAI'} for search embedding generation` - ) + logger.info(`Using ${provider.modelName} for search embedding generation`) - const { embeddings } = await callEmbeddingAPI([query], config) + const { embeddings } = await callEmbeddingAPI([query], provider, 'query') return embeddings[0] } diff --git a/apps/sim/lib/knowledge/reranker-models.ts b/apps/sim/lib/knowledge/reranker-models.ts new file mode 100644 index 00000000000..5c3d17bff89 --- /dev/null +++ b/apps/sim/lib/knowledge/reranker-models.ts @@ -0,0 +1,18 @@ +/** + * Client-safe registry of Cohere rerank models supported by the platform. + * Kept free of server imports so it can be imported into UI / block code. + */ + +/** Cohere rerank model identifiers we accept. Must match Cohere's model ids exactly. */ +export const SUPPORTED_RERANKER_MODELS = [ + 'rerank-v4.0-pro', + 'rerank-v4.0-fast', + 'rerank-v3.5', +] as const +export type RerankerModelId = (typeof SUPPORTED_RERANKER_MODELS)[number] + +export const DEFAULT_RERANKER_MODEL: RerankerModelId = 'rerank-v4.0-fast' + +export function isSupportedRerankerModel(model: string): model is RerankerModelId { + return (SUPPORTED_RERANKER_MODELS as readonly string[]).includes(model) +} diff --git a/apps/sim/lib/knowledge/reranker.ts b/apps/sim/lib/knowledge/reranker.ts new file mode 100644 index 00000000000..54b2ae02c91 --- /dev/null +++ b/apps/sim/lib/knowledge/reranker.ts @@ -0,0 +1,163 @@ +import { createLogger } from '@sim/logger' +import { getBYOKKey } from '@/lib/api-key/byok' +import { getRotatingApiKey } from '@/lib/core/config/api-keys' +import { env } from '@/lib/core/config/env' +import { isRetryableError, retryWithExponentialBackoff } from '@/lib/knowledge/documents/utils' +import { + DEFAULT_RERANKER_MODEL, + isSupportedRerankerModel, + type RerankerModelId, + SUPPORTED_RERANKER_MODELS, +} from '@/lib/knowledge/reranker-models' + +export { + DEFAULT_RERANKER_MODEL, + isSupportedRerankerModel, + type RerankerModelId, + SUPPORTED_RERANKER_MODELS, +} + +const logger = createLogger('Reranker') + +const RERANK_REQUEST_TIMEOUT_MS = 30_000 + +/** + * Cohere bills per "search unit" = one query with up to 100 documents. + * We cap at 100 so each rerank call costs exactly 1 unit and matches + * `RERANK_MODEL_PRICING` in `providers/models.ts`. The search route also + * caps `candidateTopK` at 100, so this is a defensive ceiling. + */ +const MAX_DOCUMENTS_PER_RERANK = 100 + +export interface RerankItem { + /** Stable identifier so callers can correlate ranked results back to source rows. */ + id: string + text: string +} + +export interface RerankedResult { + item: T + relevanceScore: number +} + +export interface RerankResponse { + results: RerankedResult[] + /** True when a workspace-supplied (BYOK) Cohere key was used. Callers should skip platform billing in that case. */ + isBYOK: boolean +} + +class RerankAPIError extends Error { + public status: number + constructor(message: string, status: number) { + super(message) + this.name = 'RerankAPIError' + this.status = status + } +} + +async function resolveCohereKey( + workspaceId?: string | null +): Promise<{ apiKey: string; isBYOK: boolean }> { + if (workspaceId) { + const byokResult = await getBYOKKey(workspaceId, 'cohere') + if (byokResult) { + logger.info('Using workspace BYOK key for Cohere reranker') + return { apiKey: byokResult.apiKey, isBYOK: true } + } + } + if (env.COHERE_API_KEY) { + return { apiKey: env.COHERE_API_KEY, isBYOK: false } + } + try { + return { apiKey: getRotatingApiKey('cohere'), isBYOK: false } + } catch { + throw new Error( + 'No Cohere API key configured. Set COHERE_API_KEY_1/2/3 (rotation) or COHERE_API_KEY.' + ) + } +} + +interface CohereRerankResponse { + results: Array<{ index: number; relevance_score: number }> +} + +/** + * Rerank documents against a query using Cohere's `/v2/rerank` endpoint. + * Returns the items in descending order of relevance, capped at `topN`. + */ +export async function rerank( + query: string, + items: T[], + options: { + model: string + topN?: number + workspaceId?: string | null + } +): Promise> { + if (items.length === 0) return { results: [], isBYOK: false } + + if (!isSupportedRerankerModel(options.model)) { + throw new Error(`Unsupported reranker model: ${options.model}`) + } + + const { apiKey, isBYOK } = await resolveCohereKey(options.workspaceId) + const cappedItems = + items.length > MAX_DOCUMENTS_PER_RERANK ? items.slice(0, MAX_DOCUMENTS_PER_RERANK) : items + if (items.length > MAX_DOCUMENTS_PER_RERANK) { + logger.warn(`Rerank input capped from ${items.length} to ${MAX_DOCUMENTS_PER_RERANK} documents`) + } + const documents = cappedItems.map((it) => it.text) + + const response = await retryWithExponentialBackoff( + async () => { + const controller = new AbortController() + const timeout = setTimeout(() => controller.abort(), RERANK_REQUEST_TIMEOUT_MS) + + const res = await fetch('https://api.cohere.com/v2/rerank', { + method: 'POST', + headers: { + Authorization: `Bearer ${apiKey}`, + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + model: options.model, + query, + documents, + top_n: options.topN ?? cappedItems.length, + }), + signal: controller.signal, + }).finally(() => clearTimeout(timeout)) + + if (!res.ok) { + const errorText = await res.text() + throw new RerankAPIError( + `Cohere rerank failed: ${res.status} ${res.statusText} - ${errorText}`, + res.status + ) + } + + return (await res.json()) as CohereRerankResponse + }, + { + maxRetries: 3, + initialDelayMs: 500, + maxDelayMs: 5000, + retryCondition: (error: unknown) => { + if (error instanceof RerankAPIError) { + return error.status === 429 || error.status >= 500 + } + return isRetryableError(error) + }, + } + ) + + return { + results: response.results + .filter((r) => r.index >= 0 && r.index < cappedItems.length) + .map((r) => ({ + item: cappedItems[r.index], + relevanceScore: r.relevance_score, + })), + isBYOK, + } +} diff --git a/apps/sim/lib/knowledge/service.ts b/apps/sim/lib/knowledge/service.ts index 00f4326a063..0ba5de8162a 100644 --- a/apps/sim/lib/knowledge/service.ts +++ b/apps/sim/lib/knowledge/service.ts @@ -250,8 +250,6 @@ export async function updateKnowledgeBase( if (updates.workspaceId !== undefined) updateData.workspaceId = updates.workspaceId if (updates.chunkingConfig !== undefined) { updateData.chunkingConfig = updates.chunkingConfig - updateData.embeddingModel = 'text-embedding-3-small' - updateData.embeddingDimension = 1536 } if (updates.name !== undefined) { diff --git a/apps/sim/lib/knowledge/types.ts b/apps/sim/lib/knowledge/types.ts index 6fe1a8bbaff..03e3475c285 100644 --- a/apps/sim/lib/knowledge/types.ts +++ b/apps/sim/lib/knowledge/types.ts @@ -34,7 +34,7 @@ export interface CreateKnowledgeBaseData { name: string description?: string workspaceId: string - embeddingModel: 'text-embedding-3-small' + embeddingModel: string embeddingDimension: 1536 chunkingConfig: ChunkingConfig userId: string diff --git a/apps/sim/providers/models.ts b/apps/sim/providers/models.ts index 05f50e9aeec..04942c74113 100644 --- a/apps/sim/providers/models.ts +++ b/apps/sim/providers/models.ts @@ -3023,12 +3023,33 @@ export const EMBEDDING_MODEL_PRICING: Record = { output: 0.0, updatedAt: '2026-04-01', }, + 'gemini-embedding-001': { + input: 0.15, // $0.15 per 1M tokens + output: 0.0, + updatedAt: '2026-04-29', + }, } export function getEmbeddingModelPricing(modelId: string): ModelPricing | null { return EMBEDDING_MODEL_PRICING[modelId] || null } +/** + * Cohere rerank pricing in USD per single search unit (one query × ≤100 docs). + * Sim caps every rerank request to ≤100 documents, so each call = 1 unit. + */ +export const RERANK_MODEL_PRICING: Record = { + 'rerank-v4.0-pro': { perSearchUnit: 0.0025, updatedAt: '2026-04-29' }, + 'rerank-v4.0-fast': { perSearchUnit: 0.002, updatedAt: '2026-04-29' }, + 'rerank-v3.5': { perSearchUnit: 0.002, updatedAt: '2026-04-29' }, +} + +export function getRerankModelPricing( + modelId: string +): { perSearchUnit: number; updatedAt: string } | null { + return RERANK_MODEL_PRICING[modelId] || null +} + export function getModelsWithReasoningEffort(): string[] { const models: string[] = [] for (const provider of Object.values(PROVIDER_DEFINITIONS)) { diff --git a/apps/sim/tools/knowledge/search.ts b/apps/sim/tools/knowledge/search.ts index af82111adc8..276241caab4 100644 --- a/apps/sim/tools/knowledge/search.ts +++ b/apps/sim/tools/knowledge/search.ts @@ -1,3 +1,4 @@ +import { DEFAULT_RERANKER_MODEL, SUPPORTED_RERANKER_MODELS } from '@/lib/knowledge/reranker-models' import type { KnowledgeSearchResponse } from '@/tools/knowledge/types' import { enrichKBTagFiltersSchema } from '@/tools/schema-enrichers' import { parseTagFilters } from '@/tools/shared/tags' @@ -41,6 +42,18 @@ export const knowledgeSearchTool: ToolConfig = { }, }, }, + rerankerEnabled: { + type: 'boolean', + required: false, + visibility: 'user-only', + description: 'Whether to apply Cohere reranking to vector search results', + }, + rerankerModel: { + type: 'string', + required: false, + visibility: 'user-only', + description: `Cohere rerank model to use (one of: ${SUPPORTED_RERANKER_MODELS.join(', ')})`, + }, }, schemaEnrichment: { @@ -65,11 +78,18 @@ export const knowledgeSearchTool: ToolConfig = { // Parse tag filters from various formats (array, JSON string) const structuredFilters = parseTagFilters(params.tagFilters) + const rerankerEnabled = params.rerankerEnabled === true || params.rerankerEnabled === 'true' + const rerankerModel = + typeof params.rerankerModel === 'string' && params.rerankerModel.length > 0 + ? params.rerankerModel + : DEFAULT_RERANKER_MODEL + const requestBody = { knowledgeBaseIds, query: params.query, topK: params.topK ? Math.max(1, Math.min(100, Number(params.topK))) : 10, ...(structuredFilters.length > 0 && { tagFilters: structuredFilters }), + ...(rerankerEnabled && { rerankerEnabled: true, rerankerModel }), ...(workflowId && { workflowId }), } @@ -83,9 +103,25 @@ export const knowledgeSearchTool: ToolConfig = { // Restructure cost: extract tokens/model to top level for logging let costFields: Record = {} if (data.cost && typeof data.cost === 'object') { - const { tokens, model, input, output: outputCost, total } = data.cost + const { + tokens, + model, + input, + output: outputCost, + total, + rerankerCost, + rerankerModel, + rerankerSearchUnits, + } = data.cost costFields = { - cost: { input, output: outputCost, total }, + cost: { + input, + output: outputCost, + total, + ...(typeof rerankerCost === 'number' && { rerankerCost }), + ...(typeof rerankerModel === 'string' && { rerankerModel }), + ...(typeof rerankerSearchUnits === 'number' && { rerankerSearchUnits }), + }, ...(tokens && { tokens }), ...(model && { model }), } diff --git a/apps/sim/tools/knowledge/types.ts b/apps/sim/tools/knowledge/types.ts index 3fa87ccaad7..31005d638d7 100644 --- a/apps/sim/tools/knowledge/types.ts +++ b/apps/sim/tools/knowledge/types.ts @@ -40,6 +40,7 @@ export interface KnowledgeSearchResult { chunkIndex: number metadata: Record similarity: number + rerankerScore?: number } export interface KnowledgeSearchResponse { @@ -52,18 +53,16 @@ export interface KnowledgeSearchResponse { input: number output: number total: number - tokens: { - prompt: number - completion: number - total: number - } - model: string - pricing: { - input: number - output: number - updatedAt: string - } + rerankerCost?: number + rerankerModel?: string + rerankerSearchUnits?: number + } + tokens?: { + prompt: number + completion: number + total: number } + model?: string } error?: string } diff --git a/apps/sim/tools/types.ts b/apps/sim/tools/types.ts index 7758f6facc2..535761646bc 100644 --- a/apps/sim/tools/types.ts +++ b/apps/sim/tools/types.ts @@ -17,6 +17,7 @@ export type BYOKProviderId = | 'linkup' | 'brandfetch' | 'parallel_ai' + | 'cohere' export type HttpMethod = 'GET' | 'POST' | 'PUT' | 'DELETE' | 'PATCH' | 'HEAD'