In [1]:
import sys
import os
# Fix path setup for Windows - use os.path operations instead of string replacement
current_dir = os.path.abspath('')
# Go up two levels from flashface/all_finetune to reach the root
package_dir = os.path.dirname(os.path.dirname(current_dir))
sys.path.insert(0, package_dir)

print(f"Current working directory: {os.getcwd()}")
print(f"Package directory: {package_dir}")
print(f"Python path: {sys.path[:3]}")
print(f"Contents of package_dir: {os.listdir(package_dir) if os.path.exists(package_dir) else 'Directory not found'}")

import copy
import random
import numpy as np

import torch
import torch.cuda.amp as amp
import torch.nn as nn
import torchvision.transforms as T

from config import cfg
from models import sd_v1_ref_unet
from ops.context_diffusion import ContextGaussianDiffusion
from ldm import data, models, ops
from ldm.models.vae import sd_v1_vae
from ldm.utils import load_model_weights
from utils import Compose, PadToSquare, get_padding, seed_everything
from ldm.models.retinaface import retinaface as retinaface_func, crop_face

# Import functions from demo_gradio.py
from demo_gradio import generate, encode_text
from enhanced_transforms import EnhancedPadToSquare, create_face_transforms

from PIL import Image, ImageDraw

Current working directory: x:\dev\ComfyUI-FlashFace\flashface\all_finetune
Package directory: x:\dev\ComfyUI-FlashFace
Python path: ['x:\\dev\\ComfyUI-FlashFace', 'C:\\Python\\Python13\\python313.zip', 'C:\\Python\\Python13\\DLLs']
Contents of package_dir: ['.git', '.github', '.gitignore', '.venv', 'cache', 'docs', 'example_workflows', 'figs', 'flashface', 'install_dependencies.py', 'ldm', 'LICENSE', 'nodes', 'pyproject.toml', 'readme.md', 'requirements-comfy.txt', 'requirements.txt', 'setup.bat', 'setup.sh', '__init__.py']


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# model path
SKIP_LOAD = False  # Set to True to skip loading models that aren't downloaded yet
DEBUG_VIEW = False
SKEP_LOAD = False
LOAD_FLAG = True
DEFAULT_INPUT_IMAGES = 4
MAX_INPUT_IMAGES = 4
SIZE = 768
with_lora = False
enable_encoder = False
with_pos_mask = True

weight_path = f'{package_dir}/cache/flashface.ckpt'

# Detect available device
gpu = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {gpu}")

# Use enhanced transforms for better image quality
padding_to_square = EnhancedPadToSquare(224, padding_mode='constant', fill=0)

# Use enhanced transforms for retinaface with better anti-aliasing
retinaface_transforms = T.Compose([
    EnhancedPadToSquare(size=640, padding_mode='constant', fill=0),
    T.ToTensor()
])

from ldm.models.retinaface import retinaface, crop_face

retinaface = retinaface(pretrained=True,
                        device=gpu).eval().requires_grad_(False)

# Initialize face transforms (needed for the generate function)
face_transforms = create_face_transforms(size=224)


def detect_face(imgs=None):
    # read images
    pil_imgs = imgs
    b = len(pil_imgs)
    vis_pil_imgs = copy.deepcopy(pil_imgs)

    # detection
    imgs = torch.stack([retinaface_transforms(u) for u in pil_imgs]).to(gpu)
    boxes, kpts = retinaface.detect(imgs, min_thr=0.6)

    # undo padding and scaling
    face_imgs = []

    for i in range(b):
        # params
        scale = 640 / max(pil_imgs[i].size)
        left, top, _, _ = get_padding(round(scale * pil_imgs[i].width),
                                      round(scale * pil_imgs[i].height), 640)

        # undo padding
        boxes[i][:, [0, 2]] -= left
        boxes[i][:, [1, 3]] -= top
        kpts[i][:, :, 0] -= left
        kpts[i][:, :, 1] -= top

        # undo scaling
        boxes[i][:, :4] /= scale
        kpts[i][:, :, :2] /= scale

        # crop faces
        crops = crop_face(pil_imgs[i], boxes[i], kpts[i])
        if len(crops) != 1:
            raise ValueError(
                f'Warning: {len(crops)} faces detected in image {i}')

        face_imgs += crops

        # draw boxes on the pil image
        draw = ImageDraw.Draw(vis_pil_imgs[i])
        for box in boxes[i]:
            box = box[:4].tolist()
            box = [int(x) for x in box]
            draw.rectangle(box, outline='red', width=4)

    return face_imgs


