# Virtual Try On using Looky

This notebook demonstrates how to virtual try-on using Looky.


## Import libraries

Run the cells below to import the libraries.

In [None]:
import subprocess
from time import sleep
from typing import Tuple

import cv2
import matplotlib.pyplot as plt
import numpy as np
import supervision as sv
from dask.distributed import Client
from dask_jobqueue import PBSCluster
from PIL import Image
from skimage.morphology import remove_small_objects

## Initialize cluster

Run the cells below to initialize cluster.

In [None]:
parts = subprocess.check_output("hostname -I", shell=True).decode().split()

cluster = PBSCluster(
    queue="gen_S",
    account="LOOKY",
    cores=48,
    memory="96GB",
    processes=1,
    job_directives_skip=["select=", "walltime="],
    job_extra_directives=[
        "-V",
        "-l elapstim_req=00:30:00",
        "-T openmpi",
        "-v NQSV_MPI_VER=4.1.8/gcc11.4.0-cuda12.8.1",
    ],
    scheduler_options={"host": parts[0]},
    env_extra=[
        "cd $PBS_O_WORKDIR/..",
        "export HF_HUB_OFFLINE=1",
        "export HF_DATASETS_OFFLINE=1",
        "export TRANSFORMERS_OFFLINE=1",
    ],
    interface="eno1",
)

cluster.scale(jobs=1)
sleep(10)

client = Client(cluster)
client

## Define functions

Run the cells below to define functions.

In [None]:
# Copied from transformers.models.detr.image_processing_detr.masks_to_boxes
def masks_to_boxes(masks: np.ndarray) -> np.ndarray:
    """
    Compute the bounding boxes around the provided panoptic segmentation masks.

    Args:
        masks: masks in format `[number_masks, height, width]` where N is the number of masks

    Returns:
        boxes: bounding boxes in format `[number_masks, 4]` in xyxy format
    """
    if masks.size == 0:
        return np.zeros((0, 4))

    h, w = masks.shape[-2:]
    y = np.arange(0, h, dtype=np.float32)
    x = np.arange(0, w, dtype=np.float32)
    # see https://github.com/pytorch/pytorch/issues/50276
    y, x = np.meshgrid(y, x, indexing="ij")

    x_mask = masks * np.expand_dims(x, axis=0)
    x_max = x_mask.reshape(x_mask.shape[0], -1).max(-1)
    x = np.ma.array(x_mask, mask=~(np.array(masks, dtype=bool)))
    x_min = x.filled(fill_value=1e8)
    x_min = x_min.reshape(x_min.shape[0], -1).min(-1)

    y_mask = masks * np.expand_dims(y, axis=0)
    y_max = y_mask.reshape(x_mask.shape[0], -1).max(-1)
    y = np.ma.array(y_mask, mask=~(np.array(masks, dtype=bool)))
    y_min = y.filled(fill_value=1e8)
    y_min = y_min.reshape(y_min.shape[0], -1).min(-1)

    return np.stack([x_min, y_min, x_max, y_max], 1)


