diff --git a/replicate/prediction.py b/replicate/prediction.py index 84afaae9..54f5db66 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -52,10 +52,24 @@ def cancel(self): class PredictionCollection(Collection): model = Prediction - def create(self, version: Version, input: Dict[str, Any]) -> Prediction: + def create( + self, + version: Version, + input: Dict[str, Any], + webhook_completed: Optional[str] = None, + ) -> Prediction: input = encode_json(input, upload_file=upload_file) + body = { + "version": version.id, + "input": input, + } + if webhook_completed is not None: + body["webhook_completed"] = webhook_completed + resp = self._client._request( - "POST", "/v1/predictions", json={"version": version.id, "input": input} + "POST", + "/v1/predictions", + json=body, ) obj = resp.json() obj["version"] = version diff --git a/tests/test_prediction.py b/tests/test_prediction.py index 47c1dceb..4d0dcef8 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -13,7 +13,13 @@ def test_cancel(): responses.post( "https://api.replicate.com/v1/predictions", match=[ - matchers.json_params_matcher({"version": "v1", "input": {"text": "world"}}) + matchers.json_params_matcher( + { + "version": "v1", + "input": {"text": "world"}, + "webhook_completed": "https://example.com/webhook", + } + ), ], json={ "id": "p1", @@ -33,7 +39,11 @@ def test_cancel(): }, ) - prediction = client.predictions.create(version=version, input={"text": "world"}) + prediction = client.predictions.create( + version=version, + input={"text": "world"}, + webhook_completed="https://example.com/webhook", + ) rsp = responses.post("https://api.replicate.com/v1/predictions/p1/cancel", json={}) prediction.cancel()