From bb4e12c874b20dbe865728910cf44994c305457f Mon Sep 17 00:00:00 2001
From: Gregor Martynus <39992+gr2m@users.noreply.github.com>
Date: Mon, 2 Sep 2024 17:32:22 -0700
Subject: [PATCH] feat: `options.messages` for `prompt()` (#49)

---
 README.md           |  29 +++-
 index.d.ts          |  50 ++++--
 index.test-d.ts     |  38 +++--
 lib/prompt.js       |  43 +++--
 test/prompt.test.js | 371 +++++++++++++++++++++++++++++---------------
 5 files changed, 366 insertions(+), 165 deletions(-)

diff --git a/README.md b/README.md
index e9a6863..df00d0a 100644
--- a/README.md
+++ b/README.md
@@ -321,7 +321,34 @@ const { message } = await prompt("What is the capital of France?", {
 console.log(message.content);
 ```
 
-⚠️ Not all of the arguments below are implemented yet.
+In order to pass a history of messages, pass them as `options.messages`:
+
+```js
+const { message } = await prompt("What about Spain?", {
+  model: "gpt-4",
+  token: process.env.TOKEN,
+  messages: [
+    { role: "user", content: "What is the capital of France?" },
+    { role: "assistant", content: "The capital of France is Paris." },
+  ],
+});
+```
+
+Alternatively, skip the `message` argument and pass all messages as `options.messages`:
+
+```js
+const { message } = await prompt({
+  model: "gpt-4",
+  token: process.env.TOKEN,
+  messages: [
+    { role: "user", content: "What is the capital of France?" },
+    { role: "assistant", content: "The capital of France is Paris." },
+    { role: "user", content: "What about Spain?" },
+  ],
+});
+```
+
+⚠️ Not all of the arguments below are implemented yet. See [#5](https://github.com/copilot-extensions/preview-sdk.js/issues/5) sub issues for progress.
 
 ```js
 await prompt({
diff --git a/index.d.ts b/index.d.ts
index f41811e..84d8edf 100644
--- a/index.d.ts
+++ b/index.d.ts
@@ -75,9 +75,7 @@ type ResponseEvent<T extends ResponseEventType = "text"> =
 
 type CopilotAckResponseEventData = {
   choices: [{
-    delta: {
-      content: "", role: "assistant"
-    }
+    delta: InteropMessage<"assistant">
   }]
 }
 
@@ -92,9 +90,7 @@ type CopilotDoneResponseEventData = {
 
 type CopilotTextResponseEventData = {
   choices: [{
-    delta: {
-      content: string, role: "assistant"
-    }
+    delta: InteropMessage<"assistant">
   }]
 }
 type CopilotConfirmationResponseEventData = {
@@ -134,7 +130,7 @@ interface CopilotReference {
 
 export interface CopilotRequestPayload {
   copilot_thread_id: string
-  messages: Message[]
+  messages: CopilotMessage[]
   stop: any
   top_p: number
   temperature: number
@@ -146,14 +142,10 @@ export interface CopilotRequestPayload {
 }
 
 export interface OpenAICompatibilityPayload {
-  messages: {
-    role: string
-    name?: string
-    content: string
-  }[]
+  messages: InteropMessage[]
 }
 
-export interface Message {
+export interface CopilotMessage {
   role: string
   content: string
   copilot_references: MessageCopilotReference[]
@@ -167,6 +159,14 @@ export interface Message {
     "type": "function"
   }[]
   name?: string
+  [key: string]: unknown
+}
+
+export interface InteropMessage<TRole extends string = string> {
+  role: TRole
+  content: string
+  name?: string
+  [key: string]: unknown
 }
 
 export interface MessageCopilotReference {
@@ -254,10 +254,23 @@ export interface GetUserConfirmationInterface {
 
 // prompt
 
-/** model names supported by Copilot API */
+/** 
+ * model names supported by Copilot API
+ * 
+ * Based on https://api.githubcopilot.com/models from 2024-09-02
+ */
 export type ModelName =
-  | "gpt-4"
   | "gpt-3.5-turbo"
+  | "gpt-3.5-turbo-0613"
+  | "gpt-4"
+  | "gpt-4-0613"
+  | "gpt-4-o-preview"
+  | "gpt-4o"
+  | "gpt-4o-2024-05-13"
+  | "text-embedding-3-small"
+  | "text-embedding-3-small-inference"
+  | "text-embedding-ada-002"
+  | "text-embedding-ada-002-index"
 
 export interface PromptFunction {
   type: "function"
@@ -274,6 +287,7 @@ export type PromptOptions = {
   model: ModelName
   token: string
   tools?: PromptFunction[]
+  messages?: InteropMessage[]
   request?: {
     fetch?: Function
   }
@@ -281,11 +295,15 @@ export type PromptOptions = {
 
 export type PromptResult = {
   requestId: string
-  message: Message
+  message: CopilotMessage
 }
 
+// https://stackoverflow.com/a/69328045
+type WithRequired<T, K extends keyof T> = T & { [P in K]-?: T[P] }
+
 interface PromptInterface {
   (userPrompt: string, options: PromptOptions): Promise<PromptResult>;
+  (options: WithRequired<PromptOptions, "messages">): Promise<PromptResult>;
 }
 
 // exported methods
diff --git a/index.test-d.ts b/index.test-d.ts
index 267d21c..0ddf2ef 100644
--- a/index.test-d.ts
+++ b/index.test-d.ts
@@ -17,6 +17,7 @@ import {
   getUserMessage,
   getUserConfirmation,
   type VerificationPublicKey,
+  type InteropMessage,
   CopilotRequestPayload,
   prompt,
 } from "./index.js";
@@ -79,11 +80,10 @@ export function createAckEventTest() {
   expectType<() => string>(event.toString);
   expectType<string>(event.toString());
 
+
   expectType<{
     choices: [{
-      delta: {
-        content: "", role: "assistant"
-      }
+      delta: InteropMessage<"assistant">
     }]
   }>(event.data);
 
@@ -98,9 +98,7 @@ export function createTextEventTest() {
 
   expectType<{
     choices: [{
-      delta: {
-        content: string, role: "assistant"
-      }
+      delta: InteropMessage<"assistant">
     }]
   }>(event.data);
 
@@ -243,6 +241,7 @@ export function transformPayloadForOpenAICompatibilityTest(payload: CopilotReque
       content: string;
       role: string;
       name?: string
+      [key: string]: unknown
     }[]
   }
   >(result);
@@ -307,12 +306,33 @@ export async function promptWithToolsTest() {
         function: {
           name: "",
           description: "",
-          parameters: {
-
-          },
+          parameters: {},
           strict: true,
         }
       }
     ]
   })
+}
+
+export async function promptWithMessageAndMessages() {
+  await prompt("What about Spain?", {
+    model: "gpt-4",
+    token: 'secret',
+    messages: [
+      { role: "user", content: "What is the capital of France?" },
+      { role: "assistant", content: "The capital of France is Paris." },
+    ],
+  });
+}
+
+export async function promptWithoutMessageButMessages() {
+  await prompt({
+    model: "gpt-4",
+    token: 'secret',
+    messages: [
+      { role: "user", content: "What is the capital of France?" },
+      { role: "assistant", content: "The capital of France is Paris." },
+      { role: "user", content: "What about Spain?" },
+    ],
+  });
 }
\ No newline at end of file
diff --git a/lib/prompt.js b/lib/prompt.js
index 77bbab9..e7c9db0 100644
--- a/lib/prompt.js
+++ b/lib/prompt.js
@@ -2,12 +2,32 @@
 
 /** @type {import('..').PromptInterface} */
 export async function prompt(userPrompt, promptOptions) {
-  const promptFetch = promptOptions.request?.fetch || fetch;
+  const options = typeof userPrompt === "string" ? promptOptions : userPrompt;
 
-  const systemMessage = promptOptions.tools
+  const promptFetch = options.request?.fetch || fetch;
+
+  const systemMessage = options.tools
     ? "You are a helpful assistant. Use the supplied tools to assist the user."
     : "You are a helpful assistant.";
 
+  const messages = [
+    {
+      role: "system",
+      content: systemMessage,
+    },
+  ];
+
+  if (options.messages) {
+    messages.push(...options.messages);
+  }
+
+  if (typeof userPrompt === "string") {
+    messages.push({
+      role: "user",
+      content: userPrompt,
+    });
+  }
+
   const response = await promptFetch(
     "https://api.githubcopilot.com/chat/completions",
     {
@@ -16,22 +36,13 @@ export async function prompt(userPrompt, promptOptions) {
         accept: "application/json",
         "content-type": "application/json; charset=UTF-8",
         "user-agent": "copilot-extensions/preview-sdk.js",
-        authorization: `Bearer ${promptOptions.token}`,
+        authorization: `Bearer ${options.token}`,
       },
       body: JSON.stringify({
-        messages: [
-          {
-            role: "system",
-            content: systemMessage,
-          },
-          {
-            role: "user",
-            content: userPrompt,
-          },
-        ],
-        model: promptOptions.model,
-        toolChoice: promptOptions.tools ? "auto" : undefined,
-        tools: promptOptions.tools,
+        messages: messages,
+        model: options.model,
+        toolChoice: options.tools ? "auto" : undefined,
+        tools: options.tools,
       }),
     }
   );
diff --git a/test/prompt.test.js b/test/prompt.test.js
index fb8d723..2b83e12 100644
--- a/test/prompt.test.js
+++ b/test/prompt.test.js
@@ -1,145 +1,270 @@
-import { test } from "node:test";
+import { test, suite } from "node:test";
 
 import { MockAgent } from "undici";
 
 import { prompt } from "../index.js";
 
-test("smoke", (t) => {
-  t.assert.equal(typeof prompt, "function");
-});
+suite("prompt", () => {
+  test("smoke", (t) => {
+    t.assert.equal(typeof prompt, "function");
+  });
 
-test("minimal usage", async (t) => {
-  const mockAgent = new MockAgent();
-  function fetchMock(url, opts) {
-    opts ||= {};
-    opts.dispatcher = mockAgent;
-    return fetch(url, opts);
-  }
-
-  mockAgent.disableNetConnect();
-  const mockPool = mockAgent.get("https://api.githubcopilot.com");
-  mockPool
-    .intercept({
-      method: "post",
-      path: `/chat/completions`,
-      body: JSON.stringify({
-        messages: [
-          {
-            role: "system",
-            content: "You are a helpful assistant.",
-          },
-          {
-            role: "user",
-            content: "What is the capital of France?",
-          },
-        ],
-        model: "gpt-4",
-      }),
-    })
-    .reply(
-      200,
-      {
-        choices: [
-          {
-            message: {
-              content: "<response text>",
+  test("minimal usage", async (t) => {
+    const mockAgent = new MockAgent();
+    function fetchMock(url, opts) {
+      opts ||= {};
+      opts.dispatcher = mockAgent;
+      return fetch(url, opts);
+    }
+
+    mockAgent.disableNetConnect();
+    const mockPool = mockAgent.get("https://api.githubcopilot.com");
+    mockPool
+      .intercept({
+        method: "post",
+        path: `/chat/completions`,
+        body: JSON.stringify({
+          messages: [
+            {
+              role: "system",
+              content: "You are a helpful assistant.",
             },
+            {
+              role: "user",
+              content: "What is the capital of France?",
+            },
+          ],
+          model: "gpt-4",
+        }),
+      })
+      .reply(
+        200,
+        {
+          choices: [
+            {
+              message: {
+                content: "<response text>",
+              },
+            },
+          ],
+        },
+        {
+          headers: {
+            "content-type": "application/json",
+            "x-request-id": "<request-id>",
           },
-        ],
+        }
+      );
+
+    const result = await prompt("What is the capital of France?", {
+      token: "secret",
+      model: "gpt-4",
+      request: { fetch: fetchMock },
+    });
+
+    t.assert.deepEqual(result, {
+      requestId: "<request-id>",
+      message: {
+        content: "<response text>",
       },
-      {
-        headers: {
-          "content-type": "application/json",
-          "x-request-id": "<request-id>",
+    });
+  });
+
+  test("options.messages", async (t) => {
+    const mockAgent = new MockAgent();
+    function fetchMock(url, opts) {
+      opts ||= {};
+      opts.dispatcher = mockAgent;
+      return fetch(url, opts);
+    }
+
+    mockAgent.disableNetConnect();
+    const mockPool = mockAgent.get("https://api.githubcopilot.com");
+    mockPool
+      .intercept({
+        method: "post",
+        path: `/chat/completions`,
+        body: JSON.stringify({
+          messages: [
+            { role: "system", content: "You are a helpful assistant." },
+            { role: "user", content: "What is the capital of France?" },
+            { role: "assistant", content: "The capital of France is Paris." },
+            { role: "user", content: "What about Spain?" },
+          ],
+          model: "gpt-4",
+        }),
+      })
+      .reply(
+        200,
+        {
+          choices: [
+            {
+              message: {
+                content: "<response text>",
+              },
+            },
+          ],
         },
-      }
-    );
+        {
+          headers: {
+            "content-type": "application/json",
+            "x-request-id": "<request-id>",
+          },
+        }
+      );
 
-  const result = await prompt("What is the capital of France?", {
-    token: "secret",
-    model: "gpt-4",
-    request: { fetch: fetchMock },
-  });
+    const result = await prompt("What about Spain?", {
+      model: "gpt-4",
+      token: "secret",
+      messages: [
+        { role: "user", content: "What is the capital of France?" },
+        { role: "assistant", content: "The capital of France is Paris." },
+      ],
+      request: { fetch: fetchMock },
+    });
 
-  t.assert.deepEqual(result, {
-    requestId: "<request-id>",
-    message: {
-      content: "<response text>",
-    },
+    t.assert.deepEqual(result, {
+      requestId: "<request-id>",
+      message: {
+        content: "<response text>",
+      },
+    });
   });
-});
 
-test("function calling", async (t) => {
-  const mockAgent = new MockAgent();
-  function fetchMock(url, opts) {
-    opts ||= {};
-    opts.dispatcher = mockAgent;
-    return fetch(url, opts);
-  }
-
-  mockAgent.disableNetConnect();
-  const mockPool = mockAgent.get("https://api.githubcopilot.com");
-  mockPool
-    .intercept({
-      method: "post",
-      path: `/chat/completions`,
-      body: JSON.stringify({
-        messages: [
-          {
-            role: "system",
-            content:
-              "You are a helpful assistant. Use the supplied tools to assist the user.",
-          },
-          { role: "user", content: "Call the function" },
-        ],
-        model: "gpt-4",
-        toolChoice: "auto",
-        tools: [
-          {
-            type: "function",
-            function: { name: "the_function", description: "The function" },
-          },
-        ],
-      }),
-    })
-    .reply(
-      200,
-      {
-        choices: [
-          {
-            message: {
-              content: "<response text>",
+  test("single options argument", async (t) => {
+    const mockAgent = new MockAgent();
+    function fetchMock(url, opts) {
+      opts ||= {};
+      opts.dispatcher = mockAgent;
+      return fetch(url, opts);
+    }
+
+    mockAgent.disableNetConnect();
+    const mockPool = mockAgent.get("https://api.githubcopilot.com");
+    mockPool
+      .intercept({
+        method: "post",
+        path: `/chat/completions`,
+        body: JSON.stringify({
+          messages: [
+            { role: "system", content: "You are a helpful assistant." },
+            { role: "user", content: "What is the capital of France?" },
+            { role: "assistant", content: "The capital of France is Paris." },
+            { role: "user", content: "What about Spain?" },
+          ],
+          model: "gpt-4",
+        }),
+      })
+      .reply(
+        200,
+        {
+          choices: [
+            {
+              message: {
+                content: "<response text>",
+              },
             },
+          ],
+        },
+        {
+          headers: {
+            "content-type": "application/json",
+            "x-request-id": "<request-id>",
           },
-        ],
+        }
+      );
+
+    const result = await prompt({
+      model: "gpt-4",
+      token: "secret",
+      messages: [
+        { role: "user", content: "What is the capital of France?" },
+        { role: "assistant", content: "The capital of France is Paris." },
+        { role: "user", content: "What about Spain?" },
+      ],
+      request: { fetch: fetchMock },
+    });
+
+    t.assert.deepEqual(result, {
+      requestId: "<request-id>",
+      message: {
+        content: "<response text>",
       },
-      {
-        headers: {
-          "content-type": "application/json",
-          "x-request-id": "<request-id>",
+    });
+  });
+
+  test("function calling", async (t) => {
+    const mockAgent = new MockAgent();
+    function fetchMock(url, opts) {
+      opts ||= {};
+      opts.dispatcher = mockAgent;
+      return fetch(url, opts);
+    }
+
+    mockAgent.disableNetConnect();
+    const mockPool = mockAgent.get("https://api.githubcopilot.com");
+    mockPool
+      .intercept({
+        method: "post",
+        path: `/chat/completions`,
+        body: JSON.stringify({
+          messages: [
+            {
+              role: "system",
+              content:
+                "You are a helpful assistant. Use the supplied tools to assist the user.",
+            },
+            { role: "user", content: "Call the function" },
+          ],
+          model: "gpt-4",
+          toolChoice: "auto",
+          tools: [
+            {
+              type: "function",
+              function: { name: "the_function", description: "The function" },
+            },
+          ],
+        }),
+      })
+      .reply(
+        200,
+        {
+          choices: [
+            {
+              message: {
+                content: "<response text>",
+              },
+            },
+          ],
         },
-      }
-    );
-
-  const result = await prompt("Call the function", {
-    token: "secret",
-    model: "gpt-4",
-    tools: [
-      {
-        type: "function",
-        function: {
-          name: "the_function",
-          description: "The function",
+        {
+          headers: {
+            "content-type": "application/json",
+            "x-request-id": "<request-id>",
+          },
+        }
+      );
+
+    const result = await prompt("Call the function", {
+      token: "secret",
+      model: "gpt-4",
+      tools: [
+        {
+          type: "function",
+          function: {
+            name: "the_function",
+            description: "The function",
+          },
         },
-      },
-    ],
-    request: { fetch: fetchMock },
-  });
+      ],
+      request: { fetch: fetchMock },
+    });
 
-  t.assert.deepEqual(result, {
-    requestId: "<request-id>",
-    message: {
-      content: "<response text>",
-    },
+    t.assert.deepEqual(result, {
+      requestId: "<request-id>",
+      message: {
+        content: "<response text>",
+      },
+    });
   });
 });