## Load libs

In [3]:
import cv2
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
import numpy as np 
import albumentations as A
import matplotlib.pyplot as plt
from albumentations.pytorch import ToTensorV2 #np.array -> torch.tensor
from pytorch_metric_learning import losses, miners
from torch.utils.data import DataLoader, Dataset, Sampler

import warnings
import itertools
import random
import yaml

## Define configs (Miner, Loss)

In [4]:
class UnNormalize(object):
  def __init__(self, mean, std):
    self.mean = mean
    self.std = std

  def __call__(self, tensor):
    """
    Args:
      tensor (Tensor): Tensor image of size (C, H, W) to be normalized'
    Returns:
      Tensor: Normalized image
    """
    for t, m, s in zip(tensor, self.mean, self.std):
      t.mul_(s).add_(m) #in-place operation (not make a copy but change object directly)
      #The normalize code -> t.sub_(m).div_(s)
    return tensor

unorm = UnNormalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))

## Load config

In [5]:
def load_config(config_path):
    with open(config_path, 'r') as file:
        config = yaml.safe_load(file)
    return config

configs = load_config('../configs/config.yaml')

## Define Dataset

In [6]:
class TripletDataset(Dataset):
    def __init__(self, root_dir, transform=None, is_train=True):
        """
        Args:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
            is_train (bool): Whether this is training set or test set
        """
        self.root_dir = root_dir
        self.transform = transform
        self.is_train = is_train
        
        # Define class folders
        self.class_folders = ['normal', 'preplus', 'plus']
        
        # Create a list of (image_path, class_label) tuples
        self.samples = []
        for class_idx, class_name in enumerate(self.class_folders):
            class_dir = os.path.join(root_dir, class_name)
            if os.path.isdir(class_dir):
                for img_name in os.listdir(class_dir):
                    if img_name.endswith(('.jpg', '.jpeg', '.png', '.bmp')):
                        img_path = os.path.join(class_dir, img_name)
                        self.samples.append((img_path, class_idx))
        
        # Shuffle samples
        random.shuffle(self.samples)
        
        # Split into train/test if needed
        if not is_train:
            # Use last 20% for testing
            split_idx = int(len(self.samples) * 0.8)
            self.samples = self.samples[split_idx:]
        else:
            # Use first 80% for training
            split_idx = int(len(self.samples) * 0.8)
            self.samples = self.samples[:split_idx]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        
        # Read image
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # Apply transformations
        if self.transform:
            transformed = self.transform(image=img)
            img = transformed["image"]
        
        return img, label

## Define transformations

In [7]:
def get_transforms(is_train=True):
    if is_train:
        return A.Compose([
            A.Resize(224, 224),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.RandomBrightnessContrast(p=0.2),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2()
        ])
    else:
        return A.Compose([
            A.Resize(224, 224),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2()
        ])

## Define loss and miner functions

In [8]:
def get_loss_and_miner(config):
    # Initialize miner
    if 'triplet_margin_miner' in config['mining']:
        miner_config = config['mining']['triplet_margin_miner'][0]  # Use first config
        miner = miners.TripletMarginMiner(
            margin=miner_config['m'],
            type_of_triplets="all"
        )
    elif 'batch_easy_hard_miner' in config['mining']:
        miner_config = config['mining']['batch_easy_hard_miner'][0]  # Use first config
        miner = miners.BatchEasyHardMiner(
            pos_strategy=miner_config['pos_strategy'],
            neg_strategy=miner_config['neg_strategy']
        )
    else:
        miner = None
    
    # Initialize loss
    if 'triplet_loss' in config['loss']:
        loss_fn = losses.TripletMarginLoss(margin=0.2)
    elif 'ntxent_loss' in config['loss']:
        loss_config = config['loss']['ntxent_loss']
        loss_fn = losses.NTXentLoss(temperature=loss_config['temperature'])
    else:
        loss_fn = losses.TripletMarginLoss(margin=0.2)  # Default
    
    return loss_fn, miner

## Model

In [9]:
class EmbeddingNet(nn.Module):
    def __init__(self):
        super(EmbeddingNet, self).__init__()
        self.backbone = models.resnet18(pretrained=True)
        self.backbone.fc = nn.Linear(self.backbone.fc.in_features, 128)

    def forward(self, x):
        x = self.backbone(x)
        x = nn.functional.normalize(x, p=2, dim=1)
        return x

# Setup data loaders

In [10]:
def setup_data_loaders(configs):
    # Define dataset paths - adjust these paths as needed
    data_root = '../../ROP-o/Triplet-data/data'
    
    # Create datasets
    train_dataset = TripletDataset(
        root_dir=data_root,
        transform=get_transforms(is_train=True),
        is_train=True
    )
    
    test_dataset = TripletDataset(
        root_dir=data_root,
        transform=get_transforms(is_train=False),
        is_train=False
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=configs['training']['batch_size'],
        shuffle=True,
        num_workers=os.cpu_count() or 4,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=configs['training']['batch_size'],
        shuffle=False,
        num_workers=os.cpu_count() or 4,
        pin_memory=True
    )
    
    return train_loader, test_loader