# Wrapper function to handle reference_faces list
def generate_with_faces(pos_prompt, neg_prompt=None, steps=30, face_bbox=[0.3, 0.1, 0.6, 0.4],
                       lamda_feat=1.2, face_guidence=3.2, num_sample=1, text_control_scale=7.5,
                       seed=-1, step_to_launch_face_guidence=750, reference_faces=None,
                       need_detect=True, lamda_feat_before_ref_guidence=0.85):
    # Convert reference_faces list to individual parameters
    ref_face_1 = reference_faces[0] if reference_faces and len(reference_faces) > 0 else None
    ref_face_2 = reference_faces[1] if reference_faces and len(reference_faces) > 1 else None
    ref_face_3 = reference_faces[2] if reference_faces and len(reference_faces) > 2 else None
    ref_face_4 = reference_faces[3] if reference_faces and len(reference_faces) > 3 else None

    return generate(pos_prompt=pos_prompt, neg_prompt=neg_prompt, steps=steps,
                   face_bbox=face_bbox, lamda_feat=lamda_feat, face_guidence=face_guidence,
                   num_sample=num_sample, text_control_scale=text_control_scale, seed=seed,
                   step_to_launch_face_guidence=step_to_launch_face_guidence,
                   reference_face_1=ref_face_1, reference_face_2=ref_face_2,
                   reference_face_3=ref_face_3, reference_face_4=ref_face_4,
                   need_detect=need_detect, lamda_feat_before_ref_guidence=lamda_feat_before_ref_guidence,
                   clip_model=clip, clip_tokenizer=clip_tokenizer, 
                   autoencoder_model=autoencoder, unet_model=unet, diffusion_model=diffusion)


if not DEBUG_VIEW and not SKEP_LOAD and not SKIP_LOAD:
    clip_tokenizer = data.CLIPTokenizer(padding='eos')
    clip = getattr(models, cfg.clip_model)(
        pretrained=True).eval().requires_grad_(False).textual.to(gpu)
    autoencoder = sd_v1_vae(
        pretrained=True).eval().requires_grad_(False).to(gpu)

    unet = sd_v1_ref_unet(pretrained=True,
                          version='sd-v1-5_nonema',
                          enable_encoder=enable_encoder).to(gpu)

    unet.replace_input_conv()
    unet = unet.eval().requires_grad_(False).to(gpu)
    unet.share_cache['num_pairs'] = cfg.num_pairs

    if LOAD_FLAG:
        model_weight = load_model_weights(weight_path, device='cpu')
        msg = unet.load_state_dict(model_weight, strict=True)
        print(msg)

    # diffusion
    sigmas = ops.noise_schedule(schedule=cfg.schedule,
                                n=cfg.num_timesteps,
                                beta_min=cfg.scale_min,
                                beta_max=cfg.scale_max)
    diffusion = ContextGaussianDiffusion(sigmas=sigmas,
                                         prediction_type=cfg.prediction_type)
    diffusion.num_pairs = cfg.num_pairs
    print("model initialized")

Using device: cpu
<All keys matched successfully>
<All keys matched successfully>
model initialized
model initialized


In [9]:
import importlib
import demo_gradio
import flashface.all_finetune.ops.context_diffusion
importlib.reload(flashface.all_finetune.ops.context_diffusion)
importlib.reload(demo_gradio)
from demo_gradio import generate, encode_text

