Skip to content

Commit 3b23dfa

Browse files
zekeSBrandeis
andauthored
[Inference] Add image-to-image support for Replicate (#1564)
A redo of #1427 cc @SBrandeis @Vaibhavs10 --------- Co-authored-by: SBrandeis <simon@huggingface.co>
1 parent 86ec6ef commit 3b23dfa

File tree

3 files changed

+76
-2
lines changed

3 files changed

+76
-2
lines changed

packages/inference/src/lib/getProviderHelper.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
138138
"text-to-image": new Replicate.ReplicateTextToImageTask(),
139139
"text-to-speech": new Replicate.ReplicateTextToSpeechTask(),
140140
"text-to-video": new Replicate.ReplicateTextToVideoTask(),
141+
"image-to-image": new Replicate.ReplicateImageToImageTask(),
141142
},
142143
sambanova: {
143144
conversational: new Sambanova.SambanovaConversationalTask(),

packages/inference/src/providers/replicate.ts

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,16 @@
1616
*/
1717
import { InferenceClientProviderOutputError } from "../errors.js";
1818
import { isUrl } from "../lib/isUrl.js";
19-
import type { BodyParams, HeaderParams, UrlParams } from "../types.js";
19+
import type { BodyParams, HeaderParams, RequestArgs, UrlParams } from "../types.js";
2020
import { omit } from "../utils/omit.js";
21-
import { TaskProviderHelper, type TextToImageTaskHelper, type TextToVideoTaskHelper } from "./providerHelper.js";
21+
import {
22+
TaskProviderHelper,
23+
type ImageToImageTaskHelper,
24+
type TextToImageTaskHelper,
25+
type TextToVideoTaskHelper,
26+
} from "./providerHelper.js";
27+
import type { ImageToImageArgs } from "../tasks/cv/imageToImage.js";
28+
import { base64FromBytes } from "../utils/base64FromBytes.js";
2229
export interface ReplicateOutput {
2330
output?: string | string[];
2431
}
@@ -152,3 +159,57 @@ export class ReplicateTextToVideoTask extends ReplicateTask implements TextToVid
152159
throw new InferenceClientProviderOutputError("Received malformed response from Replicate text-to-video API");
153160
}
154161
}
162+
163+
export class ReplicateImageToImageTask extends ReplicateTask implements ImageToImageTaskHelper {
164+
override preparePayload(params: BodyParams<ImageToImageArgs>): Record<string, unknown> {
165+
return {
166+
input: {
167+
...omit(params.args, ["inputs", "parameters"]),
168+
...params.args.parameters,
169+
input_image: params.args.inputs, // This will be processed in preparePayloadAsync
170+
},
171+
version: params.model.includes(":") ? params.model.split(":")[1] : undefined,
172+
};
173+
}
174+
175+
async preparePayloadAsync(args: ImageToImageArgs): Promise<RequestArgs> {
176+
const { inputs, ...restArgs } = args;
177+
178+
// Convert Blob to base64 data URL
179+
const bytes = new Uint8Array(await inputs.arrayBuffer());
180+
const base64 = base64FromBytes(bytes);
181+
const imageInput = `data:${inputs.type || "image/jpeg"};base64,${base64}`;
182+
183+
return {
184+
...restArgs,
185+
inputs: imageInput,
186+
};
187+
}
188+
189+
override async getResponse(response: ReplicateOutput): Promise<Blob> {
190+
if (
191+
typeof response === "object" &&
192+
!!response &&
193+
"output" in response &&
194+
Array.isArray(response.output) &&
195+
response.output.length > 0 &&
196+
typeof response.output[0] === "string"
197+
) {
198+
const urlResponse = await fetch(response.output[0]);
199+
return await urlResponse.blob();
200+
}
201+
202+
if (
203+
typeof response === "object" &&
204+
!!response &&
205+
"output" in response &&
206+
typeof response.output === "string" &&
207+
isUrl(response.output)
208+
) {
209+
const urlResponse = await fetch(response.output);
210+
return await urlResponse.blob();
211+
}
212+
213+
throw new InferenceClientProviderOutputError("Received malformed response from Replicate image-to-image API");
214+
}
215+
}

packages/inference/test/InferenceClient.spec.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,6 +1277,18 @@ describe.skip("InferenceClient", () => {
12771277

12781278
expect(res).toBeInstanceOf(Blob);
12791279
});
1280+
1281+
it("imageToImage - FLUX Kontext Dev", async () => {
1282+
const res = await client.imageToImage({
1283+
model: "black-forest-labs/flux-kontext-dev",
1284+
provider: "replicate",
1285+
inputs: new Blob([readTestFile("stormtrooper_depth.png")], { type: "image/png" }),
1286+
parameters: {
1287+
prompt: "Change the stormtrooper armor to golden color while keeping the same pose and helmet design",
1288+
},
1289+
});
1290+
expect(res).toBeInstanceOf(Blob);
1291+
});
12801292
},
12811293
TIMEOUT
12821294
);

0 commit comments

Comments
 (0)