diff --git a/README.md b/README.md index 75ab0d2..5341fd4 100644 --- a/README.md +++ b/README.md @@ -811,3 +811,11 @@ You can call this method directly to make other requests to the API. ## TypeScript The `Replicate` constructor and all `replicate.*` methods are fully typed. + +## Vendored Dependencies + +We have a few dependencies that have been bundled into the vendor directory rather than adding external npm dependencies. + +These have been generated using bundlejs.com and copied into the appropriate directory along with the license and repository information. + +* [eventsource-parser/stream](https://bundlejs.com/?bundle&q=eventsource-parser%40latest%2Fstream&config=%7B%22esbuild%22%3A%7B%22format%22%3A%22cjs%22%2C%22minify%22%3Afalse%2C%22platform%22%3A%22neutral%22%7D%7D) diff --git a/biome.json b/biome.json index 807b901..ecb665f 100644 --- a/biome.json +++ b/biome.json @@ -1,5 +1,8 @@ { "$schema": "https://biomejs.dev/schemas/1.0.0/schema.json", + "files": { + "ignore": [".wrangler", "vendor/*"] + }, "formatter": { "indentStyle": "space", "indentWidth": 2 diff --git a/index.d.ts b/index.d.ts index 8dc998a..abf68dc 100644 --- a/index.d.ts +++ b/index.d.ts @@ -279,7 +279,7 @@ declare module "replicate" { signature?: string; }, secret: string - ): boolean; + ): Promise; export function parseProgressFromLogs(logs: Prediction | string): { percentage: number; diff --git a/index.js b/index.js index 24376fe..1996536 100644 --- a/index.js +++ b/index.js @@ -1,6 +1,6 @@ const ApiError = require("./lib/error"); const ModelVersionIdentifier = require("./lib/identifier"); -const { Stream } = require("./lib/stream"); +const { createReadableStream } = require("./lib/stream"); const { withAutomaticRetries, validateWebhook, @@ -289,7 +289,11 @@ class Replicate { if (prediction.urls && prediction.urls.stream) { const { signal } = options; - const stream = new Stream(prediction.urls.stream, { signal }); + const stream = createReadableStream({ + url: prediction.urls.stream, + fetch: this.fetch, + options: { signal }, + }); yield* stream; } else { throw new Error("Prediction does not support streaming"); diff --git a/index.test.ts b/index.test.ts index 97abc6f..6624eb2 100644 --- a/index.test.ts +++ b/index.test.ts @@ -7,7 +7,8 @@ import Replicate, { parseProgressFromLogs, } from "replicate"; import nock from "nock"; -import fetch from "cross-fetch"; +import { createReadableStream } from "./lib/stream"; +import { PassThrough } from "node:stream"; let client: Replicate; const BASE_URL = "https://api.replicate.com/v1"; @@ -21,7 +22,6 @@ describe("Replicate client", () => { beforeEach(() => { client = new Replicate({ auth: "test-token" }); - client.fetch = fetch; unmatched = []; nock.emitter.on("no match", handleNoMatch); @@ -251,7 +251,7 @@ describe("Replicate client", () => { let actual: Record | undefined; nock(BASE_URL) .post("/predictions") - .reply(201, (uri: string, body: Record) => { + .reply(201, (_uri: string, body: Record) => { actual = body; return body; }); @@ -1010,8 +1010,6 @@ describe("Replicate client", () => { }); test("Calls the correct API routes for a model", async () => { - const firstPollingRequest = true; - nock(BASE_URL) .post("/models/replicate/hello-world/predictions") .reply(201, { @@ -1187,4 +1185,314 @@ describe("Replicate client", () => { }); // Continue with tests for other methods + + describe("createReadableStream", () => { + function createStream(body: string | NodeJS.ReadableStream, status = 200) { + const streamEndpoint = "https://stream.replicate.com"; + nock(streamEndpoint) + .get("/fake_stream") + .matchHeader("Accept", "text/event-stream") + .reply(status, body); + + return createReadableStream({ + url: `${streamEndpoint}/fake_stream`, + fetch: fetch, + }); + } + + test("consumes a server sent event stream", async () => { + const stream = createStream( + ` + event: output + id: EVENT_1 + data: hello world + + event: done + id: EVENT_2 + data: {} + + `.replace(/^[ ]+/gm, "") + ); + + const iterator = stream[Symbol.asyncIterator](); + expect(await iterator.next()).toEqual({ + done: false, + value: { event: "output", id: "EVENT_1", data: "hello world" }, + }); + expect(await iterator.next()).toEqual({ + done: false, + value: { event: "done", id: "EVENT_2", data: "{}" }, + }); + expect(await iterator.next()).toEqual({ done: true }); + expect(await iterator.next()).toEqual({ done: true }); + }); + + test("consumes multiple events", async () => { + const stream = createStream( + ` + event: output + id: EVENT_1 + data: hello world + + event: output + id: EVENT_2 + data: hello dave + + event: done + id: EVENT_3 + data: {} + + `.replace(/^[ ]+/gm, "") + ); + + const iterator = stream[Symbol.asyncIterator](); + + expect(await iterator.next()).toEqual({ + done: false, + value: { event: "output", id: "EVENT_1", data: "hello world" }, + }); + expect(await iterator.next()).toEqual({ + done: false, + value: { event: "output", id: "EVENT_2", data: "hello dave" }, + }); + expect(await iterator.next()).toEqual({ + done: false, + value: { event: "done", id: "EVENT_3", data: "{}" }, + }); + expect(await iterator.next()).toEqual({ done: true }); + expect(await iterator.next()).toEqual({ done: true }); + }); + + test("ignores unexpected characters", async () => { + const stream = createStream( + ` + : hi + + event: output + id: EVENT_1 + data: hello world + + event: done + id: EVENT_2 + data: {} + + `.replace(/^[ ]+/gm, "") + ); + + const iterator = stream[Symbol.asyncIterator](); + + expect(await iterator.next()).toEqual({ + done: false, + value: { event: "output", id: "EVENT_1", data: "hello world" }, + }); + expect(await iterator.next()).toEqual({ + done: false, + value: { event: "done", id: "EVENT_2", data: "{}" }, + }); + expect(await iterator.next()).toEqual({ done: true }); + expect(await iterator.next()).toEqual({ done: true }); + }); + + test("supports multiple lines of output in a single event", async () => { + const stream = createStream( + ` + : hi + + event: output + id: EVENT_1 + data: hello, + data: this is a new line, + data: and this is a new line too + + event: done + id: EVENT_2 + data: {} + + `.replace(/^[ ]+/gm, "") + ); + + const iterator = stream[Symbol.asyncIterator](); + + expect(await iterator.next()).toEqual({ + done: false, + value: { + event: "output", + id: "EVENT_1", + data: "hello,\nthis is a new line,\nand this is a new line too", + }, + }); + expect(await iterator.next()).toEqual({ + done: false, + value: { event: "done", id: "EVENT_2", data: "{}" }, + }); + expect(await iterator.next()).toEqual({ done: true }); + expect(await iterator.next()).toEqual({ done: true }); + }); + + test("supports the server writing data lines in multiple chunks", async () => { + const body = new PassThrough(); + const stream = createStream(body); + + // Create a stream of data chunks split on the pipe character for readability. + const data = ` + event: output + id: EVENT_1 + data: hello,| + data: this is a new line,| + data: and this is a new line too + + event: done + id: EVENT_2 + data: {} + + `.replace(/^[ ]+/gm, ""); + + const chunks = data.split("|"); + + // Consume the iterator in parallel to writing it. + const reading = new Promise((resolve, reject) => { + (async () => { + const iterator = stream[Symbol.asyncIterator](); + expect(await iterator.next()).toEqual({ + done: false, + value: { + event: "output", + id: "EVENT_1", + data: "hello,\nthis is a new line,\nand this is a new line too", + }, + }); + expect(await iterator.next()).toEqual({ + done: false, + value: { event: "done", id: "EVENT_2", data: "{}" }, + }); + expect(await iterator.next()).toEqual({ done: true }); + })().then(resolve, reject); + }); + + // Write the chunks to the stream at an interval. + const writing = new Promise((resolve, reject) => { + (async () => { + for await (const chunk of chunks) { + body.write(chunk); + await new Promise((resolve) => setTimeout(resolve, 1)); + } + body.end(); + resolve(null); + })().then(resolve, reject); + }); + + // Wait for both promises to resolve. + await Promise.all([reading, writing]); + }); + + test("supports the server writing data in a complete mess", async () => { + const body = new PassThrough(); + const stream = createStream(body); + + // Create a stream of data chunks split on the pipe character for readability. + const data = ` + : hi + + ev|ent: output + id: EVENT_1 + data: hello, + data: this |is a new line,| + data: and this is |a new line too + + event: d|one + id: EVENT|_2 + data: {} + + `.replace(/^[ ]+/gm, ""); + + const chunks = data.split("|"); + + // Consume the iterator in parallel to writing it. + const reading = new Promise((resolve, reject) => { + (async () => { + const iterator = stream[Symbol.asyncIterator](); + expect(await iterator.next()).toEqual({ + done: false, + value: { + event: "output", + id: "EVENT_1", + data: "hello,\nthis is a new line,\nand this is a new line too", + }, + }); + expect(await iterator.next()).toEqual({ + done: false, + value: { event: "done", id: "EVENT_2", data: "{}" }, + }); + expect(await iterator.next()).toEqual({ done: true }); + })().then(resolve, reject); + }); + + // Write the chunks to the stream at an interval. + const writing = new Promise((resolve, reject) => { + (async () => { + for await (const chunk of chunks) { + body.write(chunk); + await new Promise((resolve) => setTimeout(resolve, 1)); + } + body.end(); + resolve(null); + })().then(resolve, reject); + }); + + // Wait for both promises to resolve. + await Promise.all([reading, writing]); + }); + + test("supports ending without a done", async () => { + const stream = createStream( + ` + event: output + id: EVENT_1 + data: hello world + + `.replace(/^[ ]+/gm, "") + ); + + const iterator = stream[Symbol.asyncIterator](); + expect(await iterator.next()).toEqual({ + done: false, + value: { event: "output", id: "EVENT_1", data: "hello world" }, + }); + expect(await iterator.next()).toEqual({ done: true }); + }); + + test("an error event in the stream raises an exception", async () => { + const stream = createStream( + ` + event: output + id: EVENT_1 + data: hello world + + event: error + id: EVENT_2 + data: An unexpected error occurred + + `.replace(/^[ ]+/gm, "") + ); + + const iterator = stream[Symbol.asyncIterator](); + expect(await iterator.next()).toEqual({ + done: false, + value: { event: "output", id: "EVENT_1", data: "hello world" }, + }); + await expect(iterator.next()).rejects.toThrowError( + "An unexpected error occurred" + ); + expect(await iterator.next()).toEqual({ done: true }); + }); + + test("an error when fetching the stream raises an exception", async () => { + const stream = createStream("{}", 500); + const iterator = stream[Symbol.asyncIterator](); + await expect(iterator.next()).rejects.toThrowError( + "Request to https://stream.replicate.com/fake_stream failed with status 500" + ); + expect(await iterator.next()).toEqual({ done: true }); + }); + }); }); diff --git a/lib/stream.js b/lib/stream.js index 012d6d0..a97642d 100644 --- a/lib/stream.js +++ b/lib/stream.js @@ -1,14 +1,9 @@ // Attempt to use readable-stream if available, attempt to use the built-in stream module. -let Readable; -try { - Readable = require("readable-stream").Readable; -} catch (e) { - try { - Readable = require("stream").Readable; - } catch (e) { - Readable = null; - } -} + +const ApiError = require("./error"); +const { + EventSourceParserStream, +} = require("../vendor/eventsource-parser/stream"); /** * A server-sent event. @@ -42,98 +37,56 @@ class ServerSentEvent { } /** - * A stream of server-sent events. + * Create a new stream of server-sent events. + * + * @param {object} config + * @param {string} config.url The URL to connect to. + * @param {typeof fetch} [config.fetch] The URL to connect to. + * @param {object} [config.options] The EventSource options. + * @returns {ReadableStream & AsyncIterable} */ -class Stream extends Readable { - /** - * Create a new stream of server-sent events. - * - * @param {string} url The URL to connect to. - * @param {object} options The fetch options. - */ - constructor(url, options) { - if (!Readable) { - throw new Error( - "Readable streams are not supported. Please use Node.js 18 or later, or install the readable-stream package." - ); - } - - super(); - this.url = url; - this.options = options; - - this.event = null; - this.data = []; - this.lastEventId = null; - this.retry = null; - } - - decode(line) { - if (!line) { - if (!this.event && !this.data.length && !this.lastEventId) { - return null; +function createReadableStream({ url, fetch, options = {} }) { + return new ReadableStream({ + async start(controller) { + const init = { + ...options, + headers: { + ...options.headers, + Accept: "text/event-stream", + }, + }; + const response = await fetch(url, init); + + if (!response.ok) { + const text = await response.text(); + const request = new Request(url, init); + controller.error( + new ApiError( + `Request to ${url} failed with status ${response.status}`, + request, + response + ) + ); } - const sse = new ServerSentEvent( - this.event, - this.data.join("\n"), - this.lastEventId - ); - - this.event = null; - this.data = []; - this.retry = null; - - return sse; - } - - if (line.startsWith(":")) { - return null; - } - - const [field, value] = line.split(": "); - if (field === "event") { - this.event = value; - } else if (field === "data") { - this.data.push(value); - } else if (field === "id") { - this.lastEventId = value; - } - - return null; - } - - async *[Symbol.asyncIterator]() { - const response = await fetch(this.url, { - ...this.options, - headers: { - Accept: "text/event-stream", - }, - }); - - for await (const chunk of response.body) { - const decoder = new TextDecoder("utf-8"); - const text = decoder.decode(chunk); - const lines = text.split("\n"); - for (const line of lines) { - const sse = this.decode(line); - if (sse) { - if (sse.event === "error") { - throw new Error(sse.data); - } - - yield sse; - - if (sse.event === "done") { - return; - } + const stream = response.body + .pipeThrough(new TextDecoderStream()) + .pipeThrough(new EventSourceParserStream()); + for await (const event of stream) { + if (event.event === "error") { + controller.error(new Error(event.data)); + } else { + controller.enqueue( + new ServerSentEvent(event.event, event.data, event.id) + ); } } - } - } + controller.close(); + }, + }); } module.exports = { - Stream, + createReadableStream, ServerSentEvent, }; diff --git a/lib/util.js b/lib/util.js index a9406e3..9d9c5af 100644 --- a/lib/util.js +++ b/lib/util.js @@ -26,7 +26,7 @@ const ApiError = require("./error"); /** * Validate a webhook signature * - * @returns {boolean} - True if the signature is valid + * @returns {Promise} - True if the signature is valid * @throws {Error} - If the request is missing required headers, body, or secret */ async function validateWebhook(requestData, secret) { diff --git a/package-lock.json b/package-lock.json index 8ba3c03..c3a18b4 100644 --- a/package-lock.json +++ b/package-lock.json @@ -14,7 +14,7 @@ "@typescript-eslint/eslint-plugin": "^5.56.0", "cross-fetch": "^3.1.5", "jest": "^29.6.2", - "nock": "^13.3.0", + "nock": "^14.0.0-beta.4", "publint": "^0.2.7", "ts-jest": "^29.1.0", "typescript": "^5.0.2" @@ -3948,12 +3948,6 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/lodash": { - "version": "4.17.21", - "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz", - "integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==", - "dev": true - }, "node_modules/lodash.memoize": { "version": "4.1.2", "resolved": "https://registry.npmjs.org/lodash.memoize/-/lodash.memoize-4.1.2.tgz", @@ -4083,18 +4077,17 @@ "dev": true }, "node_modules/nock": { - "version": "13.3.0", - "resolved": "https://registry.npmjs.org/nock/-/nock-13.3.0.tgz", - "integrity": "sha512-HHqYQ6mBeiMc+N038w8LkMpDCRquCHWeNmN3v6645P3NhN2+qXOBqvPqo7Rt1VyCMzKhJ733wZqw5B7cQVFNPg==", + "version": "14.0.0-beta.4", + "resolved": "https://registry.npmjs.org/nock/-/nock-14.0.0-beta.4.tgz", + "integrity": "sha512-N9GIOnNFas/TtdCQpavpi6A6SyVVInkD/vrUCF2u51vlE2wSnqfPifVli6xSX8l6Lz/3sdSwPusE9n3KPDDh0g==", "dev": true, "dependencies": { "debug": "^4.1.0", "json-stringify-safe": "^5.0.1", - "lodash": "^4.17.21", "propagate": "^2.0.0" }, "engines": { - "node": ">= 10.13" + "node": ">= 18" } }, "node_modules/node-fetch": { diff --git a/package.json b/package.json index 098bb3a..c2ed3c9 100644 --- a/package.json +++ b/package.json @@ -16,6 +16,7 @@ "index.d.ts", "index.js", "lib/**/*.js", + "vendor/**/*", "package.json" ], "engines": { @@ -41,7 +42,7 @@ "@typescript-eslint/eslint-plugin": "^5.56.0", "cross-fetch": "^3.1.5", "jest": "^29.6.2", - "nock": "^13.3.0", + "nock": "^14.0.0-beta.4", "publint": "^0.2.7", "ts-jest": "^29.1.0", "typescript": "^5.0.2" diff --git a/tsconfig.json b/tsconfig.json index 7a564ee..b699d79 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -2,9 +2,10 @@ "compilerOptions": { "esModuleInterop": true, "noEmit": true, - "strict": true + "strict": true, + "allowJs": true }, "exclude": [ "**/node_modules" ] -} +} \ No newline at end of file diff --git a/vendor/eventsource-parser/stream.js b/vendor/eventsource-parser/stream.js new file mode 100644 index 0000000..88465da --- /dev/null +++ b/vendor/eventsource-parser/stream.js @@ -0,0 +1,198 @@ +// Source: https://github.com/rexxars/eventsource-parser/tree/v1.1.2 +// +// MIT License +// +// Copyright (c) 2024 Espen Hovlandsdal +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. +var __defProp = Object.defineProperty; +var __getOwnPropDesc = Object.getOwnPropertyDescriptor; +var __getOwnPropNames = Object.getOwnPropertyNames; +var __hasOwnProp = Object.prototype.hasOwnProperty; +var __export = (target, all) => { + for (var name in all) + __defProp(target, name, { get: all[name], enumerable: true }); +}; +var __copyProps = (to, from, except, desc) => { + if ((from && typeof from === "object") || typeof from === "function") { + for (let key of __getOwnPropNames(from)) + if (!__hasOwnProp.call(to, key) && key !== except) + __defProp(to, key, { + get: () => from[key], + enumerable: !(desc = __getOwnPropDesc(from, key)) || desc.enumerable, + }); + } + return to; +}; +var __toCommonJS = (mod) => + __copyProps(__defProp({}, "__esModule", { value: true }), mod); + +// /input.ts +var input_exports = {}; +__export(input_exports, { + EventSourceParserStream: () => EventSourceParserStream, +}); +module.exports = __toCommonJS(input_exports); + +// http-url:https://unpkg.com/eventsource-parser@1.1.2/dist/index.js +function createParser(onParse) { + let isFirstChunk; + let buffer; + let startingPosition; + let startingFieldLength; + let eventId; + let eventName; + let data; + reset(); + return { + feed, + reset, + }; + function reset() { + isFirstChunk = true; + buffer = ""; + startingPosition = 0; + startingFieldLength = -1; + eventId = void 0; + eventName = void 0; + data = ""; + } + function feed(chunk) { + buffer = buffer ? buffer + chunk : chunk; + if (isFirstChunk && hasBom(buffer)) { + buffer = buffer.slice(BOM.length); + } + isFirstChunk = false; + const length = buffer.length; + let position = 0; + let discardTrailingNewline = false; + while (position < length) { + if (discardTrailingNewline) { + if (buffer[position] === "\n") { + ++position; + } + discardTrailingNewline = false; + } + let lineLength = -1; + let fieldLength = startingFieldLength; + let character; + for ( + let index = startingPosition; + lineLength < 0 && index < length; + ++index + ) { + character = buffer[index]; + if (character === ":" && fieldLength < 0) { + fieldLength = index - position; + } else if (character === "\r") { + discardTrailingNewline = true; + lineLength = index - position; + } else if (character === "\n") { + lineLength = index - position; + } + } + if (lineLength < 0) { + startingPosition = length - position; + startingFieldLength = fieldLength; + break; + } else { + startingPosition = 0; + startingFieldLength = -1; + } + parseEventStreamLine(buffer, position, fieldLength, lineLength); + position += lineLength + 1; + } + if (position === length) { + buffer = ""; + } else if (position > 0) { + buffer = buffer.slice(position); + } + } + function parseEventStreamLine(lineBuffer, index, fieldLength, lineLength) { + if (lineLength === 0) { + if (data.length > 0) { + onParse({ + type: "event", + id: eventId, + event: eventName || void 0, + data: data.slice(0, -1), + // remove trailing newline + }); + data = ""; + eventId = void 0; + } + eventName = void 0; + return; + } + const noValue = fieldLength < 0; + const field = lineBuffer.slice( + index, + index + (noValue ? lineLength : fieldLength) + ); + let step = 0; + if (noValue) { + step = lineLength; + } else if (lineBuffer[index + fieldLength + 1] === " ") { + step = fieldLength + 2; + } else { + step = fieldLength + 1; + } + const position = index + step; + const valueLength = lineLength - step; + const value = lineBuffer.slice(position, position + valueLength).toString(); + if (field === "data") { + data += value ? "".concat(value, "\n") : "\n"; + } else if (field === "event") { + eventName = value; + } else if (field === "id" && !value.includes("\0")) { + eventId = value; + } else if (field === "retry") { + const retry = parseInt(value, 10); + if (!Number.isNaN(retry)) { + onParse({ + type: "reconnect-interval", + value: retry, + }); + } + } + } +} +var BOM = [239, 187, 191]; +function hasBom(buffer) { + return BOM.every((charCode, index) => buffer.charCodeAt(index) === charCode); +} + +// http-url:https://unpkg.com/eventsource-parser@1.1.2/dist/stream.js +var EventSourceParserStream = class extends TransformStream { + constructor() { + let parser; + super({ + start(controller) { + parser = createParser((event) => { + if (event.type === "event") { + controller.enqueue(event); + } + }); + }, + transform(chunk) { + parser.feed(chunk); + }, + }); + } +};