diff --git a/packages/backend/server/migrations/20240402100608_ai_prompt_session_metadata/migration.sql b/packages/backend/server/migrations/20240402100608_ai_prompt_session_metadata/migration.sql index 837d9601ead5a..1c41993c5cbaa 100644 --- a/packages/backend/server/migrations/20240402100608_ai_prompt_session_metadata/migration.sql +++ b/packages/backend/server/migrations/20240402100608_ai_prompt_session_metadata/migration.sql @@ -26,6 +26,7 @@ CREATE TABLE "ai_prompts_messages" ( "idx" INTEGER NOT NULL, "role" "AiPromptRole" NOT NULL, "content" TEXT NOT NULL, + "attachments" JSON, "params" JSON, "created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP ); @@ -47,6 +48,8 @@ CREATE TABLE "ai_sessions_messages" ( "session_id" VARCHAR(36) NOT NULL, "role" "AiPromptRole" NOT NULL, "content" TEXT NOT NULL, + "attachments" JSON, + "params" JSON, "created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP, "updated_at" TIMESTAMPTZ(6) NOT NULL, diff --git a/packages/backend/server/schema.prisma b/packages/backend/server/schema.prisma index f9f5ae06967e3..920268a1c10ce 100644 --- a/packages/backend/server/schema.prisma +++ b/packages/backend/server/schema.prisma @@ -430,15 +430,16 @@ enum AiPromptRole { } model AiPromptMessage { - promptId Int @map("prompt_id") @db.Integer + promptId Int @map("prompt_id") @db.Integer // if a group of prompts contains multiple sentences, idx specifies the order of each sentence - idx Int @db.Integer + idx Int @db.Integer // system/assistant/user - role AiPromptRole + role AiPromptRole // prompt content - content String @db.Text - params Json? @db.Json - createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6) + content String @db.Text + attachments Json? @db.Json + params Json? @db.Json + createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6) prompt AiPrompt @relation(fields: [promptId], references: [id], onDelete: Cascade) @@ -462,12 +463,14 @@ model AiPrompt { } model AiSessionMessage { - id String @id @default(uuid()) @db.VarChar(36) - sessionId String @map("session_id") @db.VarChar(36) - role AiPromptRole - content String @db.Text - createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6) - updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(6) + id String @id @default(uuid()) @db.VarChar(36) + sessionId String @map("session_id") @db.VarChar(36) + role AiPromptRole + content String @db.Text + attachments Json? @db.Json + params Json? @db.Json + createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6) + updatedAt DateTime @updatedAt @map("updated_at") @db.Timestamptz(6) session AiSession @relation(fields: [sessionId], references: [id], onDelete: Cascade) diff --git a/packages/backend/server/src/plugins/copilot/controller.ts b/packages/backend/server/src/plugins/copilot/controller.ts index f2cf446b82e9e..75f9cb31aa1aa 100644 --- a/packages/backend/server/src/plugins/copilot/controller.ts +++ b/packages/backend/server/src/plugins/copilot/controller.ts @@ -24,7 +24,7 @@ import { Public } from '../../core/auth'; import { CurrentUser } from '../../core/auth/current-user'; import { ProviderService } from './providers'; import { ChatSession, ChatSessionService } from './session'; -import { CopilotCapability } from './types'; +import { CopilotCapability, CopilotProviderType } from './types'; export interface ChatEvent { type: 'attachment' | 'message'; @@ -180,7 +180,8 @@ export class CopilotController { @Query() params: Record ): Promise> { const provider = this.provider.getProviderByCapability( - CopilotCapability.TextToImage + CopilotCapability.ImageToImage, + CopilotProviderType.FAL ); if (!provider) { throw new InternalServerErrorException('No provider available'); diff --git a/packages/backend/server/src/plugins/copilot/providers/fal.ts b/packages/backend/server/src/plugins/copilot/providers/fal.ts index 487da3b63bc1f..addb8d8b1489e 100644 --- a/packages/backend/server/src/plugins/copilot/providers/fal.ts +++ b/packages/backend/server/src/plugins/copilot/providers/fal.ts @@ -11,6 +11,10 @@ export type FalConfig = { apiKey: string; }; +export type FalResponse = { + images: Array<{ url: string }>; +}; + export class FalProvider implements CopilotImageToImageProvider { static readonly type = CopilotProviderType.FAL; static readonly capabilities = [CopilotCapability.ImageToImage]; @@ -53,7 +57,7 @@ export class FalProvider implements CopilotImageToImageProvider { throw new Error('Attachments is required'); } - const data = await fetch(`https://${model}.gateway.alpha.fal.ai/`, { + const data = (await fetch(`https://${model}.gateway.alpha.fal.ai/`, { method: 'POST', headers: { Authorization: `key ${this.config.apiKey}`, @@ -67,9 +71,9 @@ export class FalProvider implements CopilotImageToImageProvider { enable_safety_checks: false, }), signal: options.signal, - }).then(res => res.json()); + }).then(res => res.json())) as FalResponse; - return data.images[0].url; + return data.images.map(image => image.url); } async *generateImagesStream(