In [2]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

In [3]:
# TODO: Implement FashionDataset once we know what the dataset looks like
class FashionDataset(Dataset):
    def __init__(self) -> None:
        super().__init__()

    def __len__(self) -> int:
        pass

    def __getitem__(self, index):
        pass

In [None]:
def train(model: nn.Module, train_data: DataLoader, val_data: DataLoader, loss_fcn: nn.Module, epochs: int = 10):

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # Tensorboard logger
    logger = SummaryWriter("./logs/Landmark V1.0")

    for epoch in range(epochs):

        # Set model to training mode
        model.train()

        for i, (anchor, positive, negative) in tqdm(enumerate(train_data, f"Epoch {epoch}", unit="batch")):
            optimizer.zero_grad()

            # Send images to the device
            anchor.to(model.device)
            positive.to(model.device)
            negative.to(model.device)

            # Forward pass
            anchor_features = model(anchor)
            positive_features = model(positive)
            negative_features = model(negative)

            # Compute the loss
            loss = loss_fcn(anchor_features, positive_features, negative_features)

            # Log loss to tensorboard
            logger.add_scalar("Triplet Loss", loss, i)

            # Backward pass
            loss.backward()
            optimizer.step()

        # Evaluate the model on the testing data
        model.eval()

        validation_loss = 0.0

        with torch.no_grad():
            for i, (anchor, positive, negative) in tqdm(enumerate(val_data, "Testing", unit="batch")):
                anchor.to(model.device)
                positive.to(model.device)
                negative.to(model.device)

                anchor_features = model(anchor)
                positive_features = model(positive)
                negative_features = model(negative)

                #TODO: Add some more metrics to the validation loop (i.e. euclidean distance between a & p, a & n, etc...)
                
                # Compute the loss
                loss = loss_fcn(anchor_features, positive_features, negative_features)

                validation_loss += loss

        logger.add_scalar("Validation Loss", validation_loss / len(val_data), epoch)

        print(f"Epoch {epoch} Validation Loss: {validation_loss / len(val_data)}")

In [None]:
#TODO: Split into training & validation data

# Load the dataset
dataset = FashionDataset()

# Define the training dataloader
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

# Define the validation dataloader
val_loader = DataLoader(dataset, batch_size=32, shuffle=False)

In [5]:
# Set the device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the model
model = models.resnet50()

In [None]:
train(model, train_loader, val_loader, nn.TripletMarginLoss(), epochs=100)