Skip to content

Commit

Permalink
prodia.ts: inpainting + sdxl + types fix
Browse files Browse the repository at this point in the history
  • Loading branch information
montyanderson committed Oct 10, 2023
1 parent 0cafc44 commit 4f72c83
Showing 1 changed file with 88 additions and 23 deletions.
111 changes: 88 additions & 23 deletions prodia.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
/* Job Responses */

type ProdiaJobBase = { job: string };

export type ProdiaJobQueued = ProdiaJobBase & { status: "queued" };
export type ProdiaJobGenerating = ProdiaJobBase & { status: "generating" };
export type ProdiaJobFailed = ProdiaJobBase & { status: "failed" };
export type ProdiaJobSucceeded = ProdiaJobBase & {
status: "succeeded";
imageUrl: string;
};
export type ProdiaJobQueued = { imageUrl: undefined; status: "queued" };
export type ProdiaJobGenerating = { imageUrl: undefined; status: "generating" };
export type ProdiaJobFailed = { imageUrl: undefined; status: "failed" };
export type ProdiaJobSucceeded = { imageUrl: string; status: "succeeded" };

export type ProdiaJob =
| ProdiaJobQueued
| ProdiaJobGenerating
| ProdiaJobFailed
| ProdiaJobSucceeded;

const a = {} as unknown as ProdiaJob;

const { imageUrl } = a;

/* Generation Requests */

export type ProdiaGenerateRequest = {
Expand All @@ -30,8 +29,9 @@ export type ProdiaGenerateRequest = {
aspect_ratio?: "square" | "portrait" | "landscape";
};

export type ProdiaTransformRequest = {
imageUrl: string;
type ImageInput = { imageUrl: string } | { imageData: string };

export type ProdiaTransformRequest = ImageInput & {
prompt: string;
model?: string;
denoising_strength?: number;
Expand All @@ -43,8 +43,7 @@ export type ProdiaTransformRequest = {
sampler?: string;
};

export type ProdiaControlnetRequest = {
imageUrl: string;
export type ProdiaControlnetRequest = ImageInput & {
controlnet_model: string;
controlnet_module?: string;
threshold_a?: number;
Expand All @@ -61,6 +60,36 @@ export type ProdiaControlnetRequest = {
height?: number;
};

type MaskInput = { maskUrl: string } | { maskData: string };

export type ProdiaInpaintingRequest = ImageInput &
MaskInput & {
prompt: string;
model?: string;
denoising_strength?: number;
negative_prompt?: string;
steps?: number;
cfg_scale?: number;
seed?: number;
upscale?: boolean;
mask_blur: number;
inpainting_fill: number;
inpainting_mask_invert: number;
inpainting_full_res: string;
sampler?: string;
};

export type ProdiaXlGenerateRequest = {
prompt: string;
model?: string;
negative_prompt?: string;
steps?: number;
cfg_scale?: number;
seed?: number;
upscale?: boolean;
sampler?: string;
};

/* Constructor Definions */

export type Prodia = ReturnType<typeof createProdia>;
Expand All @@ -74,17 +103,17 @@ export const createProdia = ({ apiKey, base: _base }: CreateProdiaOptions) => {
const base = _base || "https://api.prodia.com/v1";

const headers = {
"X-Prodia-Key": apiKey,
"X-Prodia-Key": apiKey
};

const generate = async (params: ProdiaGenerateRequest) => {
const response = await fetch(`${base}/sd/generate`, {
method: "POST",
headers: {
...headers,
"Content-Type": "application/json",
"Content-Type": "application/json"
},
body: JSON.stringify(params),
body: JSON.stringify(params)
});

if (response.status !== 200) {
Expand All @@ -99,9 +128,9 @@ export const createProdia = ({ apiKey, base: _base }: CreateProdiaOptions) => {
method: "POST",
headers: {
...headers,
"Content-Type": "application/json",
"Content-Type": "application/json"
},
body: JSON.stringify(params),
body: JSON.stringify(params)
});

if (response.status !== 200) {
Expand All @@ -116,9 +145,43 @@ export const createProdia = ({ apiKey, base: _base }: CreateProdiaOptions) => {
method: "POST",
headers: {
...headers,
"Content-Type": "application/json",
"Content-Type": "application/json"
},
body: JSON.stringify(params)
});

if (response.status !== 200) {
throw new Error(`Bad Prodia Response: ${response.status}`);
}

return (await response.json()) as ProdiaJobQueued;
};

const inpainting = async (params: ProdiaInpaintingRequest) => {
const response = await fetch(`${base}/sd/inpainting`, {
method: "POST",
headers: {
...headers,
"Content-Type": "application/json"
},
body: JSON.stringify(params)
});

if (response.status !== 200) {
throw new Error(`Bad Prodia Response: ${response.status}`);
}

return (await response.json()) as ProdiaJobQueued;
};

const xlGenerate = async (params: ProdiaXlGenerateRequest) => {
const response = await fetch(`${base}/sdxl/generate`, {
method: "POST",
headers: {
...headers,
"Content-Type": "application/json"
},
body: JSON.stringify(params),
body: JSON.stringify(params)
});

if (response.status !== 200) {
Expand All @@ -130,7 +193,7 @@ export const createProdia = ({ apiKey, base: _base }: CreateProdiaOptions) => {

const getJob = async (jobId: string) => {
const response = await fetch(`${base}/job/${jobId}`, {
headers,
headers
});

if (response.status !== 200) {
Expand All @@ -157,7 +220,7 @@ export const createProdia = ({ apiKey, base: _base }: CreateProdiaOptions) => {

const listModels = async () => {
const response = await fetch(`${base}/models/list`, {
headers,
headers
});

if (response.status !== 200) {
Expand All @@ -171,8 +234,10 @@ export const createProdia = ({ apiKey, base: _base }: CreateProdiaOptions) => {
generate,
transform,
controlnet,
inpainting,
xlGenerate,
wait,
getJob,
listModels,
listModels
};
};

0 comments on commit 4f72c83

Please sign in to comment.