Skip to content

Commit

Permalink
Add support for validating webhooks (#200)
Browse files Browse the repository at this point in the history
* Add support for validating webhooks

* Use test case from official Svix repo

* Add comment about test secret
  • Loading branch information
mattt committed Feb 16, 2024
1 parent c6fbd33 commit c0d2a01
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 3 deletions.
25 changes: 25 additions & 0 deletions index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ declare module "replicate" {
retry?: number;
}

export interface WebhookSecret {
key: string;
}

export default class Replicate {
constructor(options?: {
auth?: string;
Expand Down Expand Up @@ -233,5 +237,26 @@ declare module "replicate" {
cancel(training_id: string): Promise<Training>;
list(): Promise<Page<Training>>;
};

webhooks: {
default: {
secret: {
get(): Promise<WebhookSecret>;
};
};
};
}

export function validateWebhook(
requestData:
| Request
| {
id?: string;
timestamp?: string;
body: string;
secret?: string;
signature?: string;
},
secret: string
): boolean;
}
12 changes: 11 additions & 1 deletion index.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
const ApiError = require("./lib/error");
const ModelVersionIdentifier = require("./lib/identifier");
const { Stream } = require("./lib/stream");
const { withAutomaticRetries } = require("./lib/util");
const { withAutomaticRetries, validateWebhook } = require("./lib/util");

const accounts = require("./lib/accounts");
const collections = require("./lib/collections");
Expand All @@ -10,6 +10,7 @@ const hardware = require("./lib/hardware");
const models = require("./lib/models");
const predictions = require("./lib/predictions");
const trainings = require("./lib/trainings");
const webhooks = require("./lib/webhooks");

const packageJSON = require("./package.json");

Expand Down Expand Up @@ -90,6 +91,14 @@ class Replicate {
cancel: trainings.cancel.bind(this),
list: trainings.list.bind(this),
};

this.webhooks = {
default: {
secret: {
get: webhooks.default.secret.get.bind(this),
},
},
};
}

/**
Expand Down Expand Up @@ -364,3 +373,4 @@ class Replicate {
}

module.exports = Replicate;
module.exports.validateWebhook = validateWebhook;
41 changes: 40 additions & 1 deletion index.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import { expect, jest, test } from "@jest/globals";
import Replicate, { ApiError, Model, Prediction } from "replicate";
import Replicate, {
ApiError,
Model,
Prediction,
validateWebhook,
} from "replicate";
import nock from "nock";
import fetch from "cross-fetch";

Expand Down Expand Up @@ -996,5 +1001,39 @@ describe("Replicate client", () => {
});
});

describe("webhooks.default.secret.get", () => {
test("Calls the correct API route", async () => {
nock(BASE_URL).get("/webhooks/default/secret").reply(200, {
key: "whsec_5WbX5kEWLlfzsGNjH64I8lOOqUB6e8FH",
});

const secret = await client.webhooks.default.secret.get();
expect(secret.key).toBe("whsec_5WbX5kEWLlfzsGNjH64I8lOOqUB6e8FH");
});

test("Can be used to validate webhook", async () => {
// Test case from https://github.com/svix/svix-webhooks/blob/b41728cd98a7e7004a6407a623f43977b82fcba4/javascript/src/webhook.test.ts#L190-L200
const request = new Request("http://test.host/webhook", {
method: "POST",
headers: {
"Content-Type": "application/json",
"Webhook-ID": "msg_p5jXN8AQM9LWM0D4loKWxJek",
"Webhook-Timestamp": "1614265330",
"Webhook-Signature":
"v1,g0hM9SsE+OTPJTGt/tmIKtSyZlE3uFJELVlNIOLJ1OE=",
},
body: `{"test": 2432232314}`,
});

// This is a test secret and should not be used in production
const secret = "whsec_MfKQ9r8GKYqrTwjUPD8ILPZIo2LaLaSw";

const isValid = await validateWebhook(request, secret);
expect(isValid).toBe(true);
});

// Add more tests for error handling, edge cases, etc.
});

// Continue with tests for other methods
});
90 changes: 89 additions & 1 deletion lib/util.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,93 @@
const crypto = require("node:crypto");

const ApiError = require("./error");

/**
* @see {@link validateWebhook}
* @overload
* @param {object} requestData - The request data
* @param {string} requestData.id - The webhook ID header from the incoming request.
* @param {string} requestData.timestamp - The webhook timestamp header from the incoming request.
* @param {string} requestData.body - The raw body of the incoming webhook request.
* @param {string} requestData.secret - The webhook secret, obtained from `replicate.webhooks.defaul.secret` method.
* @param {string} requestData.signature - The webhook signature header from the incoming request, comprising one or more space-delimited signatures.
*/

/**
* @see {@link validateWebhook}
* @overload
* @param {object} requestData - The request object
* @param {object} requestData.headers - The request headers
* @param {string} requestData.headers["webhook-id"] - The webhook ID header from the incoming request
* @param {string} requestData.headers["webhook-timestamp"] - The webhook timestamp header from the incoming request
* @param {string} requestData.headers["webhook-signature"] - The webhook signature header from the incoming request, comprising one or more space-delimited signatures
* @param {string} requestData.body - The raw body of the incoming webhook request
* @param {string} secret - The webhook secret, obtained from `replicate.webhooks.defaul.secret` method
*/

/**
* Validate a webhook signature
*
* @returns {boolean} - True if the signature is valid
* @throws {Error} - If the request is missing required headers, body, or secret
*/
async function validateWebhook(requestData, secret) {
let { id, timestamp, body, signature } = requestData;
const signingSecret = secret || requestData.secret;

if (requestData && requestData.headers && requestData.body) {
id = requestData.headers.get("webhook-id");
timestamp = requestData.headers.get("webhook-timestamp");
signature = requestData.headers.get("webhook-signature");
body = requestData.body;
}

if (body instanceof ReadableStream || body.readable) {
try {
const chunks = [];
for await (const chunk of body) {
chunks.push(Buffer.from(chunk));
}
body = Buffer.concat(chunks).toString("utf8");
} catch (err) {
throw new Error(`Error reading body: ${err.message}`);
}
} else if (body instanceof Buffer) {
body = body.toString("utf8");
} else if (typeof body !== "string") {
throw new Error("Invalid body type");
}

if (!id || !timestamp || !signature) {
throw new Error("Missing required webhook headers");
}

if (!body) {
throw new Error("Missing required body");
}

if (!signingSecret) {
throw new Error("Missing required secret");
}

const signedContent = `${id}.${timestamp}.${body}`;

const secretBytes = Buffer.from(signingSecret.split("_")[1], "base64");

const computedSignature = crypto
.createHmac("sha256", secretBytes)
.update(signedContent)
.digest("base64");

const expectedSignatures = signature
.split(" ")
.map((sig) => sig.split(",")[1]);

return expectedSignatures.some(
(expectedSignature) => expectedSignature === computedSignature
);
}

/**
* Automatically retry a request if it fails with an appropriate status code.
*
Expand Down Expand Up @@ -68,4 +156,4 @@ async function withAutomaticRetries(request, options = {}) {
return request();
}

module.exports = { withAutomaticRetries };
module.exports = { validateWebhook, withAutomaticRetries };
20 changes: 20 additions & 0 deletions lib/webhooks.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/**
* Get the default webhook signing secret
*
* @returns {Promise<object>} Resolves with the signing secret for the default webhook
*/
async function getDefaultWebhookSecret() {
const response = await this.request("/webhooks/default/secret", {
method: "GET",
});

return response.json();
}

module.exports = {
default: {
secret: {
get: getDefaultWebhookSecret,
},
},
};

0 comments on commit c0d2a01

Please sign in to comment.