Skip to content

Commit

Permalink
feat: extract common session process
Browse files Browse the repository at this point in the history
  • Loading branch information
darkskygit committed Apr 1, 2024
1 parent 06df99f commit 9d8a8fb
Showing 1 changed file with 47 additions and 37 deletions.
84 changes: 47 additions & 37 deletions packages/backend/server/src/plugins/copilot/controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import {
import { Public } from '../../core/auth';
import { CurrentUser } from '../../core/auth/current-user';
import { ProviderService } from './providers';
import { ChatSessionService } from './session';
import { ChatSession, ChatSessionService } from './session';
import { CopilotCapability } from './types';

export interface ChatEvent {
Expand All @@ -39,32 +39,52 @@ export class CopilotController {
private readonly provider: ProviderService
) {}

private async appendSessionMessage(
sessionId: string,
message?: string,
messageId?: string
): Promise<ChatSession> {
const session = await this.chatSession.get(sessionId);
if (!session) {
throw new BadRequestException('Session not found');
}

if (messageId) {
await session.pushByMessageId(messageId);
} else {
if (!message || !message.trim()) {
throw new BadRequestException('Message is empty');
}
session.push({
role: 'user',
content: decodeURIComponent(message),
createdAt: new Date(),
});
}
return session;
}

@Public()
@Get('/chat/:sessionId')
async chat(
@CurrentUser() user: CurrentUser | undefined,
@Req() req: Request,
@Param('sessionId') sessionId: string,
@Query('message') content: string
@Query('message') message: string | undefined,
@Query('messageId') messageId: string | undefined
): Promise<string> {
const provider = this.provider.getProviderByCapability(
CopilotCapability.TextToText
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
}
const session = await this.chatSession.get(sessionId);
if (!session) {
throw new BadRequestException('Session not found');
}
if (!content || !content.trim()) {
throw new BadRequestException('Message is empty');
}
session.push({
role: 'user',
content: decodeURIComponent(content),
createdAt: new Date(),
});

const session = await this.appendSessionMessage(
sessionId,
message,
messageId
);

try {
const content = await provider.generateText(
Expand Down Expand Up @@ -97,7 +117,7 @@ export class CopilotController {
@CurrentUser() user: CurrentUser | undefined,
@Req() req: Request,
@Param('sessionId') sessionId: string,
@Query('message') content: string | undefined,
@Query('message') message: string | undefined,
@Query('messageId') messageId: string | undefined
): Promise<Observable<ChatEvent>> {
const provider = this.provider.getProviderByCapability(
Expand All @@ -106,23 +126,12 @@ export class CopilotController {
if (!provider) {
throw new InternalServerErrorException('No provider available');
}
const session = await this.chatSession.get(sessionId);
if (!session) {
throw new BadRequestException('Session not found');
}

if (messageId) {
await session.pushByMessageId(messageId);
} else {
if (!content || !content.trim()) {
throw new BadRequestException('Message is empty');
}
session.push({
role: 'user',
content: decodeURIComponent(content),
createdAt: new Date(),
});
}
const session = await this.appendSessionMessage(
sessionId,
message,
messageId
);

return from(
provider.generateTextStream(session.finish(), session.model, {
Expand Down Expand Up @@ -160,20 +169,21 @@ export class CopilotController {
@CurrentUser() user: CurrentUser | undefined,
@Req() req: Request,
@Param('sessionId') sessionId: string,
@Query('messageId') messageId: string
@Query('message') message: string | undefined,
@Query('messageId') messageId: string | undefined
): Promise<Observable<ChatEvent>> {
const provider = this.provider.getProviderByCapability(
CopilotCapability.TextToImage
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
}
const session = await this.chatSession.get(sessionId);
if (!session) {
throw new BadRequestException('Session not found');
}

await session.pushByMessageId(messageId);
const session = await this.appendSessionMessage(
sessionId,
message,
messageId
);

return from(
provider.generateImagesStream(session.finish(), session.model, {
Expand Down

0 comments on commit 9d8a8fb

Please sign in to comment.