Skip to content

Commit

Permalink
added low-res-fix using tile-- let's see if it fixes or not
Browse files Browse the repository at this point in the history
  • Loading branch information
usamaehsan committed Nov 19, 2023
1 parent 8903b91 commit 802386b
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 13 deletions.
70 changes: 57 additions & 13 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,17 @@ def setup(self):
torch_dtype=torch.float16,
local_files_only=True,
).to("cuda")

self.tile_pipe= StableDiffusionControlNetPipeline(
vae=self.pipe.vae,
text_encoder=self.pipe.text_encoder,
tokenizer=self.pipe.tokenizer,
unet=self.pipe.unet,
scheduler=self.pipe.scheduler,
safety_checker=self.pipe.safety_checker,
feature_extractor=self.pipe.feature_extractor,
controlnet=self.controlnets['tile'],
)

self.canny = CannyDetector()

Expand Down Expand Up @@ -224,6 +235,17 @@ def make_inpaint_condition(self, image, image_mask):

return image

def resize_for_condition_image(self, input_image: Image, resolution: int):
input_image = input_image.convert("RGB")
W, H = input_image.size
k = float(resolution) / min(H, W)
H *= k
W *= k
H = int(round(H / 64.0)) * 64
W = int(round(W / 64.0)) * 64
img = input_image.resize((W, H), resample=Image.LANCZOS)
return img

def build_pipe(
self, inputs, max_width, max_height, guess_mode=False
):
Expand All @@ -237,10 +259,9 @@ def build_pipe(
init_image= None
got_size= False
for name, [image, conditioning_scale, mask_image] in inputs.items():
print(name)
if image is None:
continue

print(name)
image = Image.open(image)
if not got_size:
image= resize_image(image, max_width, max_height)
Expand Down Expand Up @@ -403,6 +424,15 @@ def predict(
sorted_controlnets: str = Input(
description="Comma seperated string of controlnet names, list of names: tile, inpainting, lineart,depth ,scribble , brightness /// example value: tile, inpainting, lineart ", default="tile, inpainting, lineart"
),
low_res_fix: bool = Input(
description="Using controlnet tile- after image generation", default=False
),
low_res_fix_resolution: int = Input(
description="controlnet tile resolution- after image generation", default=768
),
low_res_fix_steps: int = Input(
description="controlnet tile resolution- after image generation", default=10
),
) -> List[Path]:
if len(MISSING_WEIGHTS) > 0:
raise Exception("missing weights")
Expand Down Expand Up @@ -454,18 +484,32 @@ def predict(

if output.nsfw_content_detected and output.nsfw_content_detected[0]:
continue

output_path = f"/tmp/seed-{this_seed}.png"
# if consistency_decoder:
# print("Running consistency decoder...")
# start = time.time()
# sample = self.consistency_decoder(
# output.images[0].unsqueeze(0) / self.pipe.vae.config.scaling_factor
# )
# print("Consistency decoder took", time.time() - start, "seconds")
# save_image(sample, output_path)
# else:
output.images[0].save(output_path)
if low_res_fix:
seed= torch.manual_seed(0)
condition_image = self.resize_for_condition_image(output.images[0], low_res_fix_resolution)
tile_output= self.tile_pipe(
prompt="best quality",
negative_prompt="blur, lowres, bad anatomy, bad hands, cropped, worst quality",
num_inference_steps= low_res_fix_steps,
width=condition_image.size[0],
height=condition_image.size[1],
controlnet_conditioning_scale=1.0,
generator=seed,
)
tile_output.images[0].save(output_path)
else:
# if consistency_decoder:
# print("Running consistency decoder...")
# start = time.time()
# sample = self.consistency_decoder(
# output.images[0].unsqueeze(0) / self.pipe.vae.config.scaling_factor
# )
# print("Consistency decoder took", time.time() - start, "seconds")
# save_image(sample, output_path)
# else:
output.images[0].save(output_path)

output_paths.append(Path(output_path))

Expand Down
1 change: 1 addition & 0 deletions script/download_weights
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ for name, model in AUX_IDS.items():
aux = ControlNetModel.from_pretrained(
model,
cache_dir=TMP_CACHE,
torch_dtype=torch.float16
)
aux = aux.half()
aux.save_pretrained(os.path.join(CONTROLNET_CACHE, name))
Expand Down

0 comments on commit 802386b

Please sign in to comment.