In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from tqdm import tqdm

In [None]:
# Root directory for the dataset
root = "data/DressCode/"

# Map labels to their corresponding directories
DIRECTORY_MAP = ["upper_body", "lower_body", "dresses"]

CLASS_MAP = [0, 0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2]

CLASS_TO_NAME = ["short sleeve top",
"long sleeve top",
"short sleeve outwear"
"long sleeve outwear",
"vest",
"sling",
"shorts",
"trousers",
"skirt",
"short sleeve dress",
"long sleeve dress",
"vest dress",
"sling dress"
]

In [None]:
# Read in the dataset
pairs = pd.read_csv(
    os.path.join(root, "train_pairs_cropped.txt"),
    delimiter="\t",
    header=None,
    names=["model", "garment", "label"],
)

pairs.head()

In [None]:
# Set the seed
torch.manual_seed(42)

# Set the device
device = "cuda" if torch.cuda.is_available() else "cpu"

device

In [None]:
# Load in the encoder network
encoder = models.resnet50()

# Load the weights
encoder.load_state_dict(torch.load("models/ResNet50 Cosine Similarity Loss Margin 0.2/checkpoint-6.pt"))

# Send the model to the device
encoder = encoder.to(device)

# Define the transformations for the network
transforms = transforms.Compose([transforms.Resize((256, 192)), transforms.ToTensor()])

In [None]:
def calculate_features(image: Image) -> np.ndarray:
    """
    Get the features for a given image.
    """
    # Set the model to evaluation mode
    encoder.eval()

    # Resize & convert to tensor
    image = transforms(image)

    # Add a batch dimension
    image = image.unsqueeze(0).to(device)

    with torch.no_grad():
        return encoder(image).cpu()

In [None]:
# NUM_IMAGES x NUM_FEATURES
features = {"upper_body": [], "lower_body": [], "dresses": []}
feature_indices = {"upper_body": [], "lower_body": [], "dresses": []}

encoder.eval()

for i, (model, garment, label) in tqdm(
    enumerate(pairs.values),
    desc="Calculating Features",
    total=len(pairs),
    unit="image",
):
    # Load in the garment image
    garment_image = Image.open(
        os.path.join(root, DIRECTORY_MAP[label], "cropped_images", garment)
    ).convert("RGB")
    
    # Get the features
    features[DIRECTORY_MAP[label]].append(calculate_features(garment_image))
    feature_indices[DIRECTORY_MAP[label]].append(i)

features["upper_body"] = torch.cat(features["upper_body"])
features["lower_body"] = torch.cat(features["lower_body"])
features["dresses"] = torch.cat(features["dresses"])

feature_indices["upper_body"] = np.array(feature_indices["upper_body"])
feature_indices["lower_body"] = np.array(feature_indices["lower_body"])
feature_indices["dresses"] = np.array(feature_indices["dresses"])

features["upper_body"].shape, features["lower_body"].shape, features["dresses"].shape

In [None]:
# Save the features
torch.save(features, "data/DressCode/train_features.pt")
torch.save(feature_indices, "data/DressCode/train_feature_indices.pt")

In [None]:
from ultralytics import YOLO

# Load the model
yolo = YOLO("models/yolov8m.pt")

In [None]:
def get_similar_images(image: Image, label: int,  n: int = 5) -> list[Image.Image]:
    """
    Get the n most similar images to the given image.
    """

    # Caluclate the features for the image
    image_features = calculate_features(image)

    class_features = features[DIRECTORY_MAP[label]]

    # Calculate the cosine similarity for every image
    similarities = torch.cosine_similarity(image_features, class_features)

    # Find the n most similar images
    similar_image_indices = torch.argsort(similarities, descending=True)[:n].numpy()

    similar_image_indices = feature_indices[DIRECTORY_MAP[label]][similar_image_indices]

    similar_images = []

    for idx in similar_image_indices:
        similar_images.append(
            Image.open(
                os.path.join(
                    root, DIRECTORY_MAP[label], "cropped_images", pairs.iloc[idx]["garment"]
                )
            )
        )

    return similar_images

In [None]:
def display_similar_images(image: Image, similar_images: list[Image.Image]):
    """
    Display the similar images.
    """

    fig, axs = plt.subplots(1, len(similar_images) + 1, figsize=(20, 10))

    # Display the anchor image
    axs[0].imshow(image)
    axs[0].set_title("Anchor Image")

    for i, similar_image in enumerate(similar_images, 1):
        axs[i].imshow(similar_image)
        axs[i].set_title(f"Similar Image {i}")

    plt.show()

In [None]:
def infer(image: Image, min_confidence: float = 0.5) -> Image:
    # Get the predictions
    predictions = yolo.predict(image)[0]

    # Get the predicted detections
    detections = predictions.boxes

    # Threshold the predictions
    detections = detections[detections.conf > min_confidence]

    for detection in detections:
        bounding_box = detection.xyxy.cpu().numpy().squeeze()

        image_cropped = image.crop(bounding_box)

        image_cropped.show()

        cls = detection.cls.int().item()
        label = CLASS_MAP[cls]

        similar_images = get_similar_images(image_cropped, label)

        display_similar_images(image_cropped, similar_images)

In [None]:
image = Image.open("data\jasper.PNG").convert("RGB")

infer(image)