In [11]:
# Recommended hyper-parameters to obtain stable ID Fidelity
face_imgs = [Image.open(f"{package_dir}/example_workflows/age/{i+1}.png").convert("RGB") for i in range(3)]
need_detect = True
pos_prompt = 'A beautiful young asian woman, in a traditional chinese outfit, long hair, complete with a classic hairpin, on the street , white skin, soft light'
num_samples = 4
# center face position
face_bbox =[0.3, 0.2, 0.6, 0.5] 
# bigger these three parameters leads to more fidelity but less diversity 
lamda_feat = 1.2
face_guidence = 0.0  # Temporarily disable face guidance to test basic generation
step_to_launch_face_guidence = 750

steps = 25
default_text_control_scale = 7.5

default_seed = 0


imgs = generate_with_faces(pos_prompt=pos_prompt, 
                    neg_prompt=None, 
                    steps=steps, 
                    face_bbox=face_bbox,
                    lamda_feat=lamda_feat, 
                    face_guidence=face_guidence, 
                    num_sample=num_samples, 
                    text_control_scale=default_text_control_scale, 
                    seed=default_seed, 
                    step_to_launch_face_guidence=step_to_launch_face_guidence, 
                    reference_faces=face_imgs,
                    need_detect=need_detect
                    )


# show the generated images
img_size = imgs[0].size
num_imgs = len(imgs)
save_img = Image.new('RGB', (img_size[0] * (num_imgs + 1), img_size[1]))
for i, img in enumerate(imgs):
    save_img.paste(img, ((i + 1) * img_size[0], 0))

# paste all four reference face imgs to the first

resize_w = img_size[0] // 2
resize_h = img_size[1] // 2

for id, ref_img in enumerate(face_imgs):
    # resize the ref_img keep the ratio to fit the size of (resize_w, resize_h)
    w_ratio = resize_w / ref_img.size[0]
    h_ratio = resize_h / ref_img.size[1]
    ratio = min(w_ratio, h_ratio)
    ref_img = ref_img.resize(
        (int(ref_img.size[0] * ratio), int(ref_img.size[1] * ratio)))

    if id < 2:
        save_img.paste(ref_img, (id * resize_w, 0))
    else:
        save_img.paste(ref_img, ((id - 2) * resize_w, resize_h))

display(save_img)



0
final pos_prompt:  A beautiful young asian woman, in a traditional chinese outfit, long hair, complete with a classic hairpin, on the street , white skin, soft light, best quality, masterpiece,ultra-detailed, UHD 4K, photographic
final neg_prompt:  blurry, ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face
detected 3 faces
[0.3, 0.2, 0.6, 0.5]
detected 3 faces
[0.3, 0.2, 0.6, 0.5]


TypeError: RefStableUNet._patch_attention_forward.<locals>.patched_forward() takes from 2 to 4 positional arguments but 6 were given

In [None]:
# Recommended hyper-parameters to obtain stable ID Fidelity
face_imgs = [Image.open(f"{package_dir}/example_workflows/age/{i+1}.png").convert("RGB") for i in range(3)]
need_detect = True
pos_prompt = 'A very  old woman with short wavy hair'
num_samples = 4
# center face position
face_bbox =[0.3, 0.1, 0.6, 0.4] 
# bigger these three parameters leads to more fidelity but less diversity 
lamda_feat = 1
face_guidence = 2.5
step_to_launch_face_guidence = 750

steps = 25
default_text_control_scale = 8.5

default_seed = 0


imgs = generate_with_faces(pos_prompt=pos_prompt, 
                    neg_prompt=None, 
                    steps=steps, 
                    face_bbox=face_bbox,
                    lamda_feat=lamda_feat, 
                    face_guidence=face_guidence, 
                    num_sample=num_samples, 
                    text_control_scale=default_text_control_scale, 
                    seed=default_seed, 
                    step_to_launch_face_guidence=step_to_launch_face_guidence, 
                    reference_faces=face_imgs,
                    need_detect=need_detect
                    )


