Skip to content

Commit

Permalink
feat: switch generate provider based on attachment
Browse files Browse the repository at this point in the history
  • Loading branch information
darkskygit committed Apr 9, 2024
1 parent 31a723d commit 76f2196
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
22 changes: 19 additions & 3 deletions packages/backend/server/src/plugins/copilot/controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import { Public } from '../../core/auth';
import { CurrentUser } from '../../core/auth/current-user';
import { ProviderService } from './providers';
import { ChatSession, ChatSessionService } from './session';
import { CopilotCapability, CopilotProviderType } from './types';
import { CopilotCapability } from './types';

export interface ChatEvent {
type: 'attachment' | 'message';
Expand All @@ -39,6 +39,21 @@ export class CopilotController {
private readonly provider: ProviderService
) {}

private async hasAttachment(sessionId: string, messageId?: string) {
const session = await this.chatSession.get(sessionId);
if (!session) {
throw new BadRequestException('Session not found');
}

if (messageId) {
const message = await session.getMessageById(messageId);
if (Array.isArray(message.attachments) && message.attachments.length) {
return true;
}
}
return false;
}

private async appendSessionMessage(
sessionId: string,
message?: string,
Expand Down Expand Up @@ -180,8 +195,9 @@ export class CopilotController {
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
const provider = this.provider.getProviderByCapability(
CopilotCapability.ImageToImage,
CopilotProviderType.FAL
(await this.hasAttachment(sessionId, messageId))
? CopilotCapability.ImageToImage
: CopilotCapability.TextToImage
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
Expand Down
8 changes: 8 additions & 0 deletions packages/backend/server/src/plugins/copilot/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ export class ChatSession implements AsyncDisposable {
this.state.messages.push(message);
}

async getMessageById(messageId: string) {
const message = await this.messageCache.get(messageId);
if (!message || message.sessionId !== this.state.sessionId) {
throw new Error(`Message not found: ${messageId}`);
}
return message;
}

async pushByMessageId(messageId: string) {
const message = await this.messageCache.get(messageId);
if (!message || message.sessionId !== this.state.sessionId) {
Expand Down

0 comments on commit 76f2196

Please sign in to comment.