|
16 | 16 | */
|
17 | 17 | import { InferenceClientProviderOutputError } from "../errors.js";
|
18 | 18 | import { isUrl } from "../lib/isUrl.js";
|
19 |
| -import type { BodyParams, HeaderParams, UrlParams } from "../types.js"; |
| 19 | +import type { BodyParams, HeaderParams, RequestArgs, UrlParams } from "../types.js"; |
20 | 20 | import { omit } from "../utils/omit.js";
|
21 |
| -import { TaskProviderHelper, type TextToImageTaskHelper, type TextToVideoTaskHelper } from "./providerHelper.js"; |
| 21 | +import { |
| 22 | + TaskProviderHelper, |
| 23 | + type ImageToImageTaskHelper, |
| 24 | + type TextToImageTaskHelper, |
| 25 | + type TextToVideoTaskHelper, |
| 26 | +} from "./providerHelper.js"; |
| 27 | +import type { ImageToImageArgs } from "../tasks/cv/imageToImage.js"; |
| 28 | +import { base64FromBytes } from "../utils/base64FromBytes.js"; |
22 | 29 | export interface ReplicateOutput {
|
23 | 30 | output?: string | string[];
|
24 | 31 | }
|
@@ -152,3 +159,57 @@ export class ReplicateTextToVideoTask extends ReplicateTask implements TextToVid
|
152 | 159 | throw new InferenceClientProviderOutputError("Received malformed response from Replicate text-to-video API");
|
153 | 160 | }
|
154 | 161 | }
|
| 162 | + |
| 163 | +export class ReplicateImageToImageTask extends ReplicateTask implements ImageToImageTaskHelper { |
| 164 | + override preparePayload(params: BodyParams<ImageToImageArgs>): Record<string, unknown> { |
| 165 | + return { |
| 166 | + input: { |
| 167 | + ...omit(params.args, ["inputs", "parameters"]), |
| 168 | + ...params.args.parameters, |
| 169 | + input_image: params.args.inputs, // This will be processed in preparePayloadAsync |
| 170 | + }, |
| 171 | + version: params.model.includes(":") ? params.model.split(":")[1] : undefined, |
| 172 | + }; |
| 173 | + } |
| 174 | + |
| 175 | + async preparePayloadAsync(args: ImageToImageArgs): Promise<RequestArgs> { |
| 176 | + const { inputs, ...restArgs } = args; |
| 177 | + |
| 178 | + // Convert Blob to base64 data URL |
| 179 | + const bytes = new Uint8Array(await inputs.arrayBuffer()); |
| 180 | + const base64 = base64FromBytes(bytes); |
| 181 | + const imageInput = `data:${inputs.type || "image/jpeg"};base64,${base64}`; |
| 182 | + |
| 183 | + return { |
| 184 | + ...restArgs, |
| 185 | + inputs: imageInput, |
| 186 | + }; |
| 187 | + } |
| 188 | + |
| 189 | + override async getResponse(response: ReplicateOutput): Promise<Blob> { |
| 190 | + if ( |
| 191 | + typeof response === "object" && |
| 192 | + !!response && |
| 193 | + "output" in response && |
| 194 | + Array.isArray(response.output) && |
| 195 | + response.output.length > 0 && |
| 196 | + typeof response.output[0] === "string" |
| 197 | + ) { |
| 198 | + const urlResponse = await fetch(response.output[0]); |
| 199 | + return await urlResponse.blob(); |
| 200 | + } |
| 201 | + |
| 202 | + if ( |
| 203 | + typeof response === "object" && |
| 204 | + !!response && |
| 205 | + "output" in response && |
| 206 | + typeof response.output === "string" && |
| 207 | + isUrl(response.output) |
| 208 | + ) { |
| 209 | + const urlResponse = await fetch(response.output); |
| 210 | + return await urlResponse.blob(); |
| 211 | + } |
| 212 | + |
| 213 | + throw new InferenceClientProviderOutputError("Received malformed response from Replicate image-to-image API"); |
| 214 | + } |
| 215 | +} |
0 commit comments