# show the generated images
img_size = imgs[0].size
num_imgs = len(imgs)
save_img = Image.new('RGB', (img_size[0] * (num_imgs + 1), img_size[1]))
for i, img in enumerate(imgs):
    save_img.paste(img, ((i + 1) * img_size[0], 0))

# paste all four reference face imgs to the first

resize_w = img_size[0] // 2
resize_h = img_size[1] // 2

for id, ref_img in enumerate(face_imgs):
    # resize the ref_img keep the ratio to fit the size of (resize_w, resize_h)
    w_ratio = resize_w / ref_img.size[0]
    h_ratio = resize_h / ref_img.size[1]
    ratio = min(w_ratio, h_ratio)
    ref_img = ref_img.resize(
        (int(ref_img.size[0] * ratio), int(ref_img.size[1] * ratio)))

    if id < 2:
        save_img.paste(ref_img, (id * resize_w, 0))
    else:
        save_img.paste(ref_img, ((id - 2) * resize_w, resize_h))

display(save_img)



In [None]:
face_imgs = [
    Image.open(f"{package_dir}/example_workflows/age/{i+1}.png").convert("RGB") for i in range(3)
]
need_detect = True

pos_prompt = """The cute, beautiful baby girl with medium length brown hair and  pink bow, in the studio """
# remove beard
neg_prompt = None
# No face position
face_bbox = [0.3, 0.2, 0.6, 0.6]


# bigger these three parameters leads to more fidelity but less diversity
lamda_feat = 1.2
face_guidence = 2
step_to_launch_face_guidence = 700

steps = 50
default_text_control_scale = 7.5

default_seed = 0


imgs = generate_with_faces(
    pos_prompt=pos_prompt,
    neg_prompt=neg_prompt,
    steps=steps,
    face_bbox=face_bbox,
    lamda_feat=lamda_feat,
    face_guidence=face_guidence,
    num_sample=4,
    text_control_scale=default_text_control_scale,
    seed=default_seed,
    step_to_launch_face_guidence=step_to_launch_face_guidence,
    reference_faces=face_imgs,
    need_detect=need_detect,
)


# show the generated images
img_size = imgs[0].size
num_imgs = len(imgs)
save_img = Image.new("RGB", (img_size[0] * (num_imgs + 1), img_size[1]))
for i, img in enumerate(imgs):
    save_img.paste(img, ((i + 1) * img_size[0], 0))

# paste all four reference face imgs to the first

resize_w = img_size[0] // 2
resize_h = img_size[1] // 2

for id, ref_img in enumerate(face_imgs):
    # resize the ref_img keep the ratio to fit the size of (resize_w, resize_h)
    w_ratio = resize_w / ref_img.size[0]
    h_ratio = resize_h / ref_img.size[1]
    ratio = min(w_ratio, h_ratio)
    ref_img = ref_img.resize(
        (int(ref_img.size[0] * ratio), int(ref_img.size[1] * ratio))
    )

    if id < 2:
        save_img.paste(ref_img, (id * resize_w, 0))
    else:
        save_img.paste(ref_img, ((id - 2) * resize_w, resize_h))

display(save_img)

In [None]:
face_imgs = [Image.open(f"{package_dir}/example_workflows/avatar.png").convert("RGB")]
need_detect = True
pos_prompt = "A handsome young man with long brown hair is sitting in the desert"
num_samples = 2
# No face position
face_bbox =[0., 0., 0., 0.] 
# bigger these three parameters leads to more fidelity but less diversity 
lamda_feat = 0.9
face_guidence = 2.5
step_to_launch_face_guidence = 700

steps = 50
default_text_control_scale = 7.5

default_seed = 0


imgs = generate_with_faces(pos_prompt=pos_prompt, 
                    neg_prompt=None, 
                    steps=steps, 
                    face_bbox=face_bbox,
                    lamda_feat=lamda_feat, 
                    face_guidence=face_guidence, 
                    num_sample=num_samples, 
                    text_control_scale=default_text_control_scale, 
                    seed=default_seed, 
                    step_to_launch_face_guidence=step_to_launch_face_guidence, 
                    reference_faces=face_imgs,
                    need_detect=need_detect
                    )


