From 18ca5664c76cefbe3b4703cdc37f2c9d56e4cdd2 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Wed, 12 Apr 2023 04:54:33 -0700 Subject: [PATCH] Add support for predictions.cancel endpoint --- index.d.ts | 1 + index.js | 1 + index.test.ts | 35 +++++++++++++++++++++++++++++++++++ lib/predictions.js | 13 +++++++++++++ 4 files changed, 50 insertions(+) diff --git a/index.d.ts b/index.d.ts index 02a2f8d..5671f62 100644 --- a/index.d.ts +++ b/index.d.ts @@ -113,6 +113,7 @@ declare module 'replicate' { webhook_events_filter?: WebhookEventType[]; }): Promise; get(prediction_id: string): Promise; + cancel(prediction_id: string): Promise; list(): Promise>; }; diff --git a/index.js b/index.js index aae4639..ff4959a 100644 --- a/index.js +++ b/index.js @@ -54,6 +54,7 @@ class Replicate { this.predictions = { create: predictions.create.bind(this), get: predictions.get.bind(this), + cancel: predictions.cancel.bind(this), list: predictions.list.bind(this), }; diff --git a/index.test.ts b/index.test.ts index 89bcc8d..614cfbd 100644 --- a/index.test.ts +++ b/index.test.ts @@ -150,6 +150,41 @@ describe('Replicate client', () => { // Add more tests for error handling, edge cases, etc. }); + describe('predictions.cancel', () => { + test('Calls the correct API route with the correct payload', async () => { + nock(BASE_URL) + .post('/predictions/ufawqhfynnddngldkgtslldrkq/cancel') + .reply(200, { + id: 'ufawqhfynnddngldkgtslldrkq', + version: + '5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa', + urls: { + get: 'https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq', + cancel: + 'https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel', + }, + created_at: '2022-04-26T22:13:06.224088Z', + started_at: '2022-04-26T22:13:06.224088Z', + completed_at: '2022-04-26T22:14:06.224088Z', + status: 'canceled', + input: { + text: 'Alice', + }, + output: null, + error: null, + logs: null, + metrics: {}, + }); + + const prediction = await client.predictions.cancel( + 'ufawqhfynnddngldkgtslldrkq' + ); + expect(prediction.status).toBe('canceled'); + }); + + // Add more tests for error handling, edge cases, etc. + }); + describe('predictions.list', () => { test('Calls the correct API route with the correct payload', async () => { nock(BASE_URL) diff --git a/lib/predictions.js b/lib/predictions.js index d654a73..47f2b51 100644 --- a/lib/predictions.js +++ b/lib/predictions.js @@ -39,6 +39,18 @@ async function getPrediction(prediction_id) { }); } +/** + * Cancel a prediction by ID + * + * @param {string} prediction_id - Required. The training ID + * @returns {Promise} Resolves with the data for the training + */ +async function cancelPrediction(prediction_id) { + return this.request(`/predictions/${prediction_id}/cancel`, { + method: 'POST', + }); +} + /** * List all predictions * @@ -53,5 +65,6 @@ async function listPredictions() { module.exports = { create: createPrediction, get: getPrediction, + cancel: cancelPrediction, list: listPredictions, };