Skip to content

Commit 9c8400d

Browse files
authoredAug 15, 2024
Merge pull request #2 from copilot-extensions/include-models-in-funcs
Include list of models in tool call system prompt
2 parents c0bd65e + fe031c7 commit 9c8400d

File tree

5 files changed

+44
-17
lines changed

5 files changed

+44
-17
lines changed
 

‎src/functions.ts

+4-6
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ export interface RunnerResponse {
1111
messages: OpenAI.ChatCompletionMessageParam[];
1212
}
1313

14-
export class Tool {
14+
export abstract class Tool {
1515
modelsAPI: ModelsAPI;
16+
static definition: OpenAI.FunctionDefinition;
1617

1718
constructor(modelsAPI: ModelsAPI) {
1819
this.modelsAPI = modelsAPI;
@@ -24,12 +25,9 @@ export class Tool {
2425
function: this.definition,
2526
};
2627
}
27-
static definition: OpenAI.FunctionDefinition;
2828

29-
async execute(
29+
abstract execute(
3030
messages: OpenAI.ChatCompletionMessageParam[],
3131
args: object
32-
): Promise<RunnerResponse> {
33-
throw new Error("Not implemented");
34-
}
32+
): Promise<RunnerResponse>;
3533
}

‎src/functions/describe-model.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ export class describeModel extends Tool {
1111
model: {
1212
type: "string",
1313
description:
14-
'The model to describe. Looks like "publisher/model-name".',
14+
'The model to describe. Looks like "registry/model-name". For example, `azureml/Phi-3-medium-128k-instruct` or `azure-openai/gpt-4o',
1515
},
1616
},
1717
required: ["model"],

‎src/functions/recommend-model.ts

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import OpenAI from "openai";
22
import { RunnerResponse, defaultModel, Tool } from "../functions";
3+
import { ModelsAPI } from "../models-api";
34

45
export class recommendModel extends Tool {
56
static definition = {

‎src/index.ts

+32-10
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,6 @@ import { listModels } from "./functions/list-models";
77
import { RunnerResponse } from "./functions";
88
import { recommendModel } from "./functions/recommend-model";
99
import { ModelsAPI } from "./models-api";
10-
11-
// List of functions that are available to be called
12-
const functions = [listModels, describeModel, executeModel, recommendModel];
13-
1410
const app = express();
1511

1612
app.post("/", verifySignatureMiddleware, express.json(), async (req, res) => {
@@ -21,17 +17,44 @@ app.post("/", verifySignatureMiddleware, express.json(), async (req, res) => {
2117
return;
2218
}
2319

20+
// List of functions that are available to be called
21+
const modelsAPI = new ModelsAPI(apiKey);
22+
const functions = [listModels, describeModel, executeModel, recommendModel];
23+
2424
// Use the Copilot API to determine which function to execute
2525
const capiClient = new OpenAI({
2626
baseURL: "https://api.githubcopilot.com",
2727
apiKey,
2828
});
2929

30+
// Prepend a system message that includes the list of models, so that
31+
// tool calls can better select the right model to use.
32+
const models = await modelsAPI.listModels();
33+
const toolCallMessages = [
34+
{
35+
role: "system",
36+
content: [
37+
"You are an extension of GitHub Copilot, built to interact with GitHub Models.",
38+
"GitHub Models is a language model playground, where you can experiment with different models and see how they respond to your prompts.",
39+
"Here is a list of some of the models available to the user:",
40+
JSON.stringify(
41+
models.map((model) => ({
42+
name: model.name,
43+
publisher: model.publisher,
44+
registry: model.model_registry,
45+
description: model.summary,
46+
}))
47+
),
48+
].join("\n"),
49+
},
50+
...req.body.messages,
51+
].concat(req.body.messages);
52+
3053
console.time("tool-call");
3154
const toolCaller = await capiClient.chat.completions.create({
3255
stream: false,
3356
model: "gpt-4",
34-
messages: req.body.messages,
57+
messages: toolCallMessages,
3558
tool_choice: "auto",
3659
tools: functions.map((f) => f.tool),
3760
});
@@ -63,20 +86,19 @@ app.post("/", verifySignatureMiddleware, express.json(), async (req, res) => {
6386
const args = JSON.parse(functionToCall.arguments);
6487

6588
console.time("function-exec");
66-
const modelsAPI = new ModelsAPI(apiKey);
6789
let functionCallRes: RunnerResponse;
6890
try {
6991
console.log("Executing function", functionToCall.name);
70-
const klass = functions.find(
92+
const funcClass = functions.find(
7193
(f) => f.definition.name === functionToCall.name
7294
);
73-
if (!klass) {
95+
if (!funcClass) {
7496
throw new Error("Unknown function");
7597
}
7698

7799
console.log("\t with args", args);
78-
const inst = new klass(modelsAPI);
79-
functionCallRes = await inst.execute(req.body.messages, args);
100+
const func = new funcClass(modelsAPI);
101+
functionCallRes = await func.execute(req.body.messages, args);
80102
} catch (err) {
81103
console.error(err);
82104
res.status(500).end();

‎src/models-api.ts

+6
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ export type ModelSchemaParameter = {
3434

3535
export class ModelsAPI {
3636
inference: OpenAI;
37+
private _models: Model[] | null = null;
3738

3839
constructor(apiKey: string) {
3940
this.inference = new OpenAI({
@@ -67,6 +68,10 @@ export class ModelsAPI {
6768
}
6869

6970
async listModels(): Promise<Model[]> {
71+
if (this._models) {
72+
return this._models;
73+
}
74+
7075
const modelsRes = await fetch(
7176
"https://modelcatalog.azure-api.net/v1/models"
7277
);
@@ -75,6 +80,7 @@ export class ModelsAPI {
7580
}
7681

7782
const models = (await modelsRes.json()) as Model[];
83+
this._models = models;
7884
return models;
7985
}
8086
}

0 commit comments

Comments
 (0)
Failed to load comments.