diff --git a/packages/backend/server/src/config/affine.self.ts b/packages/backend/server/src/config/affine.self.ts index 87adcee24d85..43597d525032 100644 --- a/packages/backend/server/src/config/affine.self.ts +++ b/packages/backend/server/src/config/affine.self.ts @@ -45,6 +45,7 @@ if (env.R2_OBJECT_STORAGE_ACCOUNT_ID) { AFFiNE.plugins.use('copilot', { openai: {}, + fal: {}, }); AFFiNE.plugins.use('redis'); AFFiNE.plugins.use('payment', { diff --git a/packages/backend/server/src/plugins/copilot/controller.ts b/packages/backend/server/src/plugins/copilot/controller.ts index f5f63042bec4..0caddcc52910 100644 --- a/packages/backend/server/src/plugins/copilot/controller.ts +++ b/packages/backend/server/src/plugins/copilot/controller.ts @@ -42,6 +42,11 @@ export interface ChatEvent { data: string; } +type CheckResult = { + model: string | undefined; + hasAttachment?: boolean; +}; + @Controller('/api/copilot') export class CopilotController { private readonly logger = new Logger(CopilotController.name); @@ -53,17 +58,26 @@ export class CopilotController { private readonly storage: CopilotStorage ) {} - private async hasAttachment(sessionId: string, messageId: string) { + private async checkRequest( + userId: string, + sessionId: string, + messageId?: string + ): Promise { + await this.chatSession.checkQuota(userId); const session = await this.chatSession.get(sessionId); - if (!session) { + if (!session || session.config.userId !== userId) { throw new BadRequestException('Session not found'); } - const message = await session.getMessageById(messageId); - if (Array.isArray(message.attachments) && message.attachments.length) { - return true; + const ret: CheckResult = { model: session.model }; + + if (messageId) { + const message = await session.getMessageById(messageId); + ret.hasAttachment = + Array.isArray(message.attachments) && !!message.attachments.length; } - return false; + + return ret; } private async appendSessionMessage( @@ -107,9 +121,7 @@ export class CopilotController { @Query('messageId') messageId: string, @Query() params: Record ): Promise { - await this.chatSession.checkQuota(user.id); - - const model = await this.chatSession.get(sessionId).then(s => s?.model); + const { model } = await this.checkRequest(user.id, sessionId); const provider = this.provider.getProviderByCapability( CopilotCapability.TextToText, model @@ -155,60 +167,58 @@ export class CopilotController { @Query() params: Record ): Promise> { try { - await this.chatSession.checkQuota(user.id); + const { model } = await this.checkRequest(user.id, sessionId); + const provider = this.provider.getProviderByCapability( + CopilotCapability.TextToText, + model + ); + if (!provider) { + throw new InternalServerErrorException('No provider available'); + } + + const session = await this.appendSessionMessage(sessionId, messageId); + delete params.messageId; + + return from( + provider.generateTextStream(session.finish(params), session.model, { + signal: this.getSignal(req), + user: user.id, + }) + ).pipe( + connect(shared$ => + merge( + // actual chat event stream + shared$.pipe( + map(data => ({ type: 'message' as const, id: messageId, data })) + ), + // save the generated text to the session + shared$.pipe( + toArray(), + concatMap(values => { + session.push({ + role: 'assistant', + content: values.join(''), + createdAt: new Date(), + }); + return from(session.save()); + }), + switchMap(() => EMPTY) + ) + ) + ), + catchError(err => + of({ + type: 'error' as const, + data: this.handleError(err), + }) + ) + ); } catch (err) { return of({ type: 'error' as const, data: this.handleError(err), }); } - - const model = await this.chatSession.get(sessionId).then(s => s?.model); - const provider = this.provider.getProviderByCapability( - CopilotCapability.TextToText, - model - ); - if (!provider) { - throw new InternalServerErrorException('No provider available'); - } - - const session = await this.appendSessionMessage(sessionId, messageId); - delete params.messageId; - - return from( - provider.generateTextStream(session.finish(params), session.model, { - signal: this.getSignal(req), - user: user.id, - }) - ).pipe( - connect(shared$ => - merge( - // actual chat event stream - shared$.pipe( - map(data => ({ type: 'message' as const, id: sessionId, data })) - ), - // save the generated text to the session - shared$.pipe( - toArray(), - concatMap(values => { - session.push({ - role: 'assistant', - content: values.join(''), - createdAt: new Date(), - }); - return from(session.save()); - }), - switchMap(() => EMPTY) - ) - ) - ), - catchError(err => - of({ - type: 'error' as const, - data: this.handleError(err), - }) - ) - ); } @Sse('/chat/:sessionId/images') @@ -220,75 +230,76 @@ export class CopilotController { @Query() params: Record ): Promise> { try { - await this.chatSession.checkQuota(user.id); + const { model, hasAttachment } = await this.checkRequest( + user.id, + sessionId, + messageId + ); + const provider = this.provider.getProviderByCapability( + hasAttachment + ? CopilotCapability.ImageToImage + : CopilotCapability.TextToImage, + model + ); + if (!provider) { + throw new InternalServerErrorException('No provider available'); + } + + const session = await this.appendSessionMessage(sessionId, messageId); + delete params.messageId; + + const handleRemoteLink = this.storage.handleRemoteLink.bind( + this.storage, + user.id, + sessionId + ); + + return from( + provider.generateImagesStream(session.finish(params), session.model, { + signal: this.getSignal(req), + user: user.id, + }) + ).pipe( + mergeMap(handleRemoteLink), + connect(shared$ => + merge( + // actual chat event stream + shared$.pipe( + map(attachment => ({ + type: 'attachment' as const, + id: messageId, + data: attachment, + })) + ), + // save the generated text to the session + shared$.pipe( + toArray(), + concatMap(attachments => { + session.push({ + role: 'assistant', + content: '', + attachments: attachments, + createdAt: new Date(), + }); + return from(session.save()); + }), + switchMap(() => EMPTY) + ) + ) + ), + catchError(err => + of({ + type: 'error' as const, + data: this.handleError(err), + }) + ) + ); } catch (err) { return of({ type: 'error' as const, data: this.handleError(err), }); } - - const hasAttachment = await this.hasAttachment(sessionId, messageId); - const model = await this.chatSession.get(sessionId).then(s => s?.model); - const provider = this.provider.getProviderByCapability( - hasAttachment - ? CopilotCapability.ImageToImage - : CopilotCapability.TextToImage, - model - ); - if (!provider) { - throw new InternalServerErrorException('No provider available'); - } - - const session = await this.appendSessionMessage(sessionId, messageId); - delete params.messageId; - - const handleRemoteLink = this.storage.handleRemoteLink.bind( - this.storage, - user.id, - sessionId - ); - - return from( - provider.generateImagesStream(session.finish(params), session.model, { - signal: this.getSignal(req), - user: user.id, - }) - ).pipe( - mergeMap(handleRemoteLink), - connect(shared$ => - merge( - // actual chat event stream - shared$.pipe( - map(attachment => ({ - type: 'attachment' as const, - id: sessionId, - data: attachment, - })) - ), - // save the generated text to the session - shared$.pipe( - toArray(), - concatMap(attachments => { - session.push({ - role: 'assistant', - content: '', - attachments: attachments, - createdAt: new Date(), - }); - return from(session.save()); - }), - switchMap(() => EMPTY) - ) - ) - ), - catchError(err => - of({ - type: 'error' as const, - data: this.handleError(err), - }) - ) - ); } @Get('/unsplash/photos') diff --git a/packages/backend/server/src/plugins/copilot/prompt.ts b/packages/backend/server/src/plugins/copilot/prompt.ts index 74c51127e897..06b9d5eccc34 100644 --- a/packages/backend/server/src/plugins/copilot/prompt.ts +++ b/packages/backend/server/src/plugins/copilot/prompt.ts @@ -193,11 +193,12 @@ export class PromptService { return null; } - async set(name: string, messages: PromptMessage[]) { + async set(name: string, model: string, messages: PromptMessage[]) { return await this.db.aiPrompt .create({ data: { name, + model, messages: { create: messages.map((m, idx) => ({ idx, diff --git a/packages/backend/server/src/plugins/copilot/providers/fal.ts b/packages/backend/server/src/plugins/copilot/providers/fal.ts index 67a4dab869f9..b6b1731b7d07 100644 --- a/packages/backend/server/src/plugins/copilot/providers/fal.ts +++ b/packages/backend/server/src/plugins/copilot/providers/fal.ts @@ -41,6 +41,10 @@ export class FalProvider return !!config.apiKey; } + get type(): CopilotProviderType { + return FalProvider.type; + } + getCapabilities(): CopilotCapability[] { return FalProvider.capabilities; } diff --git a/packages/backend/server/src/plugins/copilot/providers/openai.ts b/packages/backend/server/src/plugins/copilot/providers/openai.ts index c522120df7c8..21ef0eea5ebd 100644 --- a/packages/backend/server/src/plugins/copilot/providers/openai.ts +++ b/packages/backend/server/src/plugins/copilot/providers/openai.ts @@ -13,7 +13,7 @@ import { PromptMessage, } from '../types'; -const DEFAULT_DIMENSIONS = 256; +export const DEFAULT_DIMENSIONS = 256; const SIMPLE_IMAGE_URL_REGEX = /^(https?:\/\/|data:image\/)/; @@ -59,6 +59,10 @@ export class OpenAIProvider return !!config.apiKey; } + get type(): CopilotProviderType { + return OpenAIProvider.type; + } + getCapabilities(): CopilotCapability[] { return OpenAIProvider.capabilities; } @@ -67,7 +71,7 @@ export class OpenAIProvider return this.availableModels.includes(model); } - private chatToGPTMessage( + protected chatToGPTMessage( messages: PromptMessage[] ): OpenAI.Chat.Completions.ChatCompletionMessageParam[] { // filter redundant fields @@ -92,7 +96,7 @@ export class OpenAIProvider }); } - private checkParams({ + protected checkParams({ messages, embeddings, model, diff --git a/packages/backend/server/src/plugins/copilot/resolver.ts b/packages/backend/server/src/plugins/copilot/resolver.ts index 5670fd728433..e003ed45a2d4 100644 --- a/packages/backend/server/src/plugins/copilot/resolver.ts +++ b/packages/backend/server/src/plugins/copilot/resolver.ts @@ -278,7 +278,9 @@ export class CopilotResolver { return new TooManyRequestsException('Server is busy'); } const session = await this.chatSession.get(options.sessionId); - if (!session) return new BadRequestException('Session not found'); + if (!session || session.config.userId !== user.id) { + return new BadRequestException('Session not found'); + } if (options.blobs) { options.attachments = options.attachments || []; diff --git a/packages/backend/server/src/plugins/copilot/session.ts b/packages/backend/server/src/plugins/copilot/session.ts index c7e28c22c14c..971ffcf9b383 100644 --- a/packages/backend/server/src/plugins/copilot/session.ts +++ b/packages/backend/server/src/plugins/copilot/session.ts @@ -81,7 +81,7 @@ export class ChatSession implements AsyncDisposable { } pop() { - this.state.messages.pop(); + return this.state.messages.pop(); } private takeMessages(): ChatMessage[] { @@ -115,7 +115,7 @@ export class ChatSession implements AsyncDisposable { Object.keys(params).length ? params : messages[0]?.params || {}, this.config.sessionId ), - ...messages.filter(m => m.content || m.attachments?.length), + ...messages.filter(m => m.content?.trim() || m.attachments?.length), ]; } diff --git a/packages/backend/server/src/plugins/copilot/types.ts b/packages/backend/server/src/plugins/copilot/types.ts index 805f0b338707..2e707f96c70f 100644 --- a/packages/backend/server/src/plugins/copilot/types.ts +++ b/packages/backend/server/src/plugins/copilot/types.ts @@ -15,6 +15,7 @@ export interface CopilotConfig { openai: OpenAIClientOptions; fal: FalConfig; unsplashKey: string; + test: never; } export enum AvailableModels { @@ -130,6 +131,8 @@ export type ListHistoriesOptions = { export enum CopilotProviderType { FAL = 'fal', OpenAI = 'openai', + // only for test + Test = 'test', } export enum CopilotCapability { @@ -141,6 +144,7 @@ export enum CopilotCapability { } export interface CopilotProvider { + readonly type: CopilotProviderType; getCapabilities(): CopilotCapability[]; isModelAvailable(model: string): boolean; } diff --git a/packages/backend/server/tests/copilot.e2e.ts b/packages/backend/server/tests/copilot.e2e.ts new file mode 100644 index 000000000000..cd1d9e5437fb --- /dev/null +++ b/packages/backend/server/tests/copilot.e2e.ts @@ -0,0 +1,382 @@ +/// + +import { randomUUID } from 'node:crypto'; + +import { INestApplication } from '@nestjs/common'; +import type { TestFn } from 'ava'; +import ava from 'ava'; +import Sinon from 'sinon'; + +import { AuthService } from '../src/core/auth'; +import { WorkspaceModule } from '../src/core/workspaces'; +import { ConfigModule } from '../src/fundamentals/config'; +import { CopilotModule } from '../src/plugins/copilot'; +import { PromptService } from '../src/plugins/copilot/prompt'; +import { + CopilotProviderService, + registerCopilotProvider, +} from '../src/plugins/copilot/providers'; +import { CopilotStorage } from '../src/plugins/copilot/storage'; +import { + acceptInviteById, + createTestingApp, + createWorkspace, + inviteUser, + signUp, +} from './utils'; +import { + chatWithImages, + chatWithText, + chatWithTextStream, + createCopilotMessage, + createCopilotSession, + getHistories, + MockCopilotTestProvider, + textToEventStream, +} from './utils/copilot'; + +const test = ava as TestFn<{ + auth: AuthService; + app: INestApplication; + prompt: PromptService; + provider: CopilotProviderService; + storage: CopilotStorage; +}>; + +test.beforeEach(async t => { + const { app } = await createTestingApp({ + imports: [ + ConfigModule.forRoot({ + plugins: { + copilot: { + openai: { + apiKey: '1', + }, + fal: { + apiKey: '1', + }, + }, + }, + }), + WorkspaceModule, + CopilotModule, + ], + }); + + const auth = app.get(AuthService); + const prompt = app.get(PromptService); + const storage = app.get(CopilotStorage); + + t.context.app = app; + t.context.auth = auth; + t.context.prompt = prompt; + t.context.storage = storage; +}); + +let token: string; +const promptName = 'prompt'; +test.beforeEach(async t => { + const { app, prompt } = t.context; + const user = await signUp(app, 'test', 'darksky@affine.pro', '123456'); + token = user.token.token; + + registerCopilotProvider(MockCopilotTestProvider); + + await prompt.set(promptName, 'test', [ + { role: 'system', content: 'hello {{word}}' }, + ]); +}); + +test.afterEach.always(async t => { + await t.context.app.close(); +}); + +// ==================== session ==================== + +test('should create session correctly', async t => { + const { app } = t.context; + + const assertCreateSession = async ( + workspaceId: string, + error: string, + asserter = async (x: any) => { + t.truthy(await x, error); + } + ) => { + await asserter( + createCopilotSession(app, token, workspaceId, randomUUID(), promptName) + ); + }; + + { + const { id } = await createWorkspace(app, token); + await assertCreateSession( + id, + 'should be able to create session with cloud workspace that user can access' + ); + } + + { + await assertCreateSession( + randomUUID(), + 'should be able to create session with local workspace' + ); + } + + { + const { + token: { token }, + } = await signUp(app, 'test', 'test@affine.pro', '123456'); + const { id } = await createWorkspace(app, token); + await assertCreateSession(id, '', async x => { + await t.throwsAsync( + x, + { instanceOf: Error }, + 'should not able to create session with cloud workspace that user cannot access' + ); + }); + + const inviteId = await inviteUser( + app, + token, + id, + 'darksky@affine.pro', + 'Admin' + ); + await acceptInviteById(app, id, inviteId, false); + await assertCreateSession( + id, + 'should able to create session after user have permission' + ); + } +}); + +test('should be able to use test provider', async t => { + const { app } = t.context; + + const { id } = await createWorkspace(app, token); + t.truthy( + await createCopilotSession(app, token, id, randomUUID(), promptName), + 'failed to create session' + ); +}); + +// ==================== message ==================== + +test('should create message correctly', async t => { + const { app } = t.context; + + { + const { id } = await createWorkspace(app, token); + const sessionId = await createCopilotSession( + app, + token, + id, + randomUUID(), + promptName + ); + const messageId = await createCopilotMessage(app, token, sessionId); + t.truthy(messageId, 'should be able to create message with valid session'); + } + + { + await t.throwsAsync( + createCopilotMessage(app, token, randomUUID()), + { instanceOf: Error }, + 'should not able to create message with invalid session' + ); + } +}); + +// ==================== chat ==================== + +test('should be able to chat with api', async t => { + const { app, storage } = t.context; + + Sinon.stub(storage, 'handleRemoteLink').resolvesArg(2); + + const { id } = await createWorkspace(app, token); + const sessionId = await createCopilotSession( + app, + token, + id, + randomUUID(), + promptName + ); + const messageId = await createCopilotMessage(app, token, sessionId); + const ret = await chatWithText(app, token, sessionId, messageId); + t.is(ret, 'generate text to text', 'should be able to chat with text'); + + const ret2 = await chatWithTextStream(app, token, sessionId, messageId); + t.is( + ret2, + textToEventStream('generate text to text stream', messageId), + 'should be able to chat with text stream' + ); + + const ret3 = await chatWithImages(app, token, sessionId, messageId); + t.is( + ret3, + textToEventStream( + ['https://example.com/image.jpg'], + messageId, + 'attachment' + ), + 'should be able to chat with images' + ); + + Sinon.restore(); +}); + +test('should reject message from different session', async t => { + const { app } = t.context; + + const { id } = await createWorkspace(app, token); + const sessionId = await createCopilotSession( + app, + token, + id, + randomUUID(), + promptName + ); + const anotherSessionId = await createCopilotSession( + app, + token, + id, + randomUUID(), + promptName + ); + const anotherMessageId = await createCopilotMessage( + app, + token, + anotherSessionId + ); + await t.throwsAsync( + chatWithText(app, token, sessionId, anotherMessageId), + { instanceOf: Error }, + 'should reject message from different session' + ); +}); + +test('should reject request from different user', async t => { + const { app } = t.context; + + const { id } = await createWorkspace(app, token); + const sessionId = await createCopilotSession( + app, + token, + id, + randomUUID(), + promptName + ); + + // should reject message from different user + { + const { token } = await signUp(app, 'a1', 'a1@affine.pro', '123456'); + await t.throwsAsync( + createCopilotMessage(app, token.token, sessionId), + { instanceOf: Error }, + 'should reject message from different user' + ); + } + + // should reject chat from different user + { + const messageId = await createCopilotMessage(app, token, sessionId); + { + const { token } = await signUp(app, 'a2', 'a2@affine.pro', '123456'); + await t.throwsAsync( + chatWithText(app, token.token, sessionId, messageId), + { instanceOf: Error }, + 'should reject chat from different user' + ); + } + } +}); + +// ==================== history ==================== + +test('should be able to list history', async t => { + const { app } = t.context; + + const { id: workspaceId } = await createWorkspace(app, token); + const sessionId = await createCopilotSession( + app, + token, + workspaceId, + randomUUID(), + promptName + ); + + const messageId = await createCopilotMessage(app, token, sessionId); + await chatWithText(app, token, sessionId, messageId); + + const histories = await getHistories(app, token, { workspaceId }); + t.deepEqual( + histories.map(h => h.messages.map(m => m.content)), + [['generate text to text']], + 'should be able to list history' + ); +}); + +test('should reject request that user have not permission', async t => { + const { app } = t.context; + + const { + token: { token: anotherToken }, + } = await signUp(app, 'a1', 'a1@affine.pro', '123456'); + const { id: workspaceId } = await createWorkspace(app, anotherToken); + + // should reject request that user have not permission + { + await t.throwsAsync( + getHistories(app, token, { workspaceId }), + { instanceOf: Error }, + 'should reject request that user have not permission' + ); + } + + // should able to list history after user have permission + { + const inviteId = await inviteUser( + app, + anotherToken, + workspaceId, + 'darksky@affine.pro', + 'Admin' + ); + await acceptInviteById(app, workspaceId, inviteId, false); + + t.deepEqual( + await getHistories(app, token, { workspaceId }), + [], + 'should able to list history after user have permission' + ); + } + + { + const sessionId = await createCopilotSession( + app, + anotherToken, + workspaceId, + randomUUID(), + promptName + ); + + const messageId = await createCopilotMessage(app, anotherToken, sessionId); + await chatWithText(app, anotherToken, sessionId, messageId); + + const histories = await getHistories(app, anotherToken, { workspaceId }); + t.deepEqual( + histories.map(h => h.messages.map(m => m.content)), + [['generate text to text']], + 'should able to list history' + ); + + t.deepEqual( + await getHistories(app, token, { workspaceId }), + [], + 'should not list history created by another user' + ); + } +}); diff --git a/packages/backend/server/tests/copilot.spec.ts b/packages/backend/server/tests/copilot.spec.ts index 75b023ec3d47..40f4bae9eceb 100644 --- a/packages/backend/server/tests/copilot.spec.ts +++ b/packages/backend/server/tests/copilot.spec.ts @@ -5,17 +5,28 @@ import type { TestFn } from 'ava'; import ava from 'ava'; import { AuthService } from '../src/core/auth'; -import { QuotaManagementService, QuotaModule } from '../src/core/quota'; +import { QuotaModule } from '../src/core/quota'; import { ConfigModule } from '../src/fundamentals/config'; import { CopilotModule } from '../src/plugins/copilot'; import { PromptService } from '../src/plugins/copilot/prompt'; +import { + CopilotProviderService, + registerCopilotProvider, +} from '../src/plugins/copilot/providers'; +import { ChatSessionService } from '../src/plugins/copilot/session'; +import { + CopilotCapability, + CopilotProviderType, +} from '../src/plugins/copilot/types'; import { createTestingModule } from './utils'; +import { MockCopilotTestProvider } from './utils/copilot'; const test = ava as TestFn<{ auth: AuthService; - quotaManager: QuotaManagementService; module: TestingModule; prompt: PromptService; + provider: CopilotProviderService; + session: ChatSessionService; }>; test.beforeEach(async t => { @@ -27,6 +38,9 @@ test.beforeEach(async t => { openai: { apiKey: '1', }, + fal: { + apiKey: '1', + }, }, }, }), @@ -35,26 +49,37 @@ test.beforeEach(async t => { ], }); - const quotaManager = module.get(QuotaManagementService); const auth = module.get(AuthService); const prompt = module.get(PromptService); + const provider = module.get(CopilotProviderService); + const session = module.get(ChatSessionService); t.context.module = module; - t.context.quotaManager = quotaManager; t.context.auth = auth; t.context.prompt = prompt; + t.context.provider = provider; + t.context.session = session; }); test.afterEach.always(async t => { await t.context.module.close(); }); +let userId: string; +test.beforeEach(async t => { + const { auth } = t.context; + const user = await auth.signUp('test', 'darksky@affine.pro', '123456'); + userId = user.id; +}); + +// ==================== prompt ==================== + test('should be able to manage prompt', async t => { const { prompt } = t.context; t.is((await prompt.list()).length, 0, 'should have no prompt'); - await prompt.set('test', [ + await prompt.set('test', 'test', [ { role: 'system', content: 'hello' }, { role: 'user', content: 'hello' }, ]); @@ -91,7 +116,7 @@ test('should be able to render prompt', async t => { content: 'hello world', }; - await prompt.set('test', [msg]); + await prompt.set('test', 'test', [msg]); const testPrompt = await prompt.get('test'); t.assert(testPrompt, 'should have prompt'); t.is( @@ -126,7 +151,7 @@ test('should be able to render listed prompt', async t => { links: ['https://affine.pro', 'https://github.com/toeverything/affine'], }; - await prompt.set('test', [msg]); + await prompt.set('test', 'test', [msg]); const testPrompt = await prompt.get('test'); t.is( @@ -135,3 +160,265 @@ test('should be able to render listed prompt', async t => { 'should render the prompt' ); }); + +// ==================== session ==================== + +test('should be able to manage chat session', async t => { + const { prompt, session } = t.context; + + await prompt.set('prompt', 'model', [ + { role: 'system', content: 'hello {{word}}' }, + ]); + + const sessionId = await session.create({ + docId: 'test', + workspaceId: 'test', + userId, + promptName: 'prompt', + }); + t.truthy(sessionId, 'should create session'); + + const s = (await session.get(sessionId))!; + t.is(s.config.sessionId, sessionId, 'should get session'); + t.is(s.config.promptName, 'prompt', 'should have prompt name'); + t.is(s.model, 'model', 'should have model'); + + const params = { word: 'world' }; + + s.push({ role: 'user', content: 'hello', createdAt: new Date() }); + // @ts-expect-error + const finalMessages = s.finish(params).map(({ createdAt: _, ...m }) => m); + t.deepEqual( + finalMessages, + [ + { content: 'hello world', params, role: 'system' }, + { content: 'hello', role: 'user' }, + ], + 'should generate the final message' + ); + await s.save(); + + const s1 = (await session.get(sessionId))!; + t.deepEqual( + // @ts-expect-error + s1.finish(params).map(({ createdAt: _, ...m }) => m), + finalMessages, + 'should same as before message' + ); + t.deepEqual( + // @ts-expect-error + s1.finish({}).map(({ createdAt: _, ...m }) => m), + [ + { content: 'hello ', params: {}, role: 'system' }, + { content: 'hello', role: 'user' }, + ], + 'should generate different message with another params' + ); +}); + +test('should be able to process message id', async t => { + const { prompt, session } = t.context; + + await prompt.set('prompt', 'model', [ + { role: 'system', content: 'hello {{word}}' }, + ]); + + const sessionId = await session.create({ + docId: 'test', + workspaceId: 'test', + userId, + promptName: 'prompt', + }); + const s = (await session.get(sessionId))!; + + const textMessage = (await session.createMessage({ + sessionId, + content: 'hello', + }))!; + const anotherSessionMessage = (await session.createMessage({ + sessionId: 'another-session-id', + }))!; + + await t.notThrowsAsync( + s.pushByMessageId(textMessage), + 'should push by message id' + ); + await t.throwsAsync( + s.pushByMessageId(anotherSessionMessage), + { + instanceOf: Error, + }, + 'should throw error if push by another session message id' + ); + await t.throwsAsync( + s.pushByMessageId('invalid'), + { instanceOf: Error }, + 'should throw error if push by invalid message id' + ); +}); + +test('should be able to generate with message id', async t => { + const { prompt, session } = t.context; + + await prompt.set('prompt', 'model', [ + { role: 'system', content: 'hello {{word}}' }, + ]); + + // text message + { + const sessionId = await session.create({ + docId: 'test', + workspaceId: 'test', + userId, + promptName: 'prompt', + }); + const s = (await session.get(sessionId))!; + + const message = (await session.createMessage({ + sessionId, + content: 'hello', + }))!; + + await s.pushByMessageId(message); + const finalMessages = s + .finish({ word: 'world' }) + .map(({ content }) => content); + t.deepEqual(finalMessages, ['hello world', 'hello']); + } + + // attachment message + { + const sessionId = await session.create({ + docId: 'test', + workspaceId: 'test', + userId, + promptName: 'prompt', + }); + const s = (await session.get(sessionId))!; + + const message = (await session.createMessage({ + sessionId, + attachments: ['https://affine.pro/example.jpg'], + }))!; + + await s.pushByMessageId(message); + const finalMessages = s + .finish({ word: 'world' }) + .map(({ attachments }) => attachments); + t.deepEqual(finalMessages, [ + // system prompt + undefined, + // user prompt + ['https://affine.pro/example.jpg'], + ]); + } + + // empty message + { + const sessionId = await session.create({ + docId: 'test', + workspaceId: 'test', + userId, + promptName: 'prompt', + }); + const s = (await session.get(sessionId))!; + + const message = (await session.createMessage({ + sessionId, + }))!; + + await s.pushByMessageId(message); + const finalMessages = s + .finish({ word: 'world' }) + .map(({ content }) => content); + // empty message should be filtered + t.deepEqual(finalMessages, ['hello world']); + } +}); + +// ==================== provider ==================== + +test('should be able to get provider', async t => { + const { provider } = t.context; + + { + const p = provider.getProviderByCapability(CopilotCapability.TextToText); + t.is( + p?.type.toString(), + 'openai', + 'should get provider support text-to-text' + ); + } + + { + const p = provider.getProviderByCapability( + CopilotCapability.TextToEmbedding + ); + t.is( + p?.type.toString(), + 'openai', + 'should get provider support text-to-embedding' + ); + } + + { + const p = provider.getProviderByCapability(CopilotCapability.TextToImage); + t.is( + p?.type.toString(), + 'fal', + 'should get provider support text-to-image' + ); + } + + { + const p = provider.getProviderByCapability(CopilotCapability.ImageToImage); + t.is( + p?.type.toString(), + 'fal', + 'should get provider support image-to-image' + ); + } + + { + const p = provider.getProviderByCapability(CopilotCapability.ImageToText); + t.is( + p?.type.toString(), + 'openai', + 'should get provider support image-to-text' + ); + } + + // text-to-image use fal by default, but this case can use + // model dall-e-3 to select openai provider + { + const p = provider.getProviderByCapability( + CopilotCapability.TextToImage, + 'dall-e-3' + ); + t.is( + p?.type.toString(), + 'openai', + 'should get provider support text-to-image and model' + ); + } +}); + +test('should be able to register test provider', async t => { + const { provider } = t.context; + registerCopilotProvider(MockCopilotTestProvider); + + const assertProvider = (cap: CopilotCapability) => { + const p = provider.getProviderByCapability(cap, 'test'); + t.is( + p?.type, + CopilotProviderType.Test, + `should get test provider with ${cap}` + ); + }; + + assertProvider(CopilotCapability.TextToText); + assertProvider(CopilotCapability.TextToEmbedding); + assertProvider(CopilotCapability.TextToImage); + assertProvider(CopilotCapability.ImageToImage); + assertProvider(CopilotCapability.ImageToText); +}); diff --git a/packages/backend/server/tests/utils/copilot.ts b/packages/backend/server/tests/utils/copilot.ts new file mode 100644 index 000000000000..18df53783dd5 --- /dev/null +++ b/packages/backend/server/tests/utils/copilot.ts @@ -0,0 +1,305 @@ +import { randomBytes } from 'node:crypto'; + +import { INestApplication } from '@nestjs/common'; +import request from 'supertest'; + +import { + DEFAULT_DIMENSIONS, + OpenAIProvider, +} from '../../src/plugins/copilot/providers/openai'; +import { + CopilotCapability, + CopilotImageToImageProvider, + CopilotImageToTextProvider, + CopilotProviderType, + CopilotTextToEmbeddingProvider, + CopilotTextToImageProvider, + CopilotTextToTextProvider, + PromptMessage, +} from '../../src/plugins/copilot/types'; +import { gql } from './common'; +import { handleGraphQLError } from './utils'; + +export class MockCopilotTestProvider + extends OpenAIProvider + implements + CopilotTextToTextProvider, + CopilotTextToEmbeddingProvider, + CopilotTextToImageProvider, + CopilotImageToImageProvider, + CopilotImageToTextProvider +{ + override readonly availableModels = ['test']; + static override readonly capabilities = [ + CopilotCapability.TextToText, + CopilotCapability.TextToEmbedding, + CopilotCapability.TextToImage, + CopilotCapability.ImageToImage, + CopilotCapability.ImageToText, + ]; + + override get type(): CopilotProviderType { + return CopilotProviderType.Test; + } + + override getCapabilities(): CopilotCapability[] { + return MockCopilotTestProvider.capabilities; + } + + override isModelAvailable(model: string): boolean { + return this.availableModels.includes(model); + } + + // ====== text to text ====== + + override async generateText( + messages: PromptMessage[], + model: string = 'test', + _options: { + temperature?: number; + maxTokens?: number; + signal?: AbortSignal; + user?: string; + } = {} + ): Promise { + this.checkParams({ messages, model }); + return 'generate text to text'; + } + + override async *generateTextStream( + messages: PromptMessage[], + model: string = 'gpt-3.5-turbo', + options: { + temperature?: number; + maxTokens?: number; + signal?: AbortSignal; + user?: string; + } = {} + ): AsyncIterable { + this.checkParams({ messages, model }); + + const result = 'generate text to text stream'; + for await (const message of result) { + yield message; + if (options.signal?.aborted) { + break; + } + } + } + + // ====== text to embedding ====== + + override async generateEmbedding( + messages: string | string[], + model: string, + options: { + dimensions: number; + signal?: AbortSignal; + user?: string; + } = { dimensions: DEFAULT_DIMENSIONS } + ): Promise { + messages = Array.isArray(messages) ? messages : [messages]; + this.checkParams({ embeddings: messages, model }); + + return [Array.from(randomBytes(options.dimensions)).map(v => v % 128)]; + } + + // ====== text to image ====== + override async generateImages( + messages: PromptMessage[], + _model: string = 'test', + _options: { + signal?: AbortSignal; + user?: string; + } = {} + ): Promise> { + const { content: prompt } = messages.pop() || {}; + if (!prompt) { + throw new Error('Prompt is required'); + } + + return ['https://example.com/image.jpg']; + } + + override async *generateImagesStream( + messages: PromptMessage[], + model: string = 'dall-e-3', + options: { + signal?: AbortSignal; + user?: string; + } = {} + ): AsyncIterable { + const ret = await this.generateImages(messages, model, options); + for (const url of ret) { + yield url; + } + } +} + +export async function createCopilotSession( + app: INestApplication, + userToken: string, + workspaceId: string, + docId: string, + promptName: string +): Promise { + const res = await request(app.getHttpServer()) + .post(gql) + .auth(userToken, { type: 'bearer' }) + .set({ 'x-request-id': 'test', 'x-operation-name': 'test' }) + .send({ + query: ` + mutation createCopilotSession($options: CreateChatSessionInput!) { + createCopilotSession(options: $options) + } + `, + variables: { options: { workspaceId, docId, promptName } }, + }) + .expect(200); + + handleGraphQLError(res); + + return res.body.data.createCopilotSession; +} + +export async function createCopilotMessage( + app: INestApplication, + userToken: string, + sessionId: string, + content?: string, + attachments?: string[], + blobs?: ArrayBuffer[], + params?: Record +): Promise { + const res = await request(app.getHttpServer()) + .post(gql) + .auth(userToken, { type: 'bearer' }) + .set({ 'x-request-id': 'test', 'x-operation-name': 'test' }) + .send({ + query: ` + mutation createCopilotMessage($options: CreateChatMessageInput!) { + createCopilotMessage(options: $options) + } + `, + variables: { + options: { sessionId, content, attachments, blobs, params }, + }, + }) + .expect(200); + + handleGraphQLError(res); + + return res.body.data.createCopilotMessage; +} + +export async function chatWithText( + app: INestApplication, + userToken: string, + sessionId: string, + messageId: string, + prefix = '' +): Promise { + const res = await request(app.getHttpServer()) + .get(`/api/copilot/chat/${sessionId}${prefix}?messageId=${messageId}`) + .auth(userToken, { type: 'bearer' }) + .expect(200); + + return res.text; +} + +export async function chatWithTextStream( + app: INestApplication, + userToken: string, + sessionId: string, + messageId: string +) { + return chatWithText(app, userToken, sessionId, messageId, '/stream'); +} + +export async function chatWithImages( + app: INestApplication, + userToken: string, + sessionId: string, + messageId: string +) { + return chatWithText(app, userToken, sessionId, messageId, '/images'); +} + +export function textToEventStream( + content: string | string[], + id: string, + event = 'message' +): string { + return ( + Array.from(content) + .map(x => `\nevent: ${event}\nid: ${id}\ndata: ${x}`) + .join('\n') + '\n\n' + ); +} + +type ChatMessage = { + role: string; + content: string; + attachments: string[] | null; + createdAt: string; +}; + +type History = { + sessionId: string; + tokens: number; + action: string | null; + createdAt: string; + messages: ChatMessage[]; +}; + +export async function getHistories( + app: INestApplication, + userToken: string, + variables: { + workspaceId: string; + docId?: string; + options?: { + sessionId?: string; + action?: boolean; + limit?: number; + skip?: number; + }; + } +): Promise { + const res = await request(app.getHttpServer()) + .post(gql) + .auth(userToken, { type: 'bearer' }) + .set({ 'x-request-id': 'test', 'x-operation-name': 'test' }) + .send({ + query: ` + query getCopilotHistories( + $workspaceId: String! + $docId: String + $options: QueryChatHistoriesInput + ) { + currentUser { + copilot(workspaceId: $workspaceId) { + histories(docId: $docId, options: $options) { + sessionId + tokens + action + createdAt + messages { + role + content + attachments + createdAt + } + } + } + } + } + `, + variables, + }) + .expect(200); + + handleGraphQLError(res); + + return res.body.data.currentUser?.copilot?.histories || []; +} diff --git a/packages/backend/server/tests/utils/utils.ts b/packages/backend/server/tests/utils/utils.ts index 88351d2df9b6..49c588b4ec04 100644 --- a/packages/backend/server/tests/utils/utils.ts +++ b/packages/backend/server/tests/utils/utils.ts @@ -5,6 +5,7 @@ import { Test, TestingModuleBuilder } from '@nestjs/testing'; import { PrismaClient } from '@prisma/client'; import cookieParser from 'cookie-parser'; import graphqlUploadExpress from 'graphql-upload/graphqlUploadExpress.mjs'; +import type { Response } from 'supertest'; import { AppModule, FunctionalityModules } from '../../src/app.module'; import { AuthGuard, AuthModule } from '../../src/core/auth'; @@ -136,3 +137,12 @@ export async function createTestingApp(moduleDef: TestingModuleMeatdata = {}) { app, }; } + +export function handleGraphQLError(resp: Response) { + const { errors } = resp.body; + if (errors) { + const cause = errors[0]; + const stacktrace = cause.extensions?.stacktrace; + throw new Error(stacktrace ? stacktrace.join('\n') : cause.message, cause); + } +}