# show the generated images
img_size = imgs[0].size
num_imgs = len(imgs)
save_img = Image.new('RGB', (img_size[0] * (num_imgs + 1), img_size[1]))
for i, img in enumerate(imgs):
    save_img.paste(img, ((i + 1) * img_size[0], 0))

# paste all four reference face imgs to the first

resize_w = img_size[0] // 2
resize_h = img_size[1] // 2

for id, ref_img in enumerate(face_imgs):
    # resize the ref_img keep the ratio to fit the size of (resize_w, resize_h)
    w_ratio = resize_w / ref_img.size[0]
    h_ratio = resize_h / ref_img.size[1]
    ratio = min(w_ratio, h_ratio)
    ref_img = ref_img.resize(
        (int(ref_img.size[0] * ratio), int(ref_img.size[1] * ratio)))

    if id < 2:
        save_img.paste(ref_img, (id * resize_w, 0))
    else:
        save_img.paste(ref_img, ((id - 2) * resize_w, resize_h))

display(save_img)

In [None]:
face_imgs = [Image.open(f"{package_dir}/example_workflows/snow_white.png").convert("RGB")]
need_detect = True
pos_prompt = "Full body photo of a beautiful young women sitting in the office, medium length wavy hair, wearinig red bow hairpin on the top of head"
num_samples = 2
# No face position
face_bbox =[0., 0., 0., 0.] 
# bigger these three parameters leads to more fidelity but less diversity 
lamda_feat = 1
face_guidence = 2
step_to_launch_face_guidence = 600

steps = 50
default_text_control_scale = 7.5

default_seed = 0


imgs = generate_with_faces(pos_prompt=pos_prompt, 
                    neg_prompt=None, 
                    steps=steps, 
                    face_bbox=face_bbox,
                    lamda_feat=lamda_feat, 
                    face_guidence=face_guidence, 
                    num_sample=num_samples, 
                    text_control_scale=default_text_control_scale, 
                    seed=default_seed, 
                    step_to_launch_face_guidence=step_to_launch_face_guidence, 
                    reference_faces=face_imgs,
                    need_detect=need_detect
                    )


# show the generated images
img_size = imgs[0].size
num_imgs = len(imgs)
save_img = Image.new('RGB', (img_size[0] * (num_imgs + 1), img_size[1]))
for i, img in enumerate(imgs):
    save_img.paste(img, ((i + 1) * img_size[0], 0))

# paste all four reference face imgs to the first

resize_w = img_size[0] // 2
resize_h = img_size[1] // 2

for id, ref_img in enumerate(face_imgs):
    # resize the ref_img keep the ratio to fit the size of (resize_w, resize_h)
    w_ratio = resize_w / ref_img.size[0]
    h_ratio = resize_h / ref_img.size[1]
    ratio = min(w_ratio, h_ratio)
    ref_img = ref_img.resize(
        (int(ref_img.size[0] * ratio), int(ref_img.size[1] * ratio)))

    if id < 2:
        save_img.paste(ref_img, (id * resize_w, 0))
    else:
        save_img.paste(ref_img, ((id - 2) * resize_w, resize_h))

display(save_img)

In [None]:
#eren_jaeger.png

face_imgs = [Image.open(f"{package_dir}/example_workflows/eren_jaeger.png").convert("RGB")]
need_detect = True
pos_prompt =  "A handsome, attractive, sleek young man sitting on the beach, wearing black long trench coat, man bun hair,  heavily clouded, sunset, sea in the background"
# remove beard
neg_prompt = "beard"
# No face position
face_bbox =[0., 0., 0., 0.] 
# bigger these three parameters leads to more fidelity but less diversity 
lamda_feat = 0.9
face_guidence = 2
step_to_launch_face_guidence = 600
num_samples = 2
steps = 50
default_text_control_scale = 7.5

default_seed = 0


