Skip to content

Commit

Permalink
test: copilot unit & e2e test (#6649)
Browse files Browse the repository at this point in the history
  • Loading branch information
darkskygit committed Apr 26, 2024
1 parent f015a11 commit 850bbee
Show file tree
Hide file tree
Showing 12 changed files with 1,145 additions and 134 deletions.
1 change: 1 addition & 0 deletions packages/backend/server/src/config/affine.self.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ if (env.R2_OBJECT_STORAGE_ACCOUNT_ID) {

AFFiNE.plugins.use('copilot', {
openai: {},
fal: {},
});
AFFiNE.plugins.use('redis');
AFFiNE.plugins.use('payment', {
Expand Down
251 changes: 131 additions & 120 deletions packages/backend/server/src/plugins/copilot/controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ export interface ChatEvent {
data: string;
}

type CheckResult = {
model: string | undefined;
hasAttachment?: boolean;
};

@Controller('/api/copilot')
export class CopilotController {
private readonly logger = new Logger(CopilotController.name);
Expand All @@ -53,17 +58,26 @@ export class CopilotController {
private readonly storage: CopilotStorage
) {}

private async hasAttachment(sessionId: string, messageId: string) {
private async checkRequest(
userId: string,
sessionId: string,
messageId?: string
): Promise<CheckResult> {
await this.chatSession.checkQuota(userId);
const session = await this.chatSession.get(sessionId);
if (!session) {
if (!session || session.config.userId !== userId) {
throw new BadRequestException('Session not found');
}

const message = await session.getMessageById(messageId);
if (Array.isArray(message.attachments) && message.attachments.length) {
return true;
const ret: CheckResult = { model: session.model };

if (messageId) {
const message = await session.getMessageById(messageId);
ret.hasAttachment =
Array.isArray(message.attachments) && !!message.attachments.length;
}
return false;

return ret;
}

private async appendSessionMessage(
Expand Down Expand Up @@ -107,9 +121,7 @@ export class CopilotController {
@Query('messageId') messageId: string,
@Query() params: Record<string, string | string[]>
): Promise<string> {
await this.chatSession.checkQuota(user.id);

const model = await this.chatSession.get(sessionId).then(s => s?.model);
const { model } = await this.checkRequest(user.id, sessionId);
const provider = this.provider.getProviderByCapability(
CopilotCapability.TextToText,
model
Expand Down Expand Up @@ -155,60 +167,58 @@ export class CopilotController {
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
try {
await this.chatSession.checkQuota(user.id);
const { model } = await this.checkRequest(user.id, sessionId);
const provider = this.provider.getProviderByCapability(
CopilotCapability.TextToText,
model
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
}

const session = await this.appendSessionMessage(sessionId, messageId);
delete params.messageId;

return from(
provider.generateTextStream(session.finish(params), session.model, {
signal: this.getSignal(req),
user: user.id,
})
).pipe(
connect(shared$ =>
merge(
// actual chat event stream
shared$.pipe(
map(data => ({ type: 'message' as const, id: messageId, data }))
),
// save the generated text to the session
shared$.pipe(
toArray(),
concatMap(values => {
session.push({
role: 'assistant',
content: values.join(''),
createdAt: new Date(),
});
return from(session.save());
}),
switchMap(() => EMPTY)
)
)
),
catchError(err =>
of({
type: 'error' as const,
data: this.handleError(err),
})
)
);
} catch (err) {
return of({
type: 'error' as const,
data: this.handleError(err),
});
}

const model = await this.chatSession.get(sessionId).then(s => s?.model);
const provider = this.provider.getProviderByCapability(
CopilotCapability.TextToText,
model
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
}

const session = await this.appendSessionMessage(sessionId, messageId);
delete params.messageId;

return from(
provider.generateTextStream(session.finish(params), session.model, {
signal: this.getSignal(req),
user: user.id,
})
).pipe(
connect(shared$ =>
merge(
// actual chat event stream
shared$.pipe(
map(data => ({ type: 'message' as const, id: sessionId, data }))
),
// save the generated text to the session
shared$.pipe(
toArray(),
concatMap(values => {
session.push({
role: 'assistant',
content: values.join(''),
createdAt: new Date(),
});
return from(session.save());
}),
switchMap(() => EMPTY)
)
)
),
catchError(err =>
of({
type: 'error' as const,
data: this.handleError(err),
})
)
);
}

@Sse('/chat/:sessionId/images')
Expand All @@ -220,75 +230,76 @@ export class CopilotController {
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
try {
await this.chatSession.checkQuota(user.id);
const { model, hasAttachment } = await this.checkRequest(
user.id,
sessionId,
messageId
);
const provider = this.provider.getProviderByCapability(
hasAttachment
? CopilotCapability.ImageToImage
: CopilotCapability.TextToImage,
model
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
}

const session = await this.appendSessionMessage(sessionId, messageId);
delete params.messageId;

const handleRemoteLink = this.storage.handleRemoteLink.bind(
this.storage,
user.id,
sessionId
);

return from(
provider.generateImagesStream(session.finish(params), session.model, {
signal: this.getSignal(req),
user: user.id,
})
).pipe(
mergeMap(handleRemoteLink),
connect(shared$ =>
merge(
// actual chat event stream
shared$.pipe(
map(attachment => ({
type: 'attachment' as const,
id: messageId,
data: attachment,
}))
),
// save the generated text to the session
shared$.pipe(
toArray(),
concatMap(attachments => {
session.push({
role: 'assistant',
content: '',
attachments: attachments,
createdAt: new Date(),
});
return from(session.save());
}),
switchMap(() => EMPTY)
)
)
),
catchError(err =>
of({
type: 'error' as const,
data: this.handleError(err),
})
)
);
} catch (err) {
return of({
type: 'error' as const,
data: this.handleError(err),
});
}

const hasAttachment = await this.hasAttachment(sessionId, messageId);
const model = await this.chatSession.get(sessionId).then(s => s?.model);
const provider = this.provider.getProviderByCapability(
hasAttachment
? CopilotCapability.ImageToImage
: CopilotCapability.TextToImage,
model
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
}

const session = await this.appendSessionMessage(sessionId, messageId);
delete params.messageId;

const handleRemoteLink = this.storage.handleRemoteLink.bind(
this.storage,
user.id,
sessionId
);

return from(
provider.generateImagesStream(session.finish(params), session.model, {
signal: this.getSignal(req),
user: user.id,
})
).pipe(
mergeMap(handleRemoteLink),
connect(shared$ =>
merge(
// actual chat event stream
shared$.pipe(
map(attachment => ({
type: 'attachment' as const,
id: sessionId,
data: attachment,
}))
),
// save the generated text to the session
shared$.pipe(
toArray(),
concatMap(attachments => {
session.push({
role: 'assistant',
content: '',
attachments: attachments,
createdAt: new Date(),
});
return from(session.save());
}),
switchMap(() => EMPTY)
)
)
),
catchError(err =>
of({
type: 'error' as const,
data: this.handleError(err),
})
)
);
}

@Get('/unsplash/photos')
Expand Down
3 changes: 2 additions & 1 deletion packages/backend/server/src/plugins/copilot/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,12 @@ export class PromptService {
return null;
}

async set(name: string, messages: PromptMessage[]) {
async set(name: string, model: string, messages: PromptMessage[]) {
return await this.db.aiPrompt
.create({
data: {
name,
model,
messages: {
create: messages.map((m, idx) => ({
idx,
Expand Down
4 changes: 4 additions & 0 deletions packages/backend/server/src/plugins/copilot/providers/fal.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ export class FalProvider
return !!config.apiKey;
}

get type(): CopilotProviderType {
return FalProvider.type;
}

getCapabilities(): CopilotCapability[] {
return FalProvider.capabilities;
}
Expand Down
10 changes: 7 additions & 3 deletions packages/backend/server/src/plugins/copilot/providers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import {
PromptMessage,
} from '../types';

const DEFAULT_DIMENSIONS = 256;
export const DEFAULT_DIMENSIONS = 256;

const SIMPLE_IMAGE_URL_REGEX = /^(https?:\/\/|data:image\/)/;

Expand Down Expand Up @@ -59,6 +59,10 @@ export class OpenAIProvider
return !!config.apiKey;
}

get type(): CopilotProviderType {
return OpenAIProvider.type;
}

getCapabilities(): CopilotCapability[] {
return OpenAIProvider.capabilities;
}
Expand All @@ -67,7 +71,7 @@ export class OpenAIProvider
return this.availableModels.includes(model);
}

private chatToGPTMessage(
protected chatToGPTMessage(
messages: PromptMessage[]
): OpenAI.Chat.Completions.ChatCompletionMessageParam[] {
// filter redundant fields
Expand All @@ -92,7 +96,7 @@ export class OpenAIProvider
});
}

private checkParams({
protected checkParams({
messages,
embeddings,
model,
Expand Down
4 changes: 3 additions & 1 deletion packages/backend/server/src/plugins/copilot/resolver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,9 @@ export class CopilotResolver {
return new TooManyRequestsException('Server is busy');
}
const session = await this.chatSession.get(options.sessionId);
if (!session) return new BadRequestException('Session not found');
if (!session || session.config.userId !== user.id) {
return new BadRequestException('Session not found');
}

if (options.blobs) {
options.attachments = options.attachments || [];
Expand Down

0 comments on commit 850bbee

Please sign in to comment.