In [1]:
import os
import random

import pandas as pd
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

In [2]:
# Set the seed
torch.manual_seed(42)

# Set the device
device = "cuda" if torch.cuda.is_available() else "cpu"

device

'cuda'

In [3]:
# Map labels to their corresponding directories
DIRECTORY_MAP = ["upper_body", "lower_body", "dresses"]


class FashionDataset(Dataset):
    def __init__(self, root: str, pairs: str) -> None:
        super().__init__()

        self.transforms = transforms.Compose(
            [transforms.Resize((256, 192)), transforms.ToTensor()]
        )

        # Root directory of the dataset
        self.root = root

        # Load in the paired data
        self.data = pd.read_csv(
            pairs, delimiter="\t", header=None, names=["model", "garment", "label"]
        )

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, index: int) -> dict:
        model, garment, label = self.data.iloc[index]

        # Load the anchor & positive images (random choice between model and garment)
        if random.choice([True, False]):
            anchor = Image.open(
                os.path.join(self.root, DIRECTORY_MAP[label], "cropped_images", model)
            ).convert("RGB")

            positive = Image.open(
                os.path.join(self.root, DIRECTORY_MAP[label], "cropped_images", garment)
            ).convert("RGB")
        else:
            anchor = Image.open(
                os.path.join(self.root, DIRECTORY_MAP[label], "cropped_images", garment)
            ).convert("RGB")

            positive = Image.open(
                os.path.join(self.root, DIRECTORY_MAP[label], "cropped_images", model)
            ).convert("RGB")

        # Randomly sample a negative (ensuring it is not the same as the anchor)
        while (negative_index := random.randrange(0, len(self.data))) == index:
            pass

        negative_model, negative_garment, negative_label = self.data.iloc[
            negative_index
        ]

        # Load the negative image (random choice between model and garment)
        if random.choice([True, False]):
            negative = Image.open(
                os.path.join(
                    self.root,
                    DIRECTORY_MAP[negative_label],
                    "cropped_images",
                    negative_garment,
                )
            ).convert("RGB")
        else:
            negative = Image.open(
                os.path.join(
                    self.root,
                    DIRECTORY_MAP[negative_label],
                    "cropped_images",
                    negative_model,
                )
            ).convert("RGB")

        # Resize & convert to tensors
        anchor = self.transforms(anchor)
        positive = self.transforms(positive)
        negative = self.transforms(negative)

        return anchor, positive, negative

In [4]:
def train(
    model: nn.Module,
    train_data: DataLoader,
    test_data: DataLoader,
    loss_fcn: nn.Module,
    epochs: int = 10,
    device: str = "cpu",
    log_dir: str = "./logs",
    output_dir: str = "./models",
    model_name: str = "ResNet50",
):
    # Create the log & output directories if they don't exist
    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(os.path.join(output_dir, model_name), exist_ok=True)


    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    print(f"View logs by running tensorboard --logdir={log_dir}/{model_name}")

    # Tensorboard logger
    logger = SummaryWriter(os.path.join(log_dir, model_name))

    for epoch in range(epochs):

        # Set model to training mode
        model.train()

        for i, (anchor, positive, negative) in tqdm(
            enumerate(train_data),
            f"Epoch {epoch} Training",
            unit="batch",
            total=len(train_data),
        ):
            optimizer.zero_grad()

            # Send images to the device
            anchor = anchor.to(device)
            positive = positive.to(device)
            negative = negative.to(device)

            # Forward pass
            anchor_features = model(anchor)
            positive_features = model(positive)
            negative_features = model(negative)

            # Compute the loss
            loss = loss_fcn(anchor_features, positive_features, negative_features)

            # Log loss to tensorboard
            logger.add_scalar("Train/Triplet Loss", loss, i + epoch * len(train_data))

            # Backward pass
            loss.backward()
            optimizer.step()

        # Evaluate the model on the testing data
        model.eval()

        validation_loss = 0.0
        euclidean_distance_ap = 0.0
        euclidean_distance_an = 0.0
        similarity_ap = 0.0
        similarity_an = 0.0

        with torch.no_grad():
            for i, (anchor, positive, negative) in tqdm(
                enumerate(test_data),
                f"Epoch {epoch} Evaluation",
                unit="batch",
                total=len(test_data),
            ):
                # Send images to the device
                anchor = anchor.to(device)
                positive = positive.to(device)
                negative = negative.to(device)

                # Forward pass
                anchor_features = model(anchor)
                positive_features = model(positive)
                negative_features = model(negative)

                # Compute the loss
                validation_loss += loss_fcn(
                    anchor_features, positive_features, negative_features
                )

                # Compute the Euclidean distance for the positive and negative pairs
                euclidean_distance_ap += torch.norm(
                    anchor_features - positive_features, dim=1
                ).sum()

                euclidean_distance_an += torch.norm(
                    anchor_features - negative_features, dim=1
                ).sum()

                # Compute the Cosine similarity for the positive and negative pairs
                similarity_ap += torch.cosine_similarity(anchor_features, positive_features).mean()
                similarity_an += torch.cosine_similarity(anchor_features, negative_features).mean()

        # Log validation metrics
        logger.add_scalar("Test/Triplet Loss", validation_loss / len(test_data), epoch)
        logger.add_scalar(
            "Test/Euclidean Distance Ratio (AN/AP)",
            euclidean_distance_an / euclidean_distance_ap,
            epoch,
        )
        logger.add_scalar(
            "Test/Cosine Similarity Ratio (AP/AN)", 
            similarity_ap / similarity_an, 
            epoch
        )

        print(f"Epoch {epoch} Validation Loss: {validation_loss / len(test_data)}")

        # Save the model
        torch.save(model.state_dict(), f"{os.path.join(output_dir, model_name)}/checkpoint-{epoch + 1}.pt")

