Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow undefined new model #6933

Merged
merged 1 commit into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ jobs:
env:
CARGO_TARGET_DIR: '${{ github.workspace }}/target'
DATABASE_URL: postgresql://affine:affine@localhost:5432/affine
COPILOT_OPENAI_API_KEY: ${{ secrets.COPILOT_OPENAI_API_KEY }}

- name: Upload server test coverage results
uses: codecov/codecov-action@v4
Expand Down
6 changes: 3 additions & 3 deletions packages/backend/server/src/plugins/copilot/controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ export class CopilotController {
@Query() params: Record<string, string | string[]>
): Promise<string> {
const { model } = await this.checkRequest(user.id, sessionId);
const provider = this.provider.getProviderByCapability(
const provider = await this.provider.getProviderByCapability(
CopilotCapability.TextToText,
model
);
Expand Down Expand Up @@ -179,7 +179,7 @@ export class CopilotController {
): Promise<Observable<ChatEvent>> {
try {
const { model } = await this.checkRequest(user.id, sessionId);
const provider = this.provider.getProviderByCapability(
const provider = await this.provider.getProviderByCapability(
CopilotCapability.TextToText,
model
);
Expand Down Expand Up @@ -246,7 +246,7 @@ export class CopilotController {
sessionId,
messageId
);
const provider = this.provider.getProviderByCapability(
const provider = await this.provider.getProviderByCapability(
hasAttachment
? CopilotCapability.ImageToImage
: CopilotCapability.TextToImage,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ export class FalProvider
return FalProvider.capabilities;
}

isModelAvailable(model: string): boolean {
async isModelAvailable(model: string): Promise<boolean> {
return this.availableModels.includes(model);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ export function registerCopilotProvider<
const providerConfig = config.plugins.copilot?.[type];
if (!provider.assetsConfig(providerConfig as C)) {
throw new Error(
`Invalid configuration for copilot provider ${type}: ${providerConfig}`
`Invalid configuration for copilot provider ${type}: ${JSON.stringify(providerConfig)}`
);
}
const instance = new provider(providerConfig as C);
Expand Down Expand Up @@ -116,11 +116,11 @@ export class CopilotProviderService {
return this.cachedProviders.get(provider)!;
}

getProviderByCapability<C extends CopilotCapability>(
async getProviderByCapability<C extends CopilotCapability>(
capability: C,
model?: string,
prefer?: CopilotProviderType
): CapabilityToCopilotProvider[C] | null {
): Promise<CapabilityToCopilotProvider[C] | null> {
const providers = PROVIDER_CAPABILITY_MAP.get(capability);
if (Array.isArray(providers) && providers.length) {
let selectedProvider: CopilotProviderType | undefined = prefer;
Expand All @@ -137,7 +137,7 @@ export class CopilotProviderService {
const provider = this.getProvider(selectedProvider);
if (provider.getCapabilities().includes(capability)) {
if (model) {
if (provider.isModelAvailable(model)) {
if (await provider.isModelAvailable(model)) {
return provider as CapabilityToCopilotProvider[C];
}
} else {
Expand Down
19 changes: 17 additions & 2 deletions packages/backend/server/src/plugins/copilot/providers/openai.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import assert from 'node:assert';

import { Logger } from '@nestjs/common';
import { ClientOptions, OpenAI } from 'openai';

import {
Expand Down Expand Up @@ -51,7 +52,9 @@ export class OpenAIProvider
'dall-e-3',
];

private readonly logger = new Logger(OpenAIProvider.type);
private readonly instance: OpenAI;
private existsModels: string[] | undefined;

constructor(config: ClientOptions) {
assert(OpenAIProvider.assetsConfig(config));
Expand All @@ -70,8 +73,20 @@ export class OpenAIProvider
return OpenAIProvider.capabilities;
}

isModelAvailable(model: string): boolean {
return this.availableModels.includes(model);
async isModelAvailable(model: string): Promise<boolean> {
const knownModels = this.availableModels.includes(model);
if (knownModels) return true;

if (!this.existsModels) {
try {
this.existsModels = await this.instance.models
.list()
.then(({ data }) => data.map(m => m.id));
} catch (e) {
this.logger.error('Failed to fetch online model list', e);
}
}
return !!this.existsModels?.includes(model);
}

protected chatToGPTMessage(
Expand Down
2 changes: 1 addition & 1 deletion packages/backend/server/src/plugins/copilot/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ export type CopilotImageOptions = z.infer<typeof CopilotImageOptionsSchema>;
export interface CopilotProvider {
readonly type: CopilotProviderType;
getCapabilities(): CopilotCapability[];
isModelAvailable(model: string): boolean;
isModelAvailable(model: string): Promise<boolean>;
}

export interface CopilotTextToTextProvider extends CopilotProvider {
Expand Down
60 changes: 46 additions & 14 deletions packages/backend/server/tests/copilot.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ test.beforeEach(async t => {
plugins: {
copilot: {
openai: {
apiKey: '1',
apiKey: process.env.COPILOT_OPENAI_API_KEY ?? '1',
},
fal: {
apiKey: '1',
Expand Down Expand Up @@ -368,7 +368,9 @@ test('should be able to get provider', async t => {
const { provider } = t.context;

{
const p = provider.getProviderByCapability(CopilotCapability.TextToText);
const p = await provider.getProviderByCapability(
CopilotCapability.TextToText
);
t.is(
p?.type.toString(),
'openai',
Expand All @@ -377,7 +379,7 @@ test('should be able to get provider', async t => {
}

{
const p = provider.getProviderByCapability(
const p = await provider.getProviderByCapability(
CopilotCapability.TextToEmbedding
);
t.is(
Expand All @@ -388,7 +390,9 @@ test('should be able to get provider', async t => {
}

{
const p = provider.getProviderByCapability(CopilotCapability.TextToImage);
const p = await provider.getProviderByCapability(
CopilotCapability.TextToImage
);
t.is(
p?.type.toString(),
'fal',
Expand All @@ -397,7 +401,9 @@ test('should be able to get provider', async t => {
}

{
const p = provider.getProviderByCapability(CopilotCapability.ImageToImage);
const p = await provider.getProviderByCapability(
CopilotCapability.ImageToImage
);
t.is(
p?.type.toString(),
'fal',
Expand All @@ -406,7 +412,9 @@ test('should be able to get provider', async t => {
}

{
const p = provider.getProviderByCapability(CopilotCapability.ImageToText);
const p = await provider.getProviderByCapability(
CopilotCapability.ImageToText
);
t.is(
p?.type.toString(),
'openai',
Expand All @@ -417,7 +425,7 @@ test('should be able to get provider', async t => {
// text-to-image use fal by default, but this case can use
// model dall-e-3 to select openai provider
{
const p = provider.getProviderByCapability(
const p = await provider.getProviderByCapability(
CopilotCapability.TextToImage,
'dall-e-3'
);
Expand All @@ -427,24 +435,48 @@ test('should be able to get provider', async t => {
'should get provider support text-to-image and model'
);
}

// gpt4o is not defined now, but it already published by openai
// we should check from online api if it is available
{
const p = await provider.getProviderByCapability(
CopilotCapability.ImageToText,
'gpt-4o'
);
t.is(
p?.type.toString(),
'openai',
'should get provider support text-to-image and model'
);
}

// if a model is not defined and not available in online api
// it should return null
{
const p = await provider.getProviderByCapability(
CopilotCapability.ImageToText,
'gpt-4-not-exist'
);
t.falsy(p, 'should not get provider');
}
});

test('should be able to register test provider', async t => {
const { provider } = t.context;
registerCopilotProvider(MockCopilotTestProvider);

const assertProvider = (cap: CopilotCapability) => {
const p = provider.getProviderByCapability(cap, 'test');
const assertProvider = async (cap: CopilotCapability) => {
const p = await provider.getProviderByCapability(cap, 'test');
t.is(
p?.type,
CopilotProviderType.Test,
`should get test provider with ${cap}`
);
};

assertProvider(CopilotCapability.TextToText);
assertProvider(CopilotCapability.TextToEmbedding);
assertProvider(CopilotCapability.TextToImage);
assertProvider(CopilotCapability.ImageToImage);
assertProvider(CopilotCapability.ImageToText);
await assertProvider(CopilotCapability.TextToText);
await assertProvider(CopilotCapability.TextToEmbedding);
await assertProvider(CopilotCapability.TextToImage);
await assertProvider(CopilotCapability.ImageToImage);
await assertProvider(CopilotCapability.ImageToText);
});
2 changes: 1 addition & 1 deletion packages/backend/server/tests/utils/copilot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ export class MockCopilotTestProvider
return MockCopilotTestProvider.capabilities;
}

override isModelAvailable(model: string): boolean {
override async isModelAvailable(model: string): Promise<boolean> {
return this.availableModels.includes(model);
}

Expand Down
Loading