# Test model

In [11]:
# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# load data
batch_size = configs['training']['batch_size']
n_workers = os.cpu_count()
print("num_workers =", n_workers)

trainloader, testloader = setup_data_loaders(configs)

# model
model = EmbeddingNet().to(device)

# loss, miner
loss_fn, miner = get_loss_and_miner(configs)

# optimizer 
optimizer = optim.Adam(model.parameters(), lr=configs['training']['learning_rate'])

num_workers = 16


ValueError: num_samples should be a positive integer value, but got num_samples=0

# metrics

In [None]:
class AverageMeter:
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

# Training function

In [None]:
def train_epoch(model, train_loader, optimizer, loss_fn, miner, device):
    model.train()
    losses = AverageMeter()
    
    for data, labels in train_loader:
        data, labels = data.to(device), labels.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        embeddings = model(data)
        
        # Get triplets using miner if available
        if miner:
            hard_pairs = miner(embeddings, labels)
            loss = loss_fn(embeddings, labels, hard_pairs)
        else:
            loss = loss_fn(embeddings, labels)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        # Update statistics
        losses.update(loss.item(), data.size(0))
    
    return losses.avg

# Evaluation function

In [None]:
def evaluate(model, test_loader, loss_fn, miner, device):
    model.eval()
    losses = AverageMeter()
    
    with torch.no_grad():
        for data, labels in test_loader:
            data, labels = data.to(device), labels.to(device)
            
            # Forward pass
            embeddings = model(data)
            
            # Calculate loss
            if miner:
                hard_pairs = miner(embeddings, labels)
                loss = loss_fn(embeddings, labels, hard_pairs)
            else:
                loss = loss_fn(embeddings, labels)
            
            # Update statistics
            losses.update(loss.item(), data.size(0))
    
    return losses.avg

## Training

In [None]:
epochs = configs['training']['epochs']
best_loss = float('inf')

# Training loop
for e in range(epochs):
    # Train for one epoch
    train_loss = train_epoch(model, trainloader, optimizer, loss_fn, miner, device)
    
    # Evaluate on test set
    test_loss = evaluate(model, testloader, loss_fn, miner, device)
    
    # Print progress
    print(f'Epoch {e+1}/{epochs}:')
    print(f'  Train Loss: {train_loss:.4f}')
    print(f'  Test Loss: {test_loss:.4f}')
    
    # Save best model
    if test_loss < best_loss:
        best_loss = test_loss
        # Save model checkpoint
        checkpoint_dir = '../log/checkpoints'
        os.makedirs(checkpoint_dir, exist_ok=True)
        torch.save({
            'epoch': e,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': best_loss,
        }, os.path.join(checkpoint_dir, 'best_model.pth'))
        print(f'  Saved new best model with loss: {best_loss:.4f}')

## Evaluation

In [None]:
def visualize_embeddings(model, test_loader, device, n_samples=100):
    """Visualize embeddings using t-SNE"""
    from sklearn.manifold import TSNE
    import matplotlib.pyplot as plt
    
    model.eval()
    embeddings_list = []
    labels_list = []
    
    with torch.no_grad():
        for data, labels in test_loader:
            if len(embeddings_list) >= n_samples:
                break
                
            data = data.to(device)
            embeddings = model(data)
            
            embeddings_list.append(embeddings.cpu().numpy())
            labels_list.append(labels.numpy())
    
    # Concatenate all embeddings and labels
    embeddings_array = np.vstack(embeddings_list)
    labels_array = np.concatenate(labels_list)
    
    # Limit to n_samples
    if len(embeddings_array) > n_samples:
        embeddings_array = embeddings_array[:n_samples]
        labels_array = labels_array[:n_samples]
    
    # Apply t-SNE
    tsne = TSNE(n_components=2, random_state=42)
    embeddings_2d = tsne.fit_transform(embeddings_array)
    
    # Plot
    plt.figure(figsize=(10, 8))
    for i, label in enumerate(['normal', 'preplus', 'plus']):
        mask = labels_array == i
        plt.scatter(
            embeddings_2d[mask, 0],
            embeddings_2d[mask, 1],
            label=label,
            alpha=0.7
        )
    
    plt.legend()
    plt.title('t-SNE visualization of embeddings')
    plt.show()

In [None]:
# Load best model for evaluation
checkpoint_path = '../log/checkpoints/best_model.pth'
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded best model from epoch {checkpoint['epoch']+1} with loss: {checkpoint['loss']:.4f}")
    
    # Visualize embeddings
    visualize_embeddings(model, testloader, device)
else:
    print("No saved model found. Using the last trained model.")
    visualize_embeddings(model, testloader, device)