def get_upper_body_mask(
    segmentation: np.ndarray,
    keypoints: np.ndarray,
    scores: np.ndarray,
    target_size: Tuple[int, int],
):
    height, width = target_size

    keypoints = keypoints.copy()
    keypoints[..., 0] *= width
    keypoints[..., 1] *= height

    face_keypoints = keypoints[24:92]
    face_scores = scores[24:92]

    right_hand_keypoints = keypoints[92:113]
    right_hand_scores = scores[92:113]

    left_hand_keypoints = keypoints[113:]
    left_hand_scores = scores[113:]

    body_keypoints = keypoints[:18]
    body_scores = scores[:18]

    # upper-clothes / dress / coat
    mask = np.isin(segmentation, [5, 6, 7])
    mask = remove_small_objects(mask, min_size=300)
    boxes = masks_to_boxes(mask[None, ...])
    box1 = boxes[0]

    x_min = float("inf")
    y_min = float("inf")
    x_max = float("-inf")
    y_max = float("-inf")

    for i in [2, 3, 4]:
        if body_scores[i] > 0.3:
            x = body_keypoints[i, 0]
            x_min = min(x_min, x)

    for i in [2, 5]:
        if body_scores[i] > 0.3:
            y = body_keypoints[i, 1]
            y_min = min(y_min, y)

    for i in [5, 6, 7]:
        if body_scores[i] > 0.3:
            x = body_keypoints[i, 0]
            x_max = max(x_max, x)

    for i in [8, 11]:
        if body_scores[i] > 0.3:
            y = body_keypoints[i, 1]
            y_max = max(y_max, y)

    for i in [5, 9, 13]:
        if left_hand_scores[i] > 0.3:
            x = left_hand_keypoints[i, 0]
            y = left_hand_keypoints[i, 1]

            x_min = min(x_min, x)
            y_min = min(y_min, y)
            x_max = max(x_max, x)
            y_max = max(y_max, y)

        if right_hand_scores[i] > 0.3:
            x = right_hand_keypoints[i, 0]
            y = right_hand_keypoints[i, 1]

            x_min = min(x_min, x)
            y_min = min(y_min, y)
            x_max = max(x_max, x)
            y_max = max(y_max, y)

    for i in [5, 11]:
        if face_scores[i] > 0.3:
            y = face_keypoints[i, 1]
            y_min = min(y_min, y)

    box2 = [x_min, y_min, x_max, y_max]

    x_min = min(box1[0], box2[0])
    y_min = min(box1[1], box2[1])
    x_max = max(box1[2], box2[2])
    y_max = max(box1[3], box2[3])

    x_min = max(0, x_min - 20)
    y_min = max(0, y_min - 10)
    x_max = min(width, x_max + 25)
    if box2[3] > box1[3]:
        y_max = min(height, y_max)
    else:
        y_max = min(height, y_max + 25)

    x_min = int(x_min)
    y_min = int(y_min)
    x_max = int(x_max)
    y_max = int(y_max)

    agnostic_mask = np.zeros((height, width), dtype=np.uint8)
    agnostic_mask[y_min:y_max, x_min:x_max] = 255

    # hat / hair / sunglasses / face
    mask = np.isin(segmentation, [1, 2, 4, 13])
    agnostic_mask[mask] = 0

    agnostic_mask = Image.fromarray(agnostic_mask)
    boxes = np.array([box1, box2])

    return agnostic_mask, boxes


