diff --git a/index.d.ts b/index.d.ts index 45a2430..1e417ae 100644 --- a/index.d.ts +++ b/index.d.ts @@ -156,7 +156,9 @@ declare module "replicate" { identifier: `${string}/${string}` | `${string}/${string}:${string}`, options: { input: object; - wait?: boolean | number | { interval?: number }; + wait?: + | { mode: "block"; interval?: number; timeout?: number } + | { mode: "poll"; interval?: number }; webhook?: string; webhook_events_filter?: WebhookEventType[]; signal?: AbortSignal; @@ -189,7 +191,6 @@ declare module "replicate" { wait( prediction: Prediction, options?: { - mode?: "poll"; interval?: number; }, stop?: (prediction: Prediction) => Promise @@ -215,7 +216,7 @@ declare module "replicate" { stream?: boolean; webhook?: string; webhook_events_filter?: WebhookEventType[]; - wait?: boolean | number | { mode?: "poll"; interval?: number }; + wait?: number | boolean; } ): Promise; }; @@ -304,7 +305,7 @@ declare module "replicate" { stream?: boolean; webhook?: string; webhook_events_filter?: WebhookEventType[]; - wait?: boolean | number | { mode?: "poll"; interval?: number }; + wait?: boolean | number; } & ({ version: string } | { model: string }) ): Promise; get(prediction_id: string): Promise; diff --git a/index.js b/index.js index ac4b815..a5755d9 100644 --- a/index.js +++ b/index.js @@ -48,7 +48,7 @@ class Replicate { * @param {string} options.userAgent - Identifier of your app * @param {string} [options.baseUrl] - Defaults to https://api.replicate.com/v1 * @param {Function} [options.fetch] - Fetch function to use. Defaults to `globalThis.fetch` - * @param {boolean} [options.useFileOutput] - Set to `true` to return `FileOutput` objects from `run` instead of URLs, defaults to false. + * @param {boolean} [options.useFileOutput] - Set to `false` to disable `FileOutput` objects from `run` instead of URLs, defaults to true. * @param {"default" | "upload" | "data-uri"} [options.fileEncodingStrategy] - Determines the file encoding strategy to use */ constructor(options = {}) { @@ -60,7 +60,7 @@ class Replicate { this.baseUrl = options.baseUrl || "https://api.replicate.com/v1"; this.fetch = options.fetch || globalThis.fetch; this.fileEncodingStrategy = options.fileEncodingStrategy || "default"; - this.useFileOutput = options.useFileOutput || false; + this.useFileOutput = options.useFileOutput === false ? false : true; this.accounts = { current: accounts.current.bind(this), @@ -133,8 +133,7 @@ class Replicate { * @param {string} ref - Required. The model version identifier in the format "owner/name" or "owner/name:version" * @param {object} options * @param {object} options.input - Required. An object with the model inputs - * @param {object} [options.wait] - Options for waiting for the prediction to finish. If `wait` is explicitly true, the function will block and wait for the prediction to finish. - * @param {number} [options.wait.interval] - Polling interval in milliseconds. Defaults to 500 + * @param {{mode: "block", timeout?: number, interval?: number} | {mode: "poll", interval?: number }} [options.wait] - Options for waiting for the prediction to finish. If `wait` is explicitly true, the function will block and wait for the prediction to finish. * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) * @param {AbortSignal} [options.signal] - AbortSignal to cancel the prediction @@ -144,23 +143,22 @@ class Replicate { * @returns {Promise} - Resolves with the output of running the model */ async run(ref, options, progress) { - const { wait, signal, ...data } = options; + const { wait = { mode: "block" }, signal, ...data } = options; const identifier = ModelVersionIdentifier.parse(ref); - const isBlocking = typeof wait === "boolean" || typeof wait === "number"; let prediction; if (identifier.version) { prediction = await this.predictions.create({ ...data, version: identifier.version, - wait: isBlocking ? wait : false, + wait: wait.mode === "block" ? wait.timeout ?? true : false, }); } else if (identifier.owner && identifier.name) { prediction = await this.predictions.create({ ...data, model: `${identifier.owner}/${identifier.name}`, - wait: isBlocking ? wait : false, + wait: wait.mode === "block" ? wait.timeout ?? true : false, }); } else { throw new Error("Invalid model version identifier"); @@ -171,11 +169,11 @@ class Replicate { progress(prediction); } - const isDone = isBlocking && prediction.status !== "starting"; + const isDone = wait.mode === "block" && prediction.status !== "starting"; if (!isDone) { prediction = await this.wait( prediction, - isBlocking ? {} : wait, + { interval: wait.mode === "poll" ? wait.interval : undefined }, async (updatedPrediction) => { // Call progress callback with the updated prediction object if (progress) { diff --git a/index.test.ts b/index.test.ts index 08fbf7c..c67035a 100644 --- a/index.test.ts +++ b/index.test.ts @@ -1310,7 +1310,7 @@ describe("Replicate client", () => { "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", { input: { text: "Hello, world!" }, - wait: { interval: 1 }, + wait: { mode: "poll", interval: 1 }, }, (prediction) => { const progress = parseProgressFromLogs(prediction); @@ -1402,7 +1402,7 @@ describe("Replicate client", () => { "replicate/hello-world", { input: { text: "Hello, world!" }, - wait: { interval: 1 }, + wait: { mode: "poll", interval: 1 }, }, progress ); @@ -1448,12 +1448,18 @@ describe("Replicate client", () => { }); await expect( - client.run("a/b-1.0:abc123", { input: { text: "Hello, world!" } }) + client.run("a/b-1.0:abc123", { + wait: { mode: "poll" }, + input: { text: "Hello, world!" }, + }) ).resolves.not.toThrow(); }); test("Throws an error for invalid identifiers", async () => { - const options = { input: { text: "Hello, world!" } }; + const options = { + wait: { mode: "poll" } as { mode: "poll" }, + input: { text: "Hello, world!" }, + }; // @ts-expect-error await expect(client.run("owner:abc123", options)).rejects.toThrow(); @@ -1469,6 +1475,7 @@ describe("Replicate client", () => { await client.run( "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", { + wait: { mode: "poll" }, input: { text: "Alice", }, @@ -1492,7 +1499,7 @@ describe("Replicate client", () => { }) .reply(201, { id: "ufawqhfynnddngldkgtslldrkq", - status: "processing", + status: "starting", }) .persist() .get("/predictions/ufawqhfynnddngldkgtslldrkq") @@ -1510,6 +1517,7 @@ describe("Replicate client", () => { const output = await client.run( "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", { + wait: { mode: "poll" }, input: { text: "Hello, world!" }, signal, }, @@ -1524,7 +1532,7 @@ describe("Replicate client", () => { expect(onProgress).toHaveBeenNthCalledWith( 1, expect.objectContaining({ - status: "processing", + status: "starting", }) ); expect(onProgress).toHaveBeenNthCalledWith( @@ -1580,6 +1588,7 @@ describe("Replicate client", () => { const output = (await client.run( "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", { + wait: { mode: "poll" }, input: { text: "Hello, world!" }, } )) as FileOutput; @@ -1631,6 +1640,7 @@ describe("Replicate client", () => { const output = (await client.run( "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", { + wait: { mode: "poll" }, input: { text: "Hello, world!" }, } )) as unknown as string; @@ -1677,6 +1687,7 @@ describe("Replicate client", () => { const [output] = (await client.run( "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", { + wait: { mode: "poll" }, input: { text: "Hello, world!" }, } )) as FileOutput[]; @@ -1724,6 +1735,7 @@ describe("Replicate client", () => { const output = (await client.run( "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", { + wait: { mode: "poll" }, input: { text: "Hello, world!" }, } )) as FileOutput; diff --git a/package-lock.json b/package-lock.json index f4c62d7..dd0188a 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "replicate", - "version": "0.34.1", + "version": "1.0.0-beta.1", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "replicate", - "version": "0.34.1", + "version": "1.0.0-beta.1", "license": "Apache-2.0", "devDependencies": { "@biomejs/biome": "^1.4.1", diff --git a/package.json b/package.json index 1bc8a04..2572d52 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "replicate", - "version": "0.34.1", + "version": "1.0.0-beta.1", "description": "JavaScript client for Replicate", "repository": "github:replicate/replicate-javascript", "homepage": "https://github.com/replicate/replicate-javascript#readme",