imgs = generate_with_faces(pos_prompt=pos_prompt, 
                    neg_prompt=neg_prompt, 
                    steps=steps, 
                    face_bbox=face_bbox,
                    lamda_feat=lamda_feat, 
                    face_guidence=face_guidence, 
                    num_sample=num_samples, 
                    text_control_scale=default_text_control_scale, 
                    seed=default_seed, 
                    step_to_launch_face_guidence=step_to_launch_face_guidence, 
                    reference_faces=face_imgs,
                    need_detect=need_detect
                    )


# show the generated images
img_size = imgs[0].size
num_imgs = len(imgs)
save_img = Image.new('RGB', (img_size[0] * (num_imgs + 1), img_size[1]))
for i, img in enumerate(imgs):
    save_img.paste(img, ((i + 1) * img_size[0], 0))

# paste all four reference face imgs to the first

resize_w = img_size[0] // 2
resize_h = img_size[1] // 2

for id, ref_img in enumerate(face_imgs):
    # resize the ref_img keep the ratio to fit the size of (resize_w, resize_h)
    w_ratio = resize_w / ref_img.size[0]
    h_ratio = resize_h / ref_img.size[1]
    ratio = min(w_ratio, h_ratio)
    ref_img = ref_img.resize(
        (int(ref_img.size[0] * ratio), int(ref_img.size[1] * ratio)))

    if id < 2:
        save_img.paste(ref_img, (id * resize_w, 0))
    else:
        save_img.paste(ref_img, ((id - 2) * resize_w, resize_h))

display(save_img)

In [None]:
# ordinary people

face_imgs = [Image.open(f"{package_dir}/example_workflows/man_face/{i+1}.png").convert("RGB") for i in range(4)]
need_detect = True
pos_prompt =  "An handsome young man, with cowboy hat, long hair, full body, standing in the forest, sunset"
# remove beard
neg_prompt = "beard"
# No face position
face_bbox =[0., 0., 0., 0.] 
# bigger these three parameters leads to more fidelity but less diversity 
lamda_feat = 0.85
face_guidence = 2
step_to_launch_face_guidence = 600
num_samples = 2
steps = 50
default_text_control_scale = 7.5

default_seed = 0


imgs = generate_with_faces(pos_prompt=pos_prompt, 
                    neg_prompt=neg_prompt, 
                    steps=steps, 
                    face_bbox=face_bbox,
                    lamda_feat=lamda_feat, 
                    face_guidence=face_guidence, 
                    num_sample=num_samples,
                    text_control_scale=default_text_control_scale, 
                    seed=default_seed, 
                    step_to_launch_face_guidence=step_to_launch_face_guidence, 
                    reference_faces=face_imgs,
                    need_detect=need_detect
                    )


# show the generated images
img_size = imgs[0].size
num_imgs = len(imgs)
save_img = Image.new('RGB', (img_size[0] * (num_imgs + 1), img_size[1]))
for i, img in enumerate(imgs):
    save_img.paste(img, ((i + 1) * img_size[0], 0))

# paste all four reference face imgs to the first

resize_w = img_size[0] // 2
resize_h = img_size[1] // 2

for id, ref_img in enumerate(face_imgs):
    # resize the ref_img keep the ratio to fit the size of (resize_w, resize_h)
    w_ratio = resize_w / ref_img.size[0]
    h_ratio = resize_h / ref_img.size[1]
    ratio = min(w_ratio, h_ratio)
    ref_img = ref_img.resize(
        (int(ref_img.size[0] * ratio), int(ref_img.size[1] * ratio)))

    if id < 2:
        save_img.paste(ref_img, (id * resize_w, 0))
    else:
        save_img.paste(ref_img, ((id - 2) * resize_w, resize_h))

display(save_img)

In [None]:
# ordinary people

face_imgs = [Image.open(f"{package_dir}/example_workflows/woman_face/{i+1}.png").convert("RGB") for i in range(4)]
need_detect = True
pos_prompt =  'A beautiful young woman with short curly hair in the garden holding a flower'
# remove beard
neg_prompt = None
# No face position
face_bbox =[0., 0., 0., 0.] 
# bigger these three parameters leads to more fidelity but less diversity 
lamda_feat = 1
face_guidence = 2.3
step_to_launch_face_guidence = 600
num_samples = 2
steps = 50
default_text_control_scale = 7.5

default_seed = 0


imgs = generate_with_faces(pos_prompt=pos_prompt, 
                    neg_prompt=neg_prompt, 
                    steps=steps, 
                    face_bbox=face_bbox,
                    lamda_feat=lamda_feat, 
                    face_guidence=face_guidence, 
                    num_sample=num_samples, 
                    text_control_scale=default_text_control_scale, 
                    seed=default_seed, 
                    step_to_launch_face_guidence=step_to_launch_face_guidence, 
                    reference_faces=face_imgs,
                    need_detect=need_detect
                    )


# show the generated images
img_size = imgs[0].size
num_imgs = len(imgs)
save_img = Image.new('RGB', (img_size[0] * (num_imgs + 1), img_size[1]))
for i, img in enumerate(imgs):
    save_img.paste(img, ((i + 1) * img_size[0], 0))

# paste all four reference face imgs to the first

resize_w = img_size[0] // 2
resize_h = img_size[1] // 2

for id, ref_img in enumerate(face_imgs):
    # resize the ref_img keep the ratio to fit the size of (resize_w, resize_h)
    w_ratio = resize_w / ref_img.size[0]
    h_ratio = resize_h / ref_img.size[1]
    ratio = min(w_ratio, h_ratio)
    ref_img = ref_img.resize(
        (int(ref_img.size[0] * ratio), int(ref_img.size[1] * ratio)))

    if id < 2:
        save_img.paste(ref_img, (id * resize_w, 0))
    else:
        save_img.paste(ref_img, ((id - 2) * resize_w, resize_h))

display(save_img)

In [None]:
# details


face_imgs = [Image.open(f"{package_dir}/example_workflows/details_face/{i+1}.jpeg").convert("RGB") for i in range(4)]
need_detect = True
pos_prompt =  'A beautiful young woman stands in the street,  wearing earing and white skirt and  hat, thin body, sunny day'
# remove beard
neg_prompt = 'Bangs'
# left top corner
face_bbox =  [0.1, 0.1, 0.5, 0.5]
# bigger these three parameters leads to more fidelity but less diversity 

lamda_feat = 1.3
face_guidence = 3.2
step_to_launch_face_guidence = 800

steps = 50
default_text_control_scale = 8

default_seed = 0
num_samples = 2

imgs = generate_with_faces(pos_prompt=pos_prompt, 
                    neg_prompt=neg_prompt, 
                    steps=steps, 
                    face_bbox=face_bbox,
                    lamda_feat=lamda_feat, 
                    face_guidence=face_guidence, 
                    num_sample=num_samples, 
                    text_control_scale=default_text_control_scale, 
                    seed=default_seed, 
                    step_to_launch_face_guidence=step_to_launch_face_guidence, 
                    reference_faces=face_imgs,
                    need_detect=need_detect
                    )


# show the generated images
img_size = imgs[0].size
num_imgs = len(imgs)
save_img = Image.new('RGB', (img_size[0] * (num_imgs + 1), img_size[1]))
for i, img in enumerate(imgs):
    save_img.paste(img, ((i + 1) * img_size[0], 0))

# paste all four reference face imgs to the first

resize_w = img_size[0] // 2
resize_h = img_size[1] // 2

for id, ref_img in enumerate(face_imgs):
    # resize the ref_img keep the ratio to fit the size of (resize_w, resize_h)
    w_ratio = resize_w / ref_img.size[0]
    h_ratio = resize_h / ref_img.size[1]
    ratio = min(w_ratio, h_ratio)
    ref_img = ref_img.resize(
        (int(ref_img.size[0] * ratio), int(ref_img.size[1] * ratio)))

    if id < 2:
        save_img.paste(ref_img, (id * resize_w, 0))
    else:
        save_img.paste(ref_img, ((id - 2) * resize_w, resize_h))

display(save_img)