Skip to content

Commit

Permalink
predict.py works
Browse files Browse the repository at this point in the history
  • Loading branch information
zsxkib committed Apr 22, 2024
1 parent ac60ab7 commit e7ea6b0
Showing 1 changed file with 60 additions and 2 deletions.
62 changes: 60 additions & 2 deletions predict.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

from cog import BasePredictor, Input, Path
from typing import List
from PIL import Image
Expand All @@ -9,7 +10,8 @@
import torch
import torch.cuda.amp as amp
import torchvision.transforms as T

import subprocess
import time

sys.path.append("flashface/all_finetune")
from config import cfg
Expand Down Expand Up @@ -44,14 +46,35 @@

retinaface = retinaface(pretrained=True, device="cuda").eval().requires_grad_(False)

def download_weights(url, dest):
start = time.time()
print("[!] Initiating download from URL: ", url)
print("[~] Destination path: ", dest)
command = ["pget", "-vf", url, dest]
if ".tar" in url:
command.append("-x")
try:
subprocess.check_call(command, close_fds=False)
except subprocess.CalledProcessError as e:
print(
f"[ERROR] Failed to download weights. Command '{' '.join(e.cmd)}' returned non-zero exit status {e.returncode}."
)
raise
print("[+] Download completed in: ", time.time() - start, "seconds")


def detect_face(imgs=None):

# read images
pil_imgs = imgs
b = len(pil_imgs)
vis_pil_imgs = copy.deepcopy(pil_imgs)

# convert RGBA images to RGB
for i in range(b):
if pil_imgs[i].mode == 'RGBA':
pil_imgs[i] = pil_imgs[i].convert('RGB')

# detection
imgs = torch.stack([retinaface_transforms(u) for u in pil_imgs]).to(gpu)
boxes, kpts = retinaface.detect(imgs, min_thr=0.6)
Expand Down Expand Up @@ -349,13 +372,26 @@ def predict(
description="Default negative prompt postfix",
default="blurry, ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face",
),
output_format: str = Input(
description="Format of the output images",
choices=["webp", "jpg", "png"],
default="webp",
),
output_quality: int = Input(
description="Quality of the output images, from 0 to 100. 100 is best quality, 0 is lowest quality.",
default=80,
ge=0,
le=100,
),
) -> List[Path]:

reference_face_1 = Image.open(str(reference_face_1)) if reference_face_1 is not None else None
reference_face_2 = Image.open(str(reference_face_2)) if reference_face_2 is not None else None
reference_face_3 = Image.open(str(reference_face_3)) if reference_face_3 is not None else None
reference_face_4 = Image.open(str(reference_face_4)) if reference_face_4 is not None else None



print(f"[!] ({type(pos_prompt)}) pos_prompt={pos_prompt}")
print(f"[!] ({type(neg_prompt)}) neg_prompt={neg_prompt}")
print(f"[!] ({type(steps)}) steps={steps}")
Expand All @@ -373,7 +409,7 @@ def predict(
print(f"[!] ({type(default_pos_prompt)}) default_pos_prompt={default_pos_prompt}")
print(f"[!] ({type(default_neg_prompt)}) default_neg_prompt={default_neg_prompt}")

return generate(
faces = generate(
pos_prompt=pos_prompt,
neg_prompt=neg_prompt,
steps=steps,
Expand All @@ -391,3 +427,25 @@ def predict(
default_pos_prompt=default_pos_prompt,
default_neg_prompt=default_neg_prompt,
)

saved_image_paths = []
for index, face in enumerate(faces):
extension = output_format.lower()
extension = "jpeg" if extension == "jpg" else extension
image_path = f"./image_{index}.{extension}" # Adjusted path format to include variable extension

print(f"[~] Saving to {image_path}...")
print(f"[~] Output format: {extension.upper()}")
if output_format != "png":
print(f"[~] Output quality: {output_quality}")

save_params = {"format": extension.upper()}
if output_format != "png":
save_params["quality"] = output_quality
save_params["optimize"] = True

face.save(image_path, **save_params)
print(f"Saved image {index} at {image_path}")
saved_image_paths.append(image_path)

return [Path(f) for f in saved_image_paths]

0 comments on commit e7ea6b0

Please sign in to comment.