Skip to content

Commit

Permalink
feat: add img2img loopback (#331)
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Apr 22, 2023
1 parent 7b0095a commit 00fb64b
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 1 deletion.
76 changes: 75 additions & 1 deletion api/onnx_web/diffusers/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,70 @@ def parse_prompt(params: ImageParams) -> Tuple[List[Tuple[str, float]], List[Tup
return loras, inversions


def run_loopback(
job: WorkerContext,
server: ServerContext,
params: ImageParams,
image: Image.Image,
progress: ProgressCallback,
inversions: List[Tuple[str, float]],
loras: List[Tuple[str, float]],
) -> Image.Image:
if params.loopback == 0:
return image

# load img2img pipeline once
pipe_type = "lpw" if params.lpw() else "img2img"
pipe = load_pipeline(
server,
pipe_type,
params.model,
params.scheduler,
job.get_device(),
inversions=inversions,
loras=loras,
)

def loopback_iteration(source: Image.Image):
if params.lpw():
logger.debug("using LPW pipeline for loopback")
rng = torch.manual_seed(params.seed)
result = pipe.img2img(
source,
params.prompt,
generator=rng,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
num_images_per_prompt=1,
num_inference_steps=params.steps,
strength=params.strength,
eta=params.eta,
callback=progress,
)
return result.images[0]
else:
logger.debug("using img2img pipeline for loopback")
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
source,
generator=rng,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
num_images_per_prompt=1,
num_inference_steps=params.steps,
strength=params.strength,
eta=params.eta,
callback=progress,
)
return result.images[0]

for _i in range(params.loopback):
image = loopback_iteration(image)

return image


def run_highres(
job: WorkerContext,
server: ServerContext,
Expand All @@ -58,7 +122,7 @@ def run_highres(
progress: ProgressCallback,
inversions: List[Tuple[str, float]],
loras: List[Tuple[str, float]],
) -> None:
) -> Image.Image:
if highres.scale <= 1:
return image

Expand Down Expand Up @@ -137,6 +201,7 @@ def highres_tile(tile: Image.Image, dims):
)
return result.images[0]
else:
logger.debug("using img2img pipeline for highres")
rng = np.random.RandomState(params.seed)
result = highres_pipe(
params.prompt,
Expand Down Expand Up @@ -232,6 +297,15 @@ def run_txt2img_pipeline(
del pipe

for image, output in image_outputs:
image = run_loopback(
job,
server,
params,
progress,
inversions,
loras,
)

image = run_highres(
job,
server,
Expand Down
5 changes: 5 additions & 0 deletions api/onnx_web/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ class ImageParams:
control: Optional[NetworkModel]
input_prompt: str
input_negative_prompt: str
loopback: int

def __init__(
self,
Expand All @@ -191,6 +192,7 @@ def __init__(
control: Optional[NetworkModel] = None,
input_prompt: Optional[str] = None,
input_negative_prompt: Optional[str] = None,
loopback: int = 0,
) -> None:
self.model = model
self.pipeline = pipeline
Expand All @@ -205,6 +207,7 @@ def __init__(
self.control = control
self.input_prompt = input_prompt or prompt
self.input_negative_prompt = input_negative_prompt or negative_prompt
self.loopback = loopback

def lpw(self):
return self.pipeline == "lpw"
Expand All @@ -224,6 +227,7 @@ def tojson(self) -> Dict[str, Optional[Param]]:
"control": self.control.name if self.control is not None else "",
"input_prompt": self.input_prompt,
"input_negative_prompt": self.input_negative_prompt,
"loopback": self.loopback,
}

def with_args(self, **kwargs):
Expand All @@ -241,6 +245,7 @@ def with_args(self, **kwargs):
kwargs.get("control", self.control),
kwargs.get("input_prompt", self.input_prompt),
kwargs.get("input_negative_prompt", self.input_negative_prompt),
kwargs.get("loopback", self.loopback),
)


Expand Down
8 changes: 8 additions & 0 deletions api/onnx_web/server/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,13 @@ def pipeline_from_request(
get_config_value("eta", "max"),
get_config_value("eta", "min"),
)
loopback = get_and_clamp_int(
request.args,
"loopback",
get_config_value("loopback"),
get_config_value("loopback", "max"),
get_config_value("loopback", "min"),
)
steps = get_and_clamp_int(
request.args,
"steps",
Expand Down Expand Up @@ -145,6 +152,7 @@ def pipeline_from_request(
negative_prompt=negative_prompt,
batch=batch,
control=control,
loopback=loopback,
)
size = Size(width, height)
return (device, params, size)
Expand Down
6 changes: 6 additions & 0 deletions api/params.json
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@
"max": 512,
"step": 8
},
"loopback": {
"default": 0,
"min": 0,
"max": 10,
"step": 1
},
"model": {
"default": "stable-diffusion-onnx-v1-5",
"keys": []
Expand Down
6 changes: 6 additions & 0 deletions gui/examples/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@
"max": 512,
"step": 8
},
"loopback": {
"default": 0,
"min": 0,
"max": 10,
"step": 1
},
"model": {
"default": "stable-diffusion-onnx-v1-5",
"keys": []
Expand Down
2 changes: 2 additions & 0 deletions gui/src/client/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ export interface Txt2ImgParams extends BaseImgParams {
export interface Img2ImgParams extends BaseImgParams {
source: Blob;

loopback: number;
sourceFilter?: string;
strength: number;
}
Expand Down Expand Up @@ -518,6 +519,7 @@ export function makeClient(root: string, f = fetch): ApiClient {
const url = makeImageURL(root, 'img2img', params);
appendModelToURL(url, model);

url.searchParams.append('loopback', params.loopback.toFixed(FIXED_INTEGER));
url.searchParams.append('strength', params.strength.toFixed(FIXED_FLOAT));

if (doesExist(params.sourceFilter)) {
Expand Down
13 changes: 13 additions & 0 deletions gui/src/components/tab/Img2Img.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ export function Img2Img() {
const source = useStore(state, (s) => s.img2img.source);
const sourceFilter = useStore(state, (s) => s.img2img.sourceFilter);
const strength = useStore(state, (s) => s.img2img.strength);
const loopback = useStore(state, (s) => s.img2img.loopback);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setImg2Img = useStore(state, (s) => s.setImg2Img);
// eslint-disable-next-line @typescript-eslint/unbound-method
Expand Down Expand Up @@ -112,6 +113,18 @@ export function Img2Img() {
});
}}
/>
<NumericField
label={t('parameter.loopback')}
min={params.loopback.min}
max={params.loopback.max}
step={params.loopback.step}
value={loopback}
onChange={(value) => {
setImg2Img({
loopback: value,
});
}}
/>
</Stack>
<HighresControl />
<UpscaleControl />
Expand Down
2 changes: 2 additions & 0 deletions gui/src/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ export function createStateSlices(server: ServerParams) {
const createImg2ImgSlice: Slice<Img2ImgSlice> = (set) => ({
img2img: {
...base,
loopback: server.loopback.default,
source: null,
sourceFilter: '',
strength: server.strength.default,
Expand All @@ -273,6 +274,7 @@ export function createStateSlices(server: ServerParams) {
set({
img2img: {
...base,
loopback: server.loopback.default,
source: null,
sourceFilter: '',
strength: server.strength.default,
Expand Down

0 comments on commit 00fb64b

Please sign in to comment.