Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: pick copilot provider depend on model #6540

Merged
merged 1 commit into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions packages/backend/server/src/data/migrations/utils/prompts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ export const prompts: Prompt[] = [
model: '110602490-lcm-sd15-i2i',
messages: [],
},
{
name: 'debug:action:fal-sdturbo',
action: 'image',
model: 'fast-turbo-diffusion',
messages: [],
},
{
name: 'Summary',
action: 'Summary',
Expand Down
15 changes: 11 additions & 4 deletions packages/backend/server/src/plugins/copilot/controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,10 @@ export class CopilotController {
@Query('messageId') messageId: string | undefined,
@Query() params: Record<string, string | string[]>
): Promise<string> {
const model = await this.chatSession.get(sessionId).then(s => s?.model);
const provider = this.provider.getProviderByCapability(
CopilotCapability.TextToText
CopilotCapability.TextToText,
model
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
Expand Down Expand Up @@ -139,8 +141,10 @@ export class CopilotController {
@Query('messageId') messageId: string | undefined,
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
const model = await this.chatSession.get(sessionId).then(s => s?.model);
const provider = this.provider.getProviderByCapability(
CopilotCapability.TextToText
CopilotCapability.TextToText,
model
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
Expand Down Expand Up @@ -194,10 +198,13 @@ export class CopilotController {
@Query('messageId') messageId: string | undefined,
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
const hasAttachment = await this.hasAttachment(sessionId, messageId);
const model = await this.chatSession.get(sessionId).then(s => s?.model);
const provider = this.provider.getProviderByCapability(
(await this.hasAttachment(sessionId, messageId))
hasAttachment
? CopilotCapability.ImageToImage
: CopilotCapability.TextToImage
: CopilotCapability.TextToImage,
model
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
Expand Down
41 changes: 29 additions & 12 deletions packages/backend/server/src/plugins/copilot/providers/fal.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
CopilotCapability,
CopilotImageToImageProvider,
CopilotProviderType,
CopilotTextToImageProvider,
PromptMessage,
} from '../types';

Expand All @@ -12,17 +13,24 @@ export type FalConfig = {
};

export type FalResponse = {
detail: Array<{ msg: string }>;
images: Array<{ url: string }>;
};

export class FalProvider implements CopilotImageToImageProvider {
export class FalProvider
implements CopilotTextToImageProvider, CopilotImageToImageProvider
{
static readonly type = CopilotProviderType.FAL;
static readonly capabilities = [CopilotCapability.ImageToImage];
static readonly capabilities = [
CopilotCapability.TextToImage,
CopilotCapability.ImageToImage,
];

readonly availableModels = [
// text to image
'fast-turbo-diffusion',
// image to image
// https://blog.fal.ai/building-applications-with-real-time-stable-diffusion-apis/
'110602490-lcm-sd15-i2i',
'lcm-sd15-i2i',
];

constructor(private readonly config: FalConfig) {
Expand All @@ -37,6 +45,10 @@ export class FalProvider implements CopilotImageToImageProvider {
return FalProvider.capabilities;
}

isModelAvailable(model: string): boolean {
return this.availableModels.includes(model);
}

// ====== image to image ======
async generateImages(
messages: PromptMessage[],
Expand All @@ -50,21 +62,20 @@ export class FalProvider implements CopilotImageToImageProvider {
if (!this.availableModels.includes(model)) {
throw new Error(`Invalid model: ${model}`);
}
if (!content) {
throw new Error('Prompt is required');
}
if (!Array.isArray(attachments) || !attachments.length) {
throw new Error('Attachments is required');

// prompt attachments require at least one
if (!content && (!Array.isArray(attachments) || !attachments.length)) {
throw new Error('Prompt or Attachments is empty');
}

const data = (await fetch(`https://${model}.gateway.alpha.fal.ai/`, {
const data = (await fetch(`https://fal.run/fal-ai/${model}`, {
method: 'POST',
headers: {
Authorization: `key ${this.config.apiKey}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({
image_url: attachments[0],
image_url: attachments?.[0],
prompt: content,
sync_mode: true,
seed: 42,
Expand All @@ -73,7 +84,13 @@ export class FalProvider implements CopilotImageToImageProvider {
signal: options.signal,
}).then(res => res.json())) as FalResponse;

return data.images.map(image => image.url);
if (!data.images?.length) {
const error = data.detail?.[0]?.msg;
throw new Error(
error ? `Invalid message: ${error}` : 'No images generated'
);
}
return data.images?.map(image => image.url) || [];
}

async *generateImagesStream(
Expand Down
33 changes: 26 additions & 7 deletions packages/backend/server/src/plugins/copilot/providers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -118,17 +118,36 @@ export class CopilotProviderService {

getProviderByCapability<C extends CopilotCapability>(
capability: C,
model?: string,
prefer?: CopilotProviderType
): CapabilityToCopilotProvider[C] | null {
const providers = PROVIDER_CAPABILITY_MAP.get(capability);
if (Array.isArray(providers) && providers.length) {
const selectedCapability =
prefer && providers.includes(prefer) ? prefer : providers[0];

const provider = this.getProvider(selectedCapability);
assert(provider.getCapabilities().includes(capability));

return provider as CapabilityToCopilotProvider[C];
let selectedProvider: CopilotProviderType | undefined = prefer;
let currentIndex = -1;

if (!selectedProvider) {
currentIndex = 0;
selectedProvider = providers[currentIndex];
}

while (selectedProvider) {
// find first provider that supports the capability and model
if (providers.includes(selectedProvider)) {
const provider = this.getProvider(selectedProvider);
if (provider.getCapabilities().includes(capability)) {
if (model) {
if (provider.isModelAvailable(model)) {
return provider as CapabilityToCopilotProvider[C];
}
} else {
return provider as CapabilityToCopilotProvider[C];
}
}
}
currentIndex += 1;
selectedProvider = providers[currentIndex];
}
}
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ export class OpenAIProvider
return OpenAIProvider.capabilities;
}

isModelAvailable(model: string): boolean {
return this.availableModels.includes(model);
}

private chatToGPTMessage(
messages: PromptMessage[]
): OpenAI.Chat.Completions.ChatCompletionMessageParam[] {
Expand Down
1 change: 1 addition & 0 deletions packages/backend/server/src/plugins/copilot/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ export enum CopilotCapability {

export interface CopilotProvider {
getCapabilities(): CopilotCapability[];
isModelAvailable(model: string): boolean;
}

export interface CopilotTextToTextProvider extends CopilotProvider {
Expand Down