def get_lower_body_mask(
    segmentation,
    keypoints,
    scores,
    target_size,
):
    height, width = target_size

    keypoints = keypoints.copy()
    keypoints[..., 0] *= width
    keypoints[..., 1] *= height

    body_keypoints = keypoints[:18]
    body_scores = scores[:18]

    right_foot_keypoints = keypoints[18:21]
    right_foot_scores = scores[18:21]

    left_foot_keypoints = keypoints[21:24]
    left_foot_scores = scores[21:24]

    # skirt / pants / dress / belt
    mask = np.isin(segmentation, [6, 9, 12])
    mask = remove_small_objects(mask, min_size=300)
    boxes = masks_to_boxes(mask[None, ...])
    box1 = boxes[0]

    x_min = float("inf")
    y_min = float("inf")
    x_max = float("-inf")
    y_max = float("-inf")

    for i in [8, 9, 10, 11, 12, 13]:
        if body_scores[i] > 0.3:
            x = body_keypoints[i, 0]
            x_min = min(x_min, x)
            x_max = max(x_max, x)

    for i in [8, 11]:
        if body_scores[i] > 0.3:
            y = body_keypoints[i, 1]
            y_min = min(y_min, y)

    for i in [10, 13]:
        if body_scores[i] > 0.3:
            y = body_keypoints[i, 1]
            y_max = max(y_max, y)

    for i in [0, 1, 2]:
        if right_foot_scores[i] > 0.3:
            x = right_foot_keypoints[i, 0]

            x_min = min(x_min, x)
            x_max = max(x_max, x)

        if left_foot_scores[i] > 0.3:
            x = left_foot_keypoints[i, 0]

            x_min = min(x_min, x)
            x_max = max(x_max, x)

    for i in [2]:
        if right_foot_scores[i] > 0.3:
            y = right_foot_keypoints[i, 1]
            y_max = max(y_max, y)

        if left_foot_scores[i] > 0.3:
            y = left_foot_keypoints[i, 1]
            y_max = max(y_max, y)

    box2 = [x_min, y_min, x_max, y_max]

    x_min = min(box1[0], box2[0])
    y_min = min(box1[1], box2[1])
    x_max = max(box1[2], box2[2])
    y_max = max(box1[3], box2[3])

    x_min = max(0, x_min - 30)
    if box2[1] > box1[1]:
        y_min = max(0, y_min - 50)
    else:
        y_min = max(0, y_min - 5)
    x_max = min(width, x_max + 30)
    y_max = min(height, y_max + 10)

    x_min = int(x_min)
    y_min = int(y_min)
    x_max = int(x_max)
    y_max = int(y_max)

    agnostic_mask = np.zeros((height, width), dtype=np.uint8)
    agnostic_mask[y_min:y_max, x_min:x_max] = 255

    # left-shoe / right-shoe
    mask = np.isin(segmentation, [18, 19])
    agnostic_mask[mask] = 0

    agnostic_mask = Image.fromarray(agnostic_mask)
    boxes = np.array([box1, box2])

    return agnostic_mask, boxes


def get_full_body_mask(
    segmentation: np.ndarray,
    keypoints: np.ndarray,
    scores: np.ndarray,
    target_size: Tuple[int, int],
):
    height, width = target_size

    keypoints = keypoints.copy()
    keypoints[..., 0] *= width
    keypoints[..., 1] *= height

    face_keypoints = keypoints[24:92]
    face_scores = scores[24:92]

    right_hand_keypoints = keypoints[92:113]
    right_hand_scores = scores[92:113]

    left_hand_keypoints = keypoints[113:]
    left_hand_scores = scores[113:]

    body_keypoints = keypoints[:18]
    body_scores = scores[:18]

    right_foot_keypoints = keypoints[18:21]
    right_foot_scores = scores[18:21]

    left_foot_keypoints = keypoints[21:24]
    left_foot_scores = scores[21:24]

    # upper-clothes / dress / coat / skirt / pants / belt
    mask = np.isin(segmentation, [5, 6, 7, 9, 12])
    mask = remove_small_objects(mask, min_size=300)
    boxes = masks_to_boxes(mask[None, ...])
    box1 = boxes[0]

    x_min = float("inf")
    y_min = float("inf")
    x_max = float("-inf")
    y_max = float("-inf")

    for i in [2, 3, 4, 8, 9, 10]:
        if body_scores[i] > 0.3:
            x = body_keypoints[i, 0]
            x_min = min(x_min, x)

    for i in [5, 6, 7, 11, 12, 13]:
        if body_scores[i] > 0.3:
            x = body_keypoints[i, 0]
            x_max = max(x_max, x)

    for i in [2, 5]:
        if body_scores[i] > 0.3:
            y = body_keypoints[i, 1]
            y_min = min(y_min, y)

    for i in [10, 13]:
        if body_scores[i] > 0.3:
            y = body_keypoints[i, 1]
            y_max = max(y_max, y)

    for i in [5, 9, 13, 17]:
        if right_hand_scores[i] > 0.3:
            x = right_hand_keypoints[i, 0]
            y = right_hand_keypoints[i, 1]

            x_min = min(x_min, x)
            y_min = min(y_min, y)

        if left_hand_scores[i] > 0.3:
            x = left_hand_keypoints[i, 0]
            y = left_hand_keypoints[i, 1]

            x_max = max(x_max, x)
            y_min = min(y_min, y)

    for i in [0, 1, 2]:
        if right_foot_scores[i] > 0.3:
            x = right_foot_keypoints[i, 0]
            y = right_foot_keypoints[i, 1]

            x_max = max(x_max, x)
            y_max = max(y_max, y)

        if left_foot_scores[i] > 0.3:
            x = left_foot_keypoints[i, 0]
            y = left_foot_keypoints[i, 1]

            x_min = min(x_min, x)
            y_max = max(y_max, y)

    for i in [5, 11]:
        if face_scores[i] > 0.3:
            y = face_keypoints[i, 1]
            y_min = min(y_min, y)

    box2 = [x_min, y_min, x_max, y_max]

    x_min = min(box1[0], box2[0])
    y_min = min(box1[1], box2[1])
    x_max = max(box1[2], box2[2])
    y_max = max(box1[3], box2[3])

    x_min = max(0, x_min - 50)
    if box2[1] > box1[1]:
        y_min = max(0, y_min - 20)
    else:
        y_min = max(0, y_min)
    x_max = min(width, x_max + 50)
    y_max = min(height, y_max + 10)

    x_min = int(x_min)
    y_min = int(y_min)
    x_max = int(x_max)
    y_max = int(y_max)

    agnostic_mask = np.zeros((height, width), dtype=np.uint8)
    agnostic_mask[y_min:y_max, x_min:x_max] = 255

    # hat / hair / sunglasses / face / left-shoe / right-shoe
    mask = np.isin(segmentation, [1, 2, 4, 13, 18, 19])
    agnostic_mask[mask] = 0

    agnostic_mask = Image.fromarray(agnostic_mask)
    boxes = np.array([box1, box2])

    return agnostic_mask, boxes

## Visualize images

Run the cells below to visualize images.

In [None]:
# 000001 / 003940 / 007427
image = Image.open("../data/test/person/007427.jpg").convert("RGB")
garment_image = Image.open("../data/test/garment/007427.jpg").convert("RGB")

plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
plt.imshow(image)
plt.axis("off")
plt.title("Person")

plt.subplot(1, 2, 2)
plt.imshow(garment_image)
plt.axis("off")
plt.title("Garment")

plt.tight_layout()
plt.show()

## Estimate pose

Run the cells below to estimate pose.

In [None]:
def func(image):
    from accelerate import Accelerator
    from looky.dwpose import DWposeDetector

    device = Accelerator().device

    detector = DWposeDetector(device=device)

    pose_image, keypoints, scores = detector(image)

    return pose_image, keypoints, scores


future = client.submit(func, image)
pose_image, keypoints, scores = future.result()

keypoints = keypoints[0]
scores = scores[0]

plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
plt.imshow(image)
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(pose_image)
plt.axis("off")

plt.tight_layout()
plt.show()

## Parse human 

Run the cells below to parse human.

In [None]:
def func(image):
    import torch
    from accelerate import Accelerator
    from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation

    device = Accelerator().device

    image_processor = AutoImageProcessor.from_pretrained(
        "./weights/human_parsing",
        torch_dtype=torch.bfloat16,
        local_files_only=True,
    )
    model = Mask2FormerForUniversalSegmentation.from_pretrained(
        "./weights/human_parsing", local_files_only=True
    )
    model.eval()
    model.to(device)

    inputs = image_processor(image, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model(**inputs)

    outputs = image_processor.post_process_semantic_segmentation(
        outputs, target_sizes=[(image.height, image.width)]
    )[0]

    segmentation = outputs.cpu().numpy()

    return segmentation


future = client.submit(func, image)
segmentation = future.result()

all_labels = np.unique(segmentation)

boxes = []
masks = []
class_ids = []

for label in all_labels:
    if label == 0:
        continue

    mask = (segmentation == label).astype(np.uint8) * 255
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    for contour in contours:
        area = cv2.contourArea(contour)
        if area < 300:
            continue

        x, y, w, h = cv2.boundingRect(contour)
        box = [x, y, x + w, y + h]

        boxes.append(box)
        mask = np.zeros_like(segmentation, dtype=np.uint8)
        cv2.drawContours(mask, [contour], -1, 255, thickness=-1)
        mask = mask.astype(bool)

        masks.append(mask)
        class_ids.append(label)

detections = sv.Detections(
    xyxy=np.array(boxes),
    mask=np.stack(masks),
    class_id=np.array(class_ids),
)

mask_annotator = sv.MaskAnnotator()
label_annotator = sv.LabelAnnotator(text_position=sv.Position.BOTTOM_RIGHT)

id2label = {
    0: "background",
    1: "hat",
    2: "hair",
    3: "glove",
    4: "sunglasses",
    5: "upper-clothes",
    6: "dress",
    7: "coat",
    8: "socks",
    9: "pants",
    10: "torso-skin",
    11: "scarf",
    12: "skirt",
    13: "face",
    14: "left-arm",
    15: "right-arm",
    16: "left-leg",
    17: "right-leg",
    18: "left-shoe",
    19: "right-shoe",
}

labels = []
for class_id in detections.class_id:
    label = id2label[class_id]
    labels.append(label)

annotated_image = mask_annotator.annotate(image.copy(), detections)
annotated_image = label_annotator.annotate(
    annotated_image,
    detections,
    labels=labels,
)

plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
plt.imshow(image)
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(annotated_image)
plt.axis("off")

plt.tight_layout()
plt.show()

## Generate mask

Run the cells below to generate mask.

In [None]:
# upper-bdoy / lower-body / full-body
category = "full-body"

width, height = image.size
target_size = (height, width)

if category == "upper-body":
    agnostic_mask, boxes = get_upper_body_mask(
        segmentation, keypoints, scores, target_size
    )
if category == "lower-body":
    agnostic_mask, boxes = get_lower_body_mask(
        segmentation, keypoints, scores, target_size
    )
if category == "full-body":
    agnostic_mask, boxes = get_full_body_mask(
        segmentation, keypoints, scores, target_size
    )

detections = sv.Detections(
    xyxy=boxes,
    class_id=np.array([0, 1]),
)

box_annotator = sv.BoxAnnotator()

annotated_image = box_annotator.annotate(
    scene=image.copy(),
    detections=detections,
)

masked_image = image.copy()
masked_image.paste(Image.new("RGB", image.size, "gray"), mask=agnostic_mask)

plt.figure(figsize=(12, 4))

plt.subplot(1, 3, 1)
plt.imshow(image)
plt.axis("off")

plt.subplot(1, 3, 2)
plt.imshow(annotated_image)
plt.axis("off")

plt.subplot(1, 3, 3)
plt.imshow(masked_image)
plt.axis("off")

plt.tight_layout()
plt.show()

## Inference mask

Run the cells below to try on.

In [None]:
def func(image, mask_image, garment_image, pose_image):
    import torch
    from accelerate import Accelerator

    from looky.pipelines.pipeline_virtual_try_on import VirtualTryOnPipeline

    device = Accelerator().device

    pipe = VirtualTryOnPipeline.from_pretrained(
        "./weights/virtual_try_on",
        torch_dtype=torch.bfloat16,
        local_files_only=True,
    )
    pipe.to(device)

    images = pipe(
        image=image,
        mask_image=mask_image,
        garment_image=garment_image,
        pose_image=pose_image,
        height=1024,
        width=768,
        num_inference_steps=20,
        guidance_scale=2.0,
    ).images

    return images


future = client.submit(
    func,
    image,
    agnostic_mask,
    garment_image,
    pose_image,
)
images = future.result()
generated_image = images[0]

plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
plt.imshow(image)
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(generated_image)
plt.axis("off")

plt.tight_layout()
plt.show()