diff --git a/index.d.ts b/index.d.ts index 76df9d4..ab40ed8 100644 --- a/index.d.ts +++ b/index.d.ts @@ -89,6 +89,10 @@ declare module "replicate" { retry?: number; } + export interface WebhookSecret { + key: string; + } + export default class Replicate { constructor(options?: { auth?: string; @@ -233,5 +237,26 @@ declare module "replicate" { cancel(training_id: string): Promise; list(): Promise>; }; + + webhooks: { + default: { + secret: { + get(): Promise; + }; + }; + }; } + + export function validateWebhook( + requestData: + | Request + | { + id?: string; + timestamp?: string; + body: string; + secret?: string; + signature?: string; + }, + secret: string + ): boolean; } diff --git a/index.js b/index.js index a85ea4e..13207ca 100644 --- a/index.js +++ b/index.js @@ -1,7 +1,7 @@ const ApiError = require("./lib/error"); const ModelVersionIdentifier = require("./lib/identifier"); const { Stream } = require("./lib/stream"); -const { withAutomaticRetries } = require("./lib/util"); +const { withAutomaticRetries, validateWebhook } = require("./lib/util"); const accounts = require("./lib/accounts"); const collections = require("./lib/collections"); @@ -10,6 +10,7 @@ const hardware = require("./lib/hardware"); const models = require("./lib/models"); const predictions = require("./lib/predictions"); const trainings = require("./lib/trainings"); +const webhooks = require("./lib/webhooks"); const packageJSON = require("./package.json"); @@ -90,6 +91,14 @@ class Replicate { cancel: trainings.cancel.bind(this), list: trainings.list.bind(this), }; + + this.webhooks = { + default: { + secret: { + get: webhooks.default.secret.get.bind(this), + }, + }, + }; } /** @@ -364,3 +373,4 @@ class Replicate { } module.exports = Replicate; +module.exports.validateWebhook = validateWebhook; diff --git a/index.test.ts b/index.test.ts index d50ccb4..106cc58 100644 --- a/index.test.ts +++ b/index.test.ts @@ -1,5 +1,10 @@ import { expect, jest, test } from "@jest/globals"; -import Replicate, { ApiError, Model, Prediction } from "replicate"; +import Replicate, { + ApiError, + Model, + Prediction, + validateWebhook, +} from "replicate"; import nock from "nock"; import fetch from "cross-fetch"; @@ -996,5 +1001,39 @@ describe("Replicate client", () => { }); }); + describe("webhooks.default.secret.get", () => { + test("Calls the correct API route", async () => { + nock(BASE_URL).get("/webhooks/default/secret").reply(200, { + key: "whsec_5WbX5kEWLlfzsGNjH64I8lOOqUB6e8FH", + }); + + const secret = await client.webhooks.default.secret.get(); + expect(secret.key).toBe("whsec_5WbX5kEWLlfzsGNjH64I8lOOqUB6e8FH"); + }); + + test("Can be used to validate webhook", async () => { + // Test case from https://github.com/svix/svix-webhooks/blob/b41728cd98a7e7004a6407a623f43977b82fcba4/javascript/src/webhook.test.ts#L190-L200 + const request = new Request("http://test.host/webhook", { + method: "POST", + headers: { + "Content-Type": "application/json", + "Webhook-ID": "msg_p5jXN8AQM9LWM0D4loKWxJek", + "Webhook-Timestamp": "1614265330", + "Webhook-Signature": + "v1,g0hM9SsE+OTPJTGt/tmIKtSyZlE3uFJELVlNIOLJ1OE=", + }, + body: `{"test": 2432232314}`, + }); + + // This is a test secret and should not be used in production + const secret = "whsec_MfKQ9r8GKYqrTwjUPD8ILPZIo2LaLaSw"; + + const isValid = await validateWebhook(request, secret); + expect(isValid).toBe(true); + }); + + // Add more tests for error handling, edge cases, etc. + }); + // Continue with tests for other methods }); diff --git a/lib/util.js b/lib/util.js index 7b12633..6bd70ec 100644 --- a/lib/util.js +++ b/lib/util.js @@ -1,5 +1,93 @@ +const crypto = require("node:crypto"); + const ApiError = require("./error"); +/** + * @see {@link validateWebhook} + * @overload + * @param {object} requestData - The request data + * @param {string} requestData.id - The webhook ID header from the incoming request. + * @param {string} requestData.timestamp - The webhook timestamp header from the incoming request. + * @param {string} requestData.body - The raw body of the incoming webhook request. + * @param {string} requestData.secret - The webhook secret, obtained from `replicate.webhooks.defaul.secret` method. + * @param {string} requestData.signature - The webhook signature header from the incoming request, comprising one or more space-delimited signatures. + */ + +/** + * @see {@link validateWebhook} + * @overload + * @param {object} requestData - The request object + * @param {object} requestData.headers - The request headers + * @param {string} requestData.headers["webhook-id"] - The webhook ID header from the incoming request + * @param {string} requestData.headers["webhook-timestamp"] - The webhook timestamp header from the incoming request + * @param {string} requestData.headers["webhook-signature"] - The webhook signature header from the incoming request, comprising one or more space-delimited signatures + * @param {string} requestData.body - The raw body of the incoming webhook request + * @param {string} secret - The webhook secret, obtained from `replicate.webhooks.defaul.secret` method + */ + +/** + * Validate a webhook signature + * + * @returns {boolean} - True if the signature is valid + * @throws {Error} - If the request is missing required headers, body, or secret + */ +async function validateWebhook(requestData, secret) { + let { id, timestamp, body, signature } = requestData; + const signingSecret = secret || requestData.secret; + + if (requestData && requestData.headers && requestData.body) { + id = requestData.headers.get("webhook-id"); + timestamp = requestData.headers.get("webhook-timestamp"); + signature = requestData.headers.get("webhook-signature"); + body = requestData.body; + } + + if (body instanceof ReadableStream || body.readable) { + try { + const chunks = []; + for await (const chunk of body) { + chunks.push(Buffer.from(chunk)); + } + body = Buffer.concat(chunks).toString("utf8"); + } catch (err) { + throw new Error(`Error reading body: ${err.message}`); + } + } else if (body instanceof Buffer) { + body = body.toString("utf8"); + } else if (typeof body !== "string") { + throw new Error("Invalid body type"); + } + + if (!id || !timestamp || !signature) { + throw new Error("Missing required webhook headers"); + } + + if (!body) { + throw new Error("Missing required body"); + } + + if (!signingSecret) { + throw new Error("Missing required secret"); + } + + const signedContent = `${id}.${timestamp}.${body}`; + + const secretBytes = Buffer.from(signingSecret.split("_")[1], "base64"); + + const computedSignature = crypto + .createHmac("sha256", secretBytes) + .update(signedContent) + .digest("base64"); + + const expectedSignatures = signature + .split(" ") + .map((sig) => sig.split(",")[1]); + + return expectedSignatures.some( + (expectedSignature) => expectedSignature === computedSignature + ); +} + /** * Automatically retry a request if it fails with an appropriate status code. * @@ -68,4 +156,4 @@ async function withAutomaticRetries(request, options = {}) { return request(); } -module.exports = { withAutomaticRetries }; +module.exports = { validateWebhook, withAutomaticRetries }; diff --git a/lib/webhooks.js b/lib/webhooks.js new file mode 100644 index 0000000..f1324ec --- /dev/null +++ b/lib/webhooks.js @@ -0,0 +1,20 @@ +/** + * Get the default webhook signing secret + * + * @returns {Promise} Resolves with the signing secret for the default webhook + */ +async function getDefaultWebhookSecret() { + const response = await this.request("/webhooks/default/secret", { + method: "GET", + }); + + return response.json(); +} + +module.exports = { + default: { + secret: { + get: getDefaultWebhookSecret, + }, + }, +};