In [5]:
# Load the dataset
train_data = FashionDataset("data/DressCode", "data/DressCode/train_pairs_cropped.txt")

test_data = FashionDataset(
    "data/DressCode", "data/DressCode/test_pairs_paired_cropped.txt"
)


# Define the training dataloader
train_loader = DataLoader(train_data, batch_size=48, shuffle=True)

# Define the validation dataloader
test_loader = DataLoader(test_data, batch_size=48, shuffle=False)

In [6]:
# Load the model
model = models.resnet50()

model = model.to(device)

In [7]:
train(
    model,
    train_loader,
    test_loader,
    nn.TripletMarginWithDistanceLoss(
        distance_function=lambda x, y: 1 - torch.cosine_similarity(x, y), margin=0.2
    ),
    epochs=10,
    device=device,
    model_name="ResNet50 Cosine Similarity Loss Margin 0.5",
)

View logs by running tensorboard --logdir=./logs/ResNet50 Cosine Similarity Loss Margin 0.2


Epoch 0 Training: 100%|██████████| 1003/1003 [12:44<00:00,  1.31batch/s]
Epoch 0 Evaluation: 100%|██████████| 113/113 [01:01<00:00,  1.84batch/s]


Epoch 0 Validation Loss: 0.028443099930882454


Epoch 1 Training: 100%|██████████| 1003/1003 [12:52<00:00,  1.30batch/s]
Epoch 1 Evaluation: 100%|██████████| 113/113 [01:02<00:00,  1.82batch/s]


Epoch 1 Validation Loss: 0.023438626900315285


Epoch 2 Training: 100%|██████████| 1003/1003 [12:49<00:00,  1.30batch/s]
Epoch 2 Evaluation: 100%|██████████| 113/113 [01:00<00:00,  1.85batch/s]


Epoch 2 Validation Loss: 0.018526531755924225


Epoch 3 Training: 100%|██████████| 1003/1003 [12:43<00:00,  1.31batch/s]
Epoch 3 Evaluation: 100%|██████████| 113/113 [01:02<00:00,  1.80batch/s]


Epoch 3 Validation Loss: 0.023565111681818962


Epoch 4 Training: 100%|██████████| 1003/1003 [13:02<00:00,  1.28batch/s]
Epoch 4 Evaluation: 100%|██████████| 113/113 [01:01<00:00,  1.83batch/s]


Epoch 4 Validation Loss: 0.02030663751065731


Epoch 5 Training: 100%|██████████| 1003/1003 [12:35<00:00,  1.33batch/s]
Epoch 5 Evaluation: 100%|██████████| 113/113 [01:01<00:00,  1.82batch/s]


Epoch 5 Validation Loss: 0.017346670851111412


Epoch 6 Training: 100%|██████████| 1003/1003 [12:49<00:00,  1.30batch/s]
Epoch 6 Evaluation: 100%|██████████| 113/113 [01:12<00:00,  1.56batch/s]


Epoch 6 Validation Loss: 0.014622221700847149


Epoch 7 Training:   6%|▌         | 61/1003 [00:52<13:33,  1.16batch/s]


KeyboardInterrupt: 