diff --git a/.changeset/quick-walls-switch.md b/.changeset/quick-walls-switch.md new file mode 100644 index 000000000..31ff2a41a --- /dev/null +++ b/.changeset/quick-walls-switch.md @@ -0,0 +1,5 @@ +--- +"create-llama": patch +--- + +Let user change indexes in LlamaCloud projects diff --git a/helpers/env-variables.ts b/helpers/env-variables.ts index bb04a7539..34930cad5 100644 --- a/helpers/env-variables.ts +++ b/helpers/env-variables.ts @@ -160,6 +160,17 @@ const getVectorDBEnvs = ( description: "The organization ID for the LlamaCloud project (uses default organization if not specified - Python only)", }, + ...(framework === "nextjs" + ? // activate index selector per default (not needed for non-NextJS backends as it's handled by createFrontendEnvFile) + [ + { + name: "NEXT_PUBLIC_USE_LLAMACLOUD", + description: + "Let's the user change indexes in LlamaCloud projects", + value: "true", + }, + ] + : []), ]; case "chroma": const envs = [ @@ -493,6 +504,7 @@ export const createFrontendEnvFile = async ( root: string, opts: { customApiPath?: string; + vectorDb?: TemplateVectorDB; }, ) => { const defaultFrontendEnvs = [ @@ -503,6 +515,11 @@ export const createFrontendEnvFile = async ( ? opts.customApiPath : "http://localhost:8000/api/chat", }, + { + name: "NEXT_PUBLIC_USE_LLAMACLOUD", + description: "Let's the user change indexes in LlamaCloud projects", + value: opts.vectorDb === "llamacloud" ? "true" : "false", + }, ]; const content = renderEnvVar(defaultFrontendEnvs); await fs.writeFile(path.join(root, ".env"), content); diff --git a/helpers/index.ts b/helpers/index.ts index ae83f56cc..79f3c6873 100644 --- a/helpers/index.ts +++ b/helpers/index.ts @@ -209,6 +209,7 @@ export const installTemplate = async ( // this is a frontend for a full-stack app, create .env file with model information await createFrontendEnvFile(props.root, { customApiPath: props.customApiPath, + vectorDb: props.vectorDb, }); } }; diff --git a/templates/components/engines/python/agent/__init__.py b/templates/components/engines/python/agent/__init__.py index 17e36236b..fb8d410c6 100644 --- a/templates/components/engines/python/agent/__init__.py +++ b/templates/components/engines/python/agent/__init__.py @@ -6,7 +6,7 @@ from app.engine.index import get_index -def get_chat_engine(filters=None): +def get_chat_engine(filters=None, params=None): system_prompt = os.getenv("SYSTEM_PROMPT") top_k = os.getenv("TOP_K", "3") tools = [] diff --git a/templates/components/engines/python/chat/__init__.py b/templates/components/engines/python/chat/__init__.py index f885ed132..7d8df55ae 100644 --- a/templates/components/engines/python/chat/__init__.py +++ b/templates/components/engines/python/chat/__init__.py @@ -3,11 +3,11 @@ from fastapi import HTTPException -def get_chat_engine(filters=None): +def get_chat_engine(filters=None, params=None): system_prompt = os.getenv("SYSTEM_PROMPT") top_k = os.getenv("TOP_K", 3) - index = get_index() + index = get_index(params) if index is None: raise HTTPException( status_code=500, diff --git a/templates/components/engines/typescript/agent/chat.ts b/templates/components/engines/typescript/agent/chat.ts index f3b438408..c9868ff71 100644 --- a/templates/components/engines/typescript/agent/chat.ts +++ b/templates/components/engines/typescript/agent/chat.ts @@ -10,12 +10,12 @@ import path from "node:path"; import { getDataSource } from "./index"; import { createTools } from "./tools"; -export async function createChatEngine(documentIds?: string[]) { +export async function createChatEngine(documentIds?: string[], params?: any) { const tools: BaseToolWithCall[] = []; // Add a query engine tool if we have a data source // Delete this code if you don't have a data source - const index = await getDataSource(); + const index = await getDataSource(params); if (index) { tools.push( new QueryEngineTool({ diff --git a/templates/components/engines/typescript/chat/chat.ts b/templates/components/engines/typescript/chat/chat.ts index 1144256a5..d2badd713 100644 --- a/templates/components/engines/typescript/chat/chat.ts +++ b/templates/components/engines/typescript/chat/chat.ts @@ -6,8 +6,8 @@ import { } from "llamaindex"; import { getDataSource } from "./index"; -export async function createChatEngine(documentIds?: string[]) { - const index = await getDataSource(); +export async function createChatEngine(documentIds?: string[], params?: any) { + const index = await getDataSource(params); if (!index) { throw new Error( `StorageContext is empty - call 'npm run generate' to generate the storage first`, diff --git a/templates/components/llamaindex/typescript/streaming/service.ts b/templates/components/llamaindex/typescript/streaming/service.ts index 6b6c4206c..91001e916 100644 --- a/templates/components/llamaindex/typescript/streaming/service.ts +++ b/templates/components/llamaindex/typescript/streaming/service.ts @@ -7,19 +7,51 @@ const LLAMA_CLOUD_OUTPUT_DIR = "output/llamacloud"; const LLAMA_CLOUD_BASE_URL = "https://cloud.llamaindex.ai/api/v1"; const FILE_DELIMITER = "$"; // delimiter between pipelineId and filename -interface LlamaCloudFile { +type LlamaCloudFile = { name: string; file_id: string; project_id: string; -} +}; + +type LLamaCloudProject = { + id: string; + organization_id: string; + name: string; + is_default: boolean; +}; + +type LLamaCloudPipeline = { + id: string; + name: string; + project_id: string; +}; export class LLamaCloudFileService { + private static readonly headers = { + Accept: "application/json", + Authorization: `Bearer ${process.env.LLAMA_CLOUD_API_KEY}`, + }; + + public static async getAllProjectsWithPipelines() { + try { + const projects = await LLamaCloudFileService.getAllProjects(); + const pipelines = await LLamaCloudFileService.getAllPipelines(); + return projects.map((project) => ({ + ...project, + pipelines: pipelines.filter((p) => p.project_id === project.id), + })); + } catch (error) { + console.error("Error listing projects and pipelines:", error); + return []; + } + } + public static async downloadFiles(nodes: NodeWithScore[]) { - const files = this.nodesToDownloadFiles(nodes); + const files = LLamaCloudFileService.nodesToDownloadFiles(nodes); if (!files.length) return; console.log("Downloading files from LlamaCloud..."); for (const file of files) { - await this.downloadFile(file.pipelineId, file.fileName); + await LLamaCloudFileService.downloadFile(file.pipelineId, file.fileName); } } @@ -59,13 +91,19 @@ export class LLamaCloudFileService { private static async downloadFile(pipelineId: string, fileName: string) { try { - const downloadedName = this.toDownloadedName(pipelineId, fileName); + const downloadedName = LLamaCloudFileService.toDownloadedName( + pipelineId, + fileName, + ); const downloadedPath = path.join(LLAMA_CLOUD_OUTPUT_DIR, downloadedName); // Check if file already exists if (fs.existsSync(downloadedPath)) return; - const urlToDownload = await this.getFileUrlByName(pipelineId, fileName); + const urlToDownload = await LLamaCloudFileService.getFileUrlByName( + pipelineId, + fileName, + ); if (!urlToDownload) throw new Error("File not found in LlamaCloud"); const file = fs.createWriteStream(downloadedPath); @@ -93,10 +131,13 @@ export class LLamaCloudFileService { pipelineId: string, name: string, ): Promise { - const files = await this.getAllFiles(pipelineId); + const files = await LLamaCloudFileService.getAllFiles(pipelineId); const file = files.find((file) => file.name === name); if (!file) return null; - return await this.getFileUrlById(file.project_id, file.file_id); + return await LLamaCloudFileService.getFileUrlById( + file.project_id, + file.file_id, + ); } private static async getFileUrlById( @@ -104,11 +145,10 @@ export class LLamaCloudFileService { fileId: string, ): Promise { const url = `${LLAMA_CLOUD_BASE_URL}/files/${fileId}/content?project_id=${projectId}`; - const headers = { - Accept: "application/json", - Authorization: `Bearer ${process.env.LLAMA_CLOUD_API_KEY}`, - }; - const response = await fetch(url, { method: "GET", headers }); + const response = await fetch(url, { + method: "GET", + headers: LLamaCloudFileService.headers, + }); const data = (await response.json()) as { url: string }; return data.url; } @@ -117,12 +157,31 @@ export class LLamaCloudFileService { pipelineId: string, ): Promise { const url = `${LLAMA_CLOUD_BASE_URL}/pipelines/${pipelineId}/files`; - const headers = { - Accept: "application/json", - Authorization: `Bearer ${process.env.LLAMA_CLOUD_API_KEY}`, - }; - const response = await fetch(url, { method: "GET", headers }); + const response = await fetch(url, { + method: "GET", + headers: LLamaCloudFileService.headers, + }); const data = await response.json(); return data; } + + private static async getAllProjects(): Promise { + const url = `${LLAMA_CLOUD_BASE_URL}/projects`; + const response = await fetch(url, { + method: "GET", + headers: LLamaCloudFileService.headers, + }); + const data = (await response.json()) as LLamaCloudProject[]; + return data; + } + + private static async getAllPipelines(): Promise { + const url = `${LLAMA_CLOUD_BASE_URL}/pipelines`; + const response = await fetch(url, { + method: "GET", + headers: LLamaCloudFileService.headers, + }); + const data = (await response.json()) as LLamaCloudPipeline[]; + return data; + } } diff --git a/templates/components/vectordbs/python/llamacloud/index.py b/templates/components/vectordbs/python/llamacloud/index.py index da73434f2..e54e8ca9d 100644 --- a/templates/components/vectordbs/python/llamacloud/index.py +++ b/templates/components/vectordbs/python/llamacloud/index.py @@ -5,10 +5,11 @@ logger = logging.getLogger("uvicorn") - -def get_index(): - name = os.getenv("LLAMA_CLOUD_INDEX_NAME") - project_name = os.getenv("LLAMA_CLOUD_PROJECT_NAME") +def get_index(params=None): + configParams = params or {} + pipelineConfig = configParams.get("llamaCloudPipeline", {}) + name = pipelineConfig.get("pipeline", os.getenv("LLAMA_CLOUD_INDEX_NAME")) + project_name = pipelineConfig.get("project", os.getenv("LLAMA_CLOUD_PROJECT_NAME")) api_key = os.getenv("LLAMA_CLOUD_API_KEY") base_url = os.getenv("LLAMA_CLOUD_BASE_URL") organization_id = os.getenv("LLAMA_CLOUD_ORGANIZATION_ID") diff --git a/templates/components/vectordbs/python/none/index.py b/templates/components/vectordbs/python/none/index.py index f7949d66c..65fd5ad5a 100644 --- a/templates/components/vectordbs/python/none/index.py +++ b/templates/components/vectordbs/python/none/index.py @@ -17,7 +17,7 @@ def get_storage_context(persist_dir: str) -> StorageContext: return StorageContext.from_defaults(persist_dir=persist_dir) -def get_index(): +def get_index(params=None): storage_dir = os.getenv("STORAGE_DIR", "storage") # check if storage already exists if not os.path.exists(storage_dir): diff --git a/templates/components/vectordbs/typescript/astra/index.ts b/templates/components/vectordbs/typescript/astra/index.ts index e29ed3531..38c5bbbdd 100644 --- a/templates/components/vectordbs/typescript/astra/index.ts +++ b/templates/components/vectordbs/typescript/astra/index.ts @@ -3,7 +3,7 @@ import { VectorStoreIndex } from "llamaindex"; import { AstraDBVectorStore } from "llamaindex/storage/vectorStore/AstraDBVectorStore"; import { checkRequiredEnvVars } from "./shared"; -export async function getDataSource() { +export async function getDataSource(params?: any) { checkRequiredEnvVars(); const store = new AstraDBVectorStore(); await store.connect(process.env.ASTRA_DB_COLLECTION!); diff --git a/templates/components/vectordbs/typescript/chroma/index.ts b/templates/components/vectordbs/typescript/chroma/index.ts index 1d36e643b..fbc7b4bf2 100644 --- a/templates/components/vectordbs/typescript/chroma/index.ts +++ b/templates/components/vectordbs/typescript/chroma/index.ts @@ -3,7 +3,7 @@ import { VectorStoreIndex } from "llamaindex"; import { ChromaVectorStore } from "llamaindex/storage/vectorStore/ChromaVectorStore"; import { checkRequiredEnvVars } from "./shared"; -export async function getDataSource() { +export async function getDataSource(params?: any) { checkRequiredEnvVars(); const chromaUri = `http://${process.env.CHROMA_HOST}:${process.env.CHROMA_PORT}`; diff --git a/templates/components/vectordbs/typescript/llamacloud/index.ts b/templates/components/vectordbs/typescript/llamacloud/index.ts index 3f0875ccd..413d97b66 100644 --- a/templates/components/vectordbs/typescript/llamacloud/index.ts +++ b/templates/components/vectordbs/typescript/llamacloud/index.ts @@ -1,12 +1,26 @@ import { LlamaCloudIndex } from "llamaindex/cloud/LlamaCloudIndex"; -import { checkRequiredEnvVars } from "./shared"; -export async function getDataSource() { - checkRequiredEnvVars(); +type LlamaCloudDataSourceParams = { + llamaCloudPipeline?: { + project: string; + pipeline: string; + }; +}; + +export async function getDataSource(params?: LlamaCloudDataSourceParams) { + const { project, pipeline } = params?.llamaCloudPipeline ?? {}; + const projectName = project ?? process.env.LLAMA_CLOUD_PROJECT_NAME; + const pipelineName = pipeline ?? process.env.LLAMA_CLOUD_INDEX_NAME; + const apiKey = process.env.LLAMA_CLOUD_API_KEY; + if (!projectName || !pipelineName || !apiKey) { + throw new Error( + "Set project, pipeline, and api key in the params or as environment variables.", + ); + } const index = new LlamaCloudIndex({ - name: process.env.LLAMA_CLOUD_INDEX_NAME!, - projectName: process.env.LLAMA_CLOUD_PROJECT_NAME!, - apiKey: process.env.LLAMA_CLOUD_API_KEY, + name: pipelineName, + projectName, + apiKey, baseUrl: process.env.LLAMA_CLOUD_BASE_URL, }); return index; diff --git a/templates/components/vectordbs/typescript/milvus/index.ts b/templates/components/vectordbs/typescript/milvus/index.ts index c290175f0..91275b11e 100644 --- a/templates/components/vectordbs/typescript/milvus/index.ts +++ b/templates/components/vectordbs/typescript/milvus/index.ts @@ -2,7 +2,7 @@ import { VectorStoreIndex } from "llamaindex"; import { MilvusVectorStore } from "llamaindex/storage/vectorStore/MilvusVectorStore"; import { checkRequiredEnvVars, getMilvusClient } from "./shared"; -export async function getDataSource() { +export async function getDataSource(params?: any) { checkRequiredEnvVars(); const milvusClient = getMilvusClient(); const store = new MilvusVectorStore({ milvusClient }); diff --git a/templates/components/vectordbs/typescript/mongo/index.ts b/templates/components/vectordbs/typescript/mongo/index.ts index effb8f921..efa35fa53 100644 --- a/templates/components/vectordbs/typescript/mongo/index.ts +++ b/templates/components/vectordbs/typescript/mongo/index.ts @@ -4,7 +4,7 @@ import { MongoDBAtlasVectorSearch } from "llamaindex/storage/vectorStore/MongoDB import { MongoClient } from "mongodb"; import { checkRequiredEnvVars } from "./shared"; -export async function getDataSource() { +export async function getDataSource(params?: any) { checkRequiredEnvVars(); const client = new MongoClient(process.env.MONGO_URI!); const store = new MongoDBAtlasVectorSearch({ diff --git a/templates/components/vectordbs/typescript/none/index.ts b/templates/components/vectordbs/typescript/none/index.ts index 64b289750..fecc76f45 100644 --- a/templates/components/vectordbs/typescript/none/index.ts +++ b/templates/components/vectordbs/typescript/none/index.ts @@ -2,7 +2,7 @@ import { SimpleDocumentStore, VectorStoreIndex } from "llamaindex"; import { storageContextFromDefaults } from "llamaindex/storage/StorageContext"; import { STORAGE_CACHE_DIR } from "./shared"; -export async function getDataSource() { +export async function getDataSource(params?: any) { const storageContext = await storageContextFromDefaults({ persistDir: `${STORAGE_CACHE_DIR}`, }); diff --git a/templates/components/vectordbs/typescript/pg/index.ts b/templates/components/vectordbs/typescript/pg/index.ts index 787cae74d..75bcd4038 100644 --- a/templates/components/vectordbs/typescript/pg/index.ts +++ b/templates/components/vectordbs/typescript/pg/index.ts @@ -7,7 +7,7 @@ import { checkRequiredEnvVars, } from "./shared"; -export async function getDataSource() { +export async function getDataSource(params?: any) { checkRequiredEnvVars(); const pgvs = new PGVectorStore({ connectionString: process.env.PG_CONNECTION_STRING, diff --git a/templates/components/vectordbs/typescript/pinecone/index.ts b/templates/components/vectordbs/typescript/pinecone/index.ts index 15072cfff..66a22d46e 100644 --- a/templates/components/vectordbs/typescript/pinecone/index.ts +++ b/templates/components/vectordbs/typescript/pinecone/index.ts @@ -3,7 +3,7 @@ import { VectorStoreIndex } from "llamaindex"; import { PineconeVectorStore } from "llamaindex/storage/vectorStore/PineconeVectorStore"; import { checkRequiredEnvVars } from "./shared"; -export async function getDataSource() { +export async function getDataSource(params?: any) { checkRequiredEnvVars(); const store = new PineconeVectorStore(); return await VectorStoreIndex.fromVectorStore(store); diff --git a/templates/components/vectordbs/typescript/qdrant/index.ts b/templates/components/vectordbs/typescript/qdrant/index.ts index 0233d0882..a9d87ab8c 100644 --- a/templates/components/vectordbs/typescript/qdrant/index.ts +++ b/templates/components/vectordbs/typescript/qdrant/index.ts @@ -5,7 +5,7 @@ import { checkRequiredEnvVars, getQdrantClient } from "./shared"; dotenv.config(); -export async function getDataSource() { +export async function getDataSource(params?: any) { checkRequiredEnvVars(); const collectionName = process.env.QDRANT_COLLECTION; const store = new QdrantVectorStore({ diff --git a/templates/types/streaming/express/src/controllers/chat-config.controller.ts b/templates/types/streaming/express/src/controllers/chat-config.controller.ts index 4481e10d1..af843c2c5 100644 --- a/templates/types/streaming/express/src/controllers/chat-config.controller.ts +++ b/templates/types/streaming/express/src/controllers/chat-config.controller.ts @@ -1,4 +1,5 @@ import { Request, Response } from "express"; +import { LLamaCloudFileService } from "./llamaindex/streaming/service"; export const chatConfig = async (_req: Request, res: Response) => { let starterQuestions = undefined; @@ -12,3 +13,14 @@ export const chatConfig = async (_req: Request, res: Response) => { starterQuestions, }); }; + +export const chatLlamaCloudConfig = async (_req: Request, res: Response) => { + const config = { + projects: await LLamaCloudFileService.getAllProjectsWithPipelines(), + pipeline: { + pipeline: process.env.LLAMA_CLOUD_INDEX_NAME, + project: process.env.LLAMA_CLOUD_PROJECT_NAME, + }, + }; + return res.status(200).json(config); +}; diff --git a/templates/types/streaming/express/src/controllers/chat.controller.ts b/templates/types/streaming/express/src/controllers/chat.controller.ts index 95228e8d8..50b70789c 100644 --- a/templates/types/streaming/express/src/controllers/chat.controller.ts +++ b/templates/types/streaming/express/src/controllers/chat.controller.ts @@ -17,7 +17,7 @@ export const chat = async (req: Request, res: Response) => { const vercelStreamData = new StreamData(); const streamTimeout = createStreamTimeout(vercelStreamData); try { - const { messages }: { messages: Message[] } = req.body; + const { messages, data }: { messages: Message[]; data?: any } = req.body; const userMessage = messages.pop(); if (!messages || !userMessage || userMessage.role !== "user") { return res.status(400).json({ @@ -46,7 +46,7 @@ export const chat = async (req: Request, res: Response) => { }, ); const ids = retrieveDocumentIds(allAnnotations); - const chatEngine = await createChatEngine(ids); + const chatEngine = await createChatEngine(ids, data); // Convert message content from Vercel/AI format to LlamaIndex/OpenAI format const userMessageContent = convertMessageContent( diff --git a/templates/types/streaming/express/src/routes/chat.route.ts b/templates/types/streaming/express/src/routes/chat.route.ts index 96711d502..5044efcb4 100644 --- a/templates/types/streaming/express/src/routes/chat.route.ts +++ b/templates/types/streaming/express/src/routes/chat.route.ts @@ -1,5 +1,8 @@ import express, { Router } from "express"; -import { chatConfig } from "../controllers/chat-config.controller"; +import { + chatConfig, + chatLlamaCloudConfig, +} from "../controllers/chat-config.controller"; import { chatRequest } from "../controllers/chat-request.controller"; import { chatUpload } from "../controllers/chat-upload.controller"; import { chat } from "../controllers/chat.controller"; @@ -11,6 +14,7 @@ initSettings(); llmRouter.route("/").post(chat); llmRouter.route("/request").post(chatRequest); llmRouter.route("/config").get(chatConfig); +llmRouter.route("/config/llamacloud").get(chatLlamaCloudConfig); llmRouter.route("/upload").post(chatUpload); export default llmRouter; diff --git a/templates/types/streaming/fastapi/app/api/routers/chat.py b/templates/types/streaming/fastapi/app/api/routers/chat.py index e2cffadf7..cb7036d92 100644 --- a/templates/types/streaming/fastapi/app/api/routers/chat.py +++ b/templates/types/streaming/fastapi/app/api/routers/chat.py @@ -52,8 +52,9 @@ async def chat( doc_ids = data.get_chat_document_ids() filters = generate_filters(doc_ids) + params = data.data or {} logger.info("Creating chat engine with filters", filters.dict()) - chat_engine = get_chat_engine(filters=filters) + chat_engine = get_chat_engine(filters=filters, params=params) event_handler = EventCallbackHandler() chat_engine.callback_manager.handlers.append(event_handler) # type: ignore @@ -125,3 +126,23 @@ async def chat_config() -> ChatConfig: if conversation_starters and conversation_starters.strip(): starter_questions = conversation_starters.strip().split("\n") return ChatConfig(starter_questions=starter_questions) + + +@r.get("/config/llamacloud") +async def chat_llama_cloud_config(): + projects = LLamaCloudFileService.get_all_projects_with_pipelines() + pipeline = os.getenv("LLAMA_CLOUD_INDEX_NAME") + project = os.getenv("LLAMA_CLOUD_PROJECT_NAME") + pipeline_config = ( + pipeline + and project + and { + "pipeline": pipeline, + "project": project, + } + or None + ) + return { + "projects": projects, + "pipeline": pipeline_config, + } diff --git a/templates/types/streaming/fastapi/app/api/routers/models.py b/templates/types/streaming/fastapi/app/api/routers/models.py index b510e848a..c9ea1adb7 100644 --- a/templates/types/streaming/fastapi/app/api/routers/models.py +++ b/templates/types/streaming/fastapi/app/api/routers/models.py @@ -75,6 +75,7 @@ class Message(BaseModel): class ChatData(BaseModel): messages: List[Message] + data: Any = None class Config: json_schema_extra = { @@ -237,7 +238,7 @@ class ChatConfig(BaseModel): starter_questions: Optional[List[str]] = Field( default=None, description="List of starter questions", - serialization_alias="starterQuestions" + serialization_alias="starterQuestions", ) class Config: diff --git a/templates/types/streaming/fastapi/app/api/services/llama_cloud.py b/templates/types/streaming/fastapi/app/api/services/llama_cloud.py index ea03e64bb..852ae7cee 100644 --- a/templates/types/streaming/fastapi/app/api/services/llama_cloud.py +++ b/templates/types/streaming/fastapi/app/api/services/llama_cloud.py @@ -14,6 +14,32 @@ class LLamaCloudFileService: DOWNLOAD_FILE_NAME_TPL = "{pipeline_id}${filename}" + @classmethod + def get_all_projects(cls) -> List[Dict[str, Any]]: + url = f"{cls.LLAMA_CLOUD_URL}/projects" + return cls._make_request(url) + + @classmethod + def get_all_pipelines(cls) -> List[Dict[str, Any]]: + url = f"{cls.LLAMA_CLOUD_URL}/pipelines" + return cls._make_request(url) + + @classmethod + def get_all_projects_with_pipelines(cls) -> List[Dict[str, Any]]: + try: + projects = cls.get_all_projects() + pipelines = cls.get_all_pipelines() + return [ + { + **project, + "pipelines": [p for p in pipelines if p["project_id"] == project["id"]], + } + for project in projects + ] + except Exception as error: + logger.error(f"Error listing projects and pipelines: {error}") + return [] + @classmethod def _get_files(cls, pipeline_id: str) -> List[Dict[str, Any]]: url = f"{cls.LLAMA_CLOUD_URL}/pipelines/{pipeline_id}/files" diff --git a/templates/types/streaming/fastapi/app/engine/index.py b/templates/types/streaming/fastapi/app/engine/index.py index 2dbc589b1..e1adcb803 100644 --- a/templates/types/streaming/fastapi/app/engine/index.py +++ b/templates/types/streaming/fastapi/app/engine/index.py @@ -6,7 +6,7 @@ logger = logging.getLogger("uvicorn") -def get_index(): +def get_index(params=None): logger.info("Connecting vector store...") store = get_vector_store() # Load the index from the vector store diff --git a/templates/types/streaming/nextjs/app/api/chat/config/llamacloud/route.ts b/templates/types/streaming/nextjs/app/api/chat/config/llamacloud/route.ts new file mode 100644 index 000000000..e40409f17 --- /dev/null +++ b/templates/types/streaming/nextjs/app/api/chat/config/llamacloud/route.ts @@ -0,0 +1,16 @@ +import { NextResponse } from "next/server"; +import { LLamaCloudFileService } from "../../llamaindex/streaming/service"; + +/** + * This API is to get config from the backend envs and expose them to the frontend + */ +export async function GET() { + const config = { + projects: await LLamaCloudFileService.getAllProjectsWithPipelines(), + pipeline: { + pipeline: process.env.LLAMA_CLOUD_INDEX_NAME, + project: process.env.LLAMA_CLOUD_PROJECT_NAME, + }, + }; + return NextResponse.json(config, { status: 200 }); +} diff --git a/templates/types/streaming/nextjs/app/api/chat/route.ts b/templates/types/streaming/nextjs/app/api/chat/route.ts index 792ecb7b8..adfccf13b 100644 --- a/templates/types/streaming/nextjs/app/api/chat/route.ts +++ b/templates/types/streaming/nextjs/app/api/chat/route.ts @@ -27,7 +27,7 @@ export async function POST(request: NextRequest) { try { const body = await request.json(); - const { messages }: { messages: Message[] } = body; + const { messages, data }: { messages: Message[]; data?: any } = body; const userMessage = messages.pop(); if (!messages || !userMessage || userMessage.role !== "user") { return NextResponse.json( @@ -59,7 +59,7 @@ export async function POST(request: NextRequest) { }, ); const ids = retrieveDocumentIds(allAnnotations); - const chatEngine = await createChatEngine(ids); + const chatEngine = await createChatEngine(ids, data); // Convert message content from Vercel/AI format to LlamaIndex/OpenAI format const userMessageContent = convertMessageContent( diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/chat-input.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/chat-input.tsx index 01c7c0b1b..4c582966a 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/chat-input.tsx +++ b/templates/types/streaming/nextjs/app/components/ui/chat/chat-input.tsx @@ -1,4 +1,5 @@ import { JSONValue } from "ai"; +import { useState } from "react"; import { Button } from "../button"; import { DocumentPreview } from "../document-preview"; import FileUploader from "../file-uploader"; @@ -6,6 +7,7 @@ import { Input } from "../input"; import UploadImagePreview from "../upload-image-preview"; import { ChatHandler } from "./chat.interface"; import { useFile } from "./hooks/use-file"; +import { LlamaCloudSelector } from "./widgets/LlamaCloudSelector"; const ALLOWED_EXTENSIONS = ["png", "jpg", "jpeg", "csv", "pdf", "txt", "docx"]; @@ -34,6 +36,7 @@ export default function ChatInput( reset, getAnnotations, } = useFile(); + const [requestData, setRequestData] = useState(); // default submit function does not handle including annotations in the message // so we need to use append function to submit new message with annotations @@ -42,12 +45,15 @@ export default function ChatInput( annotations: JSONValue[] | undefined, ) => { e.preventDefault(); - props.append!({ - content: props.input, - role: "user", - createdAt: new Date(), - annotations, - }); + props.append!( + { + content: props.input, + role: "user", + createdAt: new Date(), + annotations, + }, + { data: requestData }, + ); props.setInput!(""); }; @@ -57,7 +63,7 @@ export default function ChatInput( handleSubmitWithAnnotations(e, annotations); return reset(); } - props.handleSubmit(e); + props.handleSubmit(e, { data: requestData }); }; const handleUploadFile = async (file: File) => { @@ -109,6 +115,9 @@ export default function ChatInput( disabled: props.isLoading, }} /> + {process.env.NEXT_PUBLIC_USE_LLAMACLOUD === "true" && ( + + )} diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/widgets/LlamaCloudSelector.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/widgets/LlamaCloudSelector.tsx new file mode 100644 index 000000000..aa995c91f --- /dev/null +++ b/templates/types/streaming/nextjs/app/components/ui/chat/widgets/LlamaCloudSelector.tsx @@ -0,0 +1,151 @@ +import { Loader2 } from "lucide-react"; +import { useEffect, useState } from "react"; +import { + Select, + SelectContent, + SelectGroup, + SelectItem, + SelectLabel, + SelectTrigger, + SelectValue, +} from "../../select"; +import { useClientConfig } from "../hooks/use-config"; + +type LLamaCloudPipeline = { + id: string; + name: string; +}; + +type LLamaCloudProject = { + id: string; + organization_id: string; + name: string; + is_default: boolean; + pipelines: Array; +}; + +type PipelineConfig = { + project: string; // project name + pipeline: string; // pipeline name +}; + +type LlamaCloudConfig = { + projects?: LLamaCloudProject[]; + pipeline?: PipelineConfig; +}; + +export interface LlamaCloudSelectorProps { + setRequestData: React.Dispatch; +} + +export function LlamaCloudSelector({ + setRequestData, +}: LlamaCloudSelectorProps) { + const { backend } = useClientConfig(); + const [config, setConfig] = useState(); + + useEffect(() => { + if (process.env.NEXT_PUBLIC_USE_LLAMACLOUD === "true" && !config) { + fetch(`${backend}/api/chat/config/llamacloud`) + .then((response) => response.json()) + .then((data) => { + setConfig(data); + setRequestData({ + llamaCloudPipeline: data.pipeline, + }); + }) + .catch((error) => console.error("Error fetching config", error)); + } + }, [backend, config, setRequestData]); + + const setPipeline = (pipelineConfig?: PipelineConfig) => { + setConfig((prevConfig: any) => ({ + ...prevConfig, + pipeline: pipelineConfig, + })); + setRequestData((prevData: any) => { + if (!prevData) return { llamaCloudPipeline: pipelineConfig }; + return { + ...prevData, + llamaCloudPipeline: pipelineConfig, + }; + }); + }; + + const handlePipelineSelect = async (value: string) => { + setPipeline(JSON.parse(value) as PipelineConfig); + }; + + if (!config) { + return ( +
+ +
+ ); + } + if (!isValid(config)) { + return ( +

+ Invalid LlamaCloud configuration. Check console logs. +

+ ); + } + const { projects, pipeline } = config; + + return ( + + ); +} + +function isValid(config: LlamaCloudConfig): boolean { + const { projects, pipeline } = config; + if (!projects?.length) return false; + if (!pipeline) return false; + const matchedProject = projects.find( + (project: LLamaCloudProject) => project.name === pipeline.project, + ); + if (!matchedProject) { + console.error( + `LlamaCloud project ${pipeline.project} not found. Check LLAMA_CLOUD_PROJECT_NAME variable`, + ); + return false; + } + const pipelineExists = matchedProject.pipelines.some( + (p) => p.name === pipeline.pipeline, + ); + if (!pipelineExists) { + console.error( + `LlamaCloud pipeline ${pipeline.pipeline} not found. Check LLAMA_CLOUD_INDEX_NAME variable`, + ); + return false; + } + return true; +} diff --git a/templates/types/streaming/nextjs/app/components/ui/select.tsx b/templates/types/streaming/nextjs/app/components/ui/select.tsx new file mode 100644 index 000000000..c01b068ba --- /dev/null +++ b/templates/types/streaming/nextjs/app/components/ui/select.tsx @@ -0,0 +1,159 @@ +"use client"; + +import * as SelectPrimitive from "@radix-ui/react-select"; +import { Check, ChevronDown, ChevronUp } from "lucide-react"; +import * as React from "react"; +import { cn } from "./lib/utils"; + +const Select = SelectPrimitive.Root; + +const SelectGroup = SelectPrimitive.Group; + +const SelectValue = SelectPrimitive.Value; + +const SelectTrigger = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, children, ...props }, ref) => ( + span]:line-clamp-1", + className, + )} + {...props} + > + {children} + + + + +)); +SelectTrigger.displayName = SelectPrimitive.Trigger.displayName; + +const SelectScrollUpButton = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + + + +)); +SelectScrollUpButton.displayName = SelectPrimitive.ScrollUpButton.displayName; + +const SelectScrollDownButton = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + + + +)); +SelectScrollDownButton.displayName = + SelectPrimitive.ScrollDownButton.displayName; + +const SelectContent = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, children, position = "popper", ...props }, ref) => ( + + + + + {children} + + + + +)); +SelectContent.displayName = SelectPrimitive.Content.displayName; + +const SelectLabel = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)); +SelectLabel.displayName = SelectPrimitive.Label.displayName; + +const SelectItem = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, children, ...props }, ref) => ( + + + + + + + + {children} + +)); +SelectItem.displayName = SelectPrimitive.Item.displayName; + +const SelectSeparator = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)); +SelectSeparator.displayName = SelectPrimitive.Separator.displayName; + +export { + Select, + SelectContent, + SelectGroup, + SelectItem, + SelectLabel, + SelectScrollDownButton, + SelectScrollUpButton, + SelectSeparator, + SelectTrigger, + SelectValue, +}; diff --git a/templates/types/streaming/nextjs/package.json b/templates/types/streaming/nextjs/package.json index 5d429a417..b0b8bd577 100644 --- a/templates/types/streaming/nextjs/package.json +++ b/templates/types/streaming/nextjs/package.json @@ -15,6 +15,7 @@ "@llamaindex/pdf-viewer": "^1.1.3", "@radix-ui/react-collapsible": "^1.0.3", "@radix-ui/react-hover-card": "^1.0.7", + "@radix-ui/react-select": "^2.1.1", "@radix-ui/react-slot": "^1.0.2", "ai": "^3.0.21", "ajv": "^8.12.0",