Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: copilot unit & e2e test #6649

Merged
merged 1 commit into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/backend/server/src/config/affine.self.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ if (env.R2_OBJECT_STORAGE_ACCOUNT_ID) {

AFFiNE.plugins.use('copilot', {
openai: {},
fal: {},
});
AFFiNE.plugins.use('redis');
AFFiNE.plugins.use('payment', {
Expand Down
251 changes: 131 additions & 120 deletions packages/backend/server/src/plugins/copilot/controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ export interface ChatEvent {
data: string;
}

type CheckResult = {
model: string | undefined;
hasAttachment?: boolean;
};

@Controller('/api/copilot')
export class CopilotController {
private readonly logger = new Logger(CopilotController.name);
Expand All @@ -53,17 +58,26 @@ export class CopilotController {
private readonly storage: CopilotStorage
) {}

private async hasAttachment(sessionId: string, messageId: string) {
private async checkRequest(
userId: string,
sessionId: string,
messageId?: string
): Promise<CheckResult> {
await this.chatSession.checkQuota(userId);
const session = await this.chatSession.get(sessionId);
if (!session) {
if (!session || session.config.userId !== userId) {
throw new BadRequestException('Session not found');
}

const message = await session.getMessageById(messageId);
if (Array.isArray(message.attachments) && message.attachments.length) {
return true;
const ret: CheckResult = { model: session.model };

if (messageId) {
const message = await session.getMessageById(messageId);
ret.hasAttachment =
Array.isArray(message.attachments) && !!message.attachments.length;
}
return false;

return ret;
}

private async appendSessionMessage(
Expand Down Expand Up @@ -107,9 +121,7 @@ export class CopilotController {
@Query('messageId') messageId: string,
@Query() params: Record<string, string | string[]>
): Promise<string> {
await this.chatSession.checkQuota(user.id);

const model = await this.chatSession.get(sessionId).then(s => s?.model);
const { model } = await this.checkRequest(user.id, sessionId);
const provider = this.provider.getProviderByCapability(
CopilotCapability.TextToText,
model
Expand Down Expand Up @@ -155,60 +167,58 @@ export class CopilotController {
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
try {
await this.chatSession.checkQuota(user.id);
const { model } = await this.checkRequest(user.id, sessionId);
const provider = this.provider.getProviderByCapability(
CopilotCapability.TextToText,
model
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
}

const session = await this.appendSessionMessage(sessionId, messageId);
delete params.messageId;

return from(
provider.generateTextStream(session.finish(params), session.model, {
signal: this.getSignal(req),
user: user.id,
})
).pipe(
connect(shared$ =>
merge(
// actual chat event stream
shared$.pipe(
map(data => ({ type: 'message' as const, id: messageId, data }))
),
// save the generated text to the session
shared$.pipe(
toArray(),
concatMap(values => {
session.push({
role: 'assistant',
content: values.join(''),
createdAt: new Date(),
});
return from(session.save());
}),
switchMap(() => EMPTY)
)
)
),
catchError(err =>
of({
type: 'error' as const,
data: this.handleError(err),
})
)
);
} catch (err) {
return of({
type: 'error' as const,
data: this.handleError(err),
});
}

const model = await this.chatSession.get(sessionId).then(s => s?.model);
const provider = this.provider.getProviderByCapability(
CopilotCapability.TextToText,
model
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
}

const session = await this.appendSessionMessage(sessionId, messageId);
delete params.messageId;

return from(
provider.generateTextStream(session.finish(params), session.model, {
signal: this.getSignal(req),
user: user.id,
})
).pipe(
connect(shared$ =>
merge(
// actual chat event stream
shared$.pipe(
map(data => ({ type: 'message' as const, id: sessionId, data }))
),
// save the generated text to the session
shared$.pipe(
toArray(),
concatMap(values => {
session.push({
role: 'assistant',
content: values.join(''),
createdAt: new Date(),
});
return from(session.save());
}),
switchMap(() => EMPTY)
)
)
),
catchError(err =>
of({
type: 'error' as const,
data: this.handleError(err),
})
)
);
}

@Sse('/chat/:sessionId/images')
Expand All @@ -220,75 +230,76 @@ export class CopilotController {
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
try {
await this.chatSession.checkQuota(user.id);
const { model, hasAttachment } = await this.checkRequest(
user.id,
sessionId,
messageId
);
const provider = this.provider.getProviderByCapability(
hasAttachment
? CopilotCapability.ImageToImage
: CopilotCapability.TextToImage,
model
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
}

const session = await this.appendSessionMessage(sessionId, messageId);
delete params.messageId;

const handleRemoteLink = this.storage.handleRemoteLink.bind(
this.storage,
user.id,
sessionId
);

return from(
provider.generateImagesStream(session.finish(params), session.model, {
signal: this.getSignal(req),
user: user.id,
})
).pipe(
mergeMap(handleRemoteLink),
connect(shared$ =>
merge(
// actual chat event stream
shared$.pipe(
map(attachment => ({
type: 'attachment' as const,
id: messageId,
data: attachment,
}))
),
// save the generated text to the session
shared$.pipe(
toArray(),
concatMap(attachments => {
session.push({
role: 'assistant',
content: '',
attachments: attachments,
createdAt: new Date(),
});
return from(session.save());
}),
switchMap(() => EMPTY)
)
)
),
catchError(err =>
of({
type: 'error' as const,
data: this.handleError(err),
})
)
);
} catch (err) {
return of({
type: 'error' as const,
data: this.handleError(err),
});
}

const hasAttachment = await this.hasAttachment(sessionId, messageId);
const model = await this.chatSession.get(sessionId).then(s => s?.model);
const provider = this.provider.getProviderByCapability(
hasAttachment
? CopilotCapability.ImageToImage
: CopilotCapability.TextToImage,
model
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
}

const session = await this.appendSessionMessage(sessionId, messageId);
delete params.messageId;

const handleRemoteLink = this.storage.handleRemoteLink.bind(
this.storage,
user.id,
sessionId
);

return from(
provider.generateImagesStream(session.finish(params), session.model, {
signal: this.getSignal(req),
user: user.id,
})
).pipe(
mergeMap(handleRemoteLink),
connect(shared$ =>
merge(
// actual chat event stream
shared$.pipe(
map(attachment => ({
type: 'attachment' as const,
id: sessionId,
data: attachment,
}))
),
// save the generated text to the session
shared$.pipe(
toArray(),
concatMap(attachments => {
session.push({
role: 'assistant',
content: '',
attachments: attachments,
createdAt: new Date(),
});
return from(session.save());
}),
switchMap(() => EMPTY)
)
)
),
catchError(err =>
of({
type: 'error' as const,
data: this.handleError(err),
})
)
);
}

@Get('/unsplash/photos')
Expand Down
3 changes: 2 additions & 1 deletion packages/backend/server/src/plugins/copilot/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,12 @@ export class PromptService {
return null;
}

async set(name: string, messages: PromptMessage[]) {
async set(name: string, model: string, messages: PromptMessage[]) {
return await this.db.aiPrompt
.create({
data: {
name,
model,
messages: {
create: messages.map((m, idx) => ({
idx,
Expand Down
4 changes: 4 additions & 0 deletions packages/backend/server/src/plugins/copilot/providers/fal.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ export class FalProvider
return !!config.apiKey;
}

get type(): CopilotProviderType {
return FalProvider.type;
}

getCapabilities(): CopilotCapability[] {
return FalProvider.capabilities;
}
Expand Down
10 changes: 7 additions & 3 deletions packages/backend/server/src/plugins/copilot/providers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import {
PromptMessage,
} from '../types';

const DEFAULT_DIMENSIONS = 256;
export const DEFAULT_DIMENSIONS = 256;

const SIMPLE_IMAGE_URL_REGEX = /^(https?:\/\/|data:image\/)/;

Expand Down Expand Up @@ -59,6 +59,10 @@ export class OpenAIProvider
return !!config.apiKey;
}

get type(): CopilotProviderType {
return OpenAIProvider.type;
}

getCapabilities(): CopilotCapability[] {
return OpenAIProvider.capabilities;
}
Expand All @@ -67,7 +71,7 @@ export class OpenAIProvider
return this.availableModels.includes(model);
}

private chatToGPTMessage(
protected chatToGPTMessage(
messages: PromptMessage[]
): OpenAI.Chat.Completions.ChatCompletionMessageParam[] {
// filter redundant fields
Expand All @@ -92,7 +96,7 @@ export class OpenAIProvider
});
}

private checkParams({
protected checkParams({
messages,
embeddings,
model,
Expand Down
4 changes: 3 additions & 1 deletion packages/backend/server/src/plugins/copilot/resolver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,9 @@ export class CopilotResolver {
return new TooManyRequestsException('Server is busy');
}
const session = await this.chatSession.get(options.sessionId);
if (!session) return new BadRequestException('Session not found');
if (!session || session.config.userId !== user.id) {
return new BadRequestException('Session not found');
}

if (options.blobs) {
options.attachments = options.attachments || [];
Expand Down