In [None]:
# jupyter nbconvert --to script train_metrics1.ipynb

In [None]:
import warnings
warnings.filterwarnings('ignore')

import sys
import os

# Get the root directory: one level up from "trainers"
ROOT_DIR = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(ROOT_DIR)


In [None]:

import pandas as pd
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

from sklearn.metrics import roc_auc_score
from sklearn.neighbors import NearestNeighbors

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler

# Import PyTorch Metric Learning
from pytorch_metric_learning import losses, miners, reducers, distances
from pytorch_metric_learning.utils import accuracy_calculator

import albumentations as A
from albumentations.pytorch import ToTensorV2 #np.array -> torch.tensor (B, 3, H, W)
import timm

from utils.dataset import TripletDataset
from models.CNN_model import EmbeddingModel
from utils.tools import AverageMeter
from utils.warmup import GradualWarmupSchedulerV2

In [None]:
num_classes = 3
root_dir = '../datasets/'
csv_train_file = 'train_data_with_folds.csv'
class_list = ['normal', 'preplus', 'plus']
label_dict = {cls: i for i, cls in enumerate(class_list)}

df = pd.read_csv(os.path.join(root_dir, csv_train_file))

In [None]:
def get_transforms(image_size):
    transforms_train = A.Compose([
        A.Resize(image_size, image_size),
        # A.ImageCompression(quality_lower=80, quality_upper=100, p=0.25),
        # A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=0.5),
        # A.Flip(p=0.5),
        # A.RandomRotate90(p=0.5),
        # A.RandomBrightnessContrast(p=0.5),
        # A.CoarseDropout(num_holes_range=(1,1), hole_height_range=(8, 32), hole_width_range=(8, 32), p=0.25),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])
    
    transforms_val = A.Compose([
        A.Resize(image_size, image_size),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])

    return transforms_train, transforms_val

class UnNormalize(object):
  def __init__(self, mean, std):
    self.mean = mean
    self.std = std

  def __call__(self, tensor):
    """
    Args:
      tensor (Tensor): Tensor image of size (C, H, W) to be normalized'
    Returns:
      Tensor: Normalized image
    """
    for t, m, s in zip(tensor, self.mean, self.std):
      t.mul_(s).add_(m)
      #The normalize code -> t.sub_(m).div_(s)
    return tensor

unorm = UnNormalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))

def get_sampler(dataset): # WeightedRandomSampler
    labels = [dataset[idx][1] for idx in range(len(dataset))]
    class_counts = np.bincount(labels, minlength=num_classes)
    class_weights = 1.0 / (class_counts + 1e-6)
    sample_weights = [class_weights[label] for label in labels]
    sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )
    return sampler


# Plotting function
def plot_fold_history(fold, history, run_dir):
    epochs = list(range(1, len(history['train_loss']) + 1))
    plt.figure(figsize=(15, 5))
    
    # Plot auc
    plt.subplot(1, 2, 1)
    plt.plot(epochs, history['train_auc'], '-o', label='Train AUC', color='skyblue')
    plt.plot(epochs, history['val_auc'], '-o', label='Val AUC', color='lightcoral')
    plt.scatter(history['best_val_auc_epoch'], history['best_val_auc'], s=200, color='lightcoral')
    plt.text(history['best_val_auc_epoch'], history['best_val_auc'], f'max {history["best_val_auc"]:.4f}', size=12)
    plt.xlabel('Epoch')
    plt.ylabel('AUC')
    plt.title(f'Fold {fold} Auc')
    plt.legend()
    
    # Plot loss
    plt.subplot(1, 2, 2)
    plt.plot(epochs, history['train_loss'], '-o', label='Train Loss', color='skyblue')
    plt.plot(epochs, history['val_loss'], '-o', label='Val Loss', color='lightcoral')
    plt.plot(epochs, history['triplet_loss'], '-o', label='Triplet Loss', color='green')
    plt.scatter(history['best_val_loss_epoch'], history['best_val_loss'], s=200, color='lightcoral')
    plt.text(history['best_val_loss_epoch'], history['best_val_loss'], f'min {history["best_val_loss"]:.4f}', size=12)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'Fold {fold} Loss')
    plt.legend()
    
    plt.tight_layout()
    plot_path = os.path.join(run_dir, f"fold{fold}_history.png")
    plt.savefig(plot_path)
    print(f"History plot saved: {plot_path}")
    plt.show()

class EarlyStopper:
    def __init__(self, min_delta=0, patience=1, use_loss=False):
        self.min_delta = min_delta
        self.patience = patience
        self.use_loss = use_loss
        self.best_metric = float('inf') if use_loss else -float('inf')
        self.count = 0
        
    def early_stop(self, metric):
        if self.use_loss:
            if metric < self.best_metric:
                self.best_metric = metric
                self.count = 0
            elif metric > self.best_metric + self.min_delta:
                self.count += 1
                if self.count >= self.patience:
                    return True
        else:
            if metric > self.best_metric:
                self.best_metric = metric
                self.count = 0
            elif metric + self.min_delta < self.best_metric:
                self.count += 1
                if self.count >= self.patience:
                    return True
        return False
        


In [None]:
def compute_recall_at_k(embeddings, labels, k=1):
    """
    Compute Recall@K for given embeddings and labels.
    """
    embeddings = np.asarray(embeddings)
    labels = np.asarray(labels)

    # Fit Nearest Neighbors (exclude self-match)
    nn_model = NearestNeighbors(n_neighbors=k + 1, metric='cosine')
    nn_model.fit(embeddings)
    distances, indices = nn_model.kneighbors(embeddings)

    # Count correct labels among nearest neighbors (exclude self-match at idx 0)
    correct = 0
    for i in range(len(labels)):
        neighbor_idxs = indices[i][1:]  # exclude self
        neighbor_labels = labels[neighbor_idxs]
        if labels[i] in neighbor_labels:
            correct += 1

    recall_at_k = correct / len(labels)
    return recall_at_k

In [None]:
def train_epoch(model, loader, optimizer, criterion, triplet_loss, miner, device, loss_type):
    train_loss_meter = AverageMeter()
    triplet_loss_meter = AverageMeter()
    ce_loss_meter = AverageMeter()
    model.train()
    
    PROBS = []
    TARGETS = []
    
    for img, label in tqdm(loader, desc='Training'):
      
        optimizer.zero_grad()
        inputs = img.to(device)
        targets = label.to(device)
        
        # Get both logits and embeddings
        logits, embeddings = model(inputs)
        
        # Calculate classification loss
        ce_loss = criterion(logits, targets)
        
        hard_pairs = miner(embeddings, targets)
        if len(hard_pairs[0]) == 0:
            trip_loss = torch.tensor(0.0, device=device)
        else:
            trip_loss = triplet_loss(embeddings, targets, hard_pairs)

        if loss_type == "ce_only":
            loss = ce_loss
        elif loss_type == "triplet_only":
            loss = trip_loss
        else:
            loss = 0.5 * ce_loss + 0.5 * trip_loss
        
        loss.backward()
        optimizer.step()
        
        train_loss_meter.update(loss.item(), inputs.size(0))
        ce_loss_meter.update(ce_loss.item(), inputs.size(0))
        triplet_loss_meter.update(trip_loss.item() if trip_loss != 0 else 0, inputs.size(0))
        
        with torch.no_grad():
            probs = F.softmax(logits.float(), dim=1).cpu().numpy()
            PROBS.append(probs)
            TARGETS.append(targets.cpu().numpy())
    
    # Concatenate all predictions and targets
    PROBS = np.concatenate(PROBS)
    TARGETS = np.concatenate(TARGETS)

    if not np.allclose(PROBS.sum(axis=1), 1.0, atol=1e-5):
        print("PROBS not summing to 1!")
    if np.any(TARGETS < 0) or np.any(TARGETS >= num_classes):
        print(f"Invalid TARGETS values: {TARGETS}")

    # Compute AUC over entire epoch
    try:
        train_auc = roc_auc_score(y_true=TARGETS, y_score=PROBS, multi_class='ovr')
    except ValueError as e:
        print(f"Sample of PROBS: {PROBS[0]}, sum: {PROBS.sum(axis=1)}")
        print(f"Error: {e}")
        train_auc = 0.0
        
    return train_loss_meter.avg, train_auc, triplet_loss_meter.avg, ce_loss_meter.avg

# Validation epoch
def val_epoch(model, loader, criterion, triplet_loss, miner, device, loss_type):
    model.eval()
    val_loss_meter = AverageMeter()
    triplet_loss_meter = AverageMeter()
    ce_loss_meter = AverageMeter()
    
    PROBS = []
    TARGETS = []
    EMBEDDINGS = []

    with torch.no_grad():
        for img, label in loader:
            inputs = img.to(device)
            targets = label.to(device)

            logits, embeddings = model(inputs)
            
            # Calculate classification loss
            ce_loss = criterion(logits, targets)
            trip_loss = triplet_loss(embeddings, targets)
            
            if loss_type == "ce_only":
                loss = ce_loss
            elif loss_type == "triplet_only":
                loss = trip_loss
            else:
                # print("CE and trip")
                loss = 0.5 * ce_loss + 0.5 * trip_loss
            
            val_loss_meter.update(loss.item(), inputs.size(0))
            ce_loss_meter.update(ce_loss.item(), inputs.size(0))
            triplet_loss_meter.update(trip_loss.item() if trip_loss != 0 else 0, inputs.size(0))
            
            probs = F.softmax(logits.float(), dim=1).cpu().numpy()
            PROBS.append(probs)
            TARGETS.append(targets.cpu().numpy())
            EMBEDDINGS.append(embeddings.cpu().numpy())
    
    PROBS = np.concatenate(PROBS)
    TARGETS = np.concatenate(TARGETS)
    EMBEDDINGS = np.concatenate(EMBEDDINGS)

    try:
        val_auc = roc_auc_score(TARGETS, PROBS, multi_class='ovr')
    except ValueError as e:
        print(f"Val AUC failed: {e}, Unique targets: {np.unique(TARGETS)}, Probs shape: {PROBS.shape}")
        val_auc = 0.0

    recall_at_k = compute_recall_at_k(EMBEDDINGS, TARGETS, k=1)
    print(f"Validation Recall@1: {recall_at_k:.5f}")

    return val_loss_meter.avg, val_auc, triplet_loss_meter.avg, ce_loss_meter.avg, recall_at_k

In [None]:
def run(fold, df, root_dir, transforms_train, transforms_val, num_workers, n_epochs, device, batch_size, lr, run_dir, backbone_name):
    train_df = df[df['fold'] != fold].reset_index(drop=True)
    val_df = df[df['fold'] == fold].reset_index(drop=True)

    # Datasets
    train_ds = TripletDataset(root_dir, train_df,'train', transform=transforms_train)
    val_ds = TripletDataset(root_dir, val_df,'train', transform=transforms_val)
    
    # Sampler
    train_sampler = get_sampler(train_ds)

    # Data loaders
    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=False, 
        pin_memory=True, 
        sampler=train_sampler
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=False, 
        pin_memory=True, 
    )

    # Model, optimizer, criterion
    model = EmbeddingModel(num_classes,backbone_name, 512).to(device)
    optimizer = Adam(model.parameters(), lr=lr)
    
    # Classification loss
    criterion = nn.CrossEntropyLoss()
    
    # Distance function for triplet loss
    distance = distances.CosineSimilarity()
    
    # Batch hard miner for triplet loss
    miner = miners.BatchEasyHardMiner(
        pos_strategy="easy",  # hardest positive
        neg_strategy="semihard",  # hardest negative
        distance=distance
    )
    
    # Triplet loss with margin
    triplet_loss = losses.TripletMarginLoss(
        margin=1,
        distance=distance,
        reducer=reducers.AvgNonZeroReducer()
    )

    # Learning rate scheduler
    scheduler_cosine = CosineAnnealingWarmRestarts(optimizer, T_0=10)
    scheduler_warmup = GradualWarmupSchedulerV2(optimizer, multiplier=5, total_epoch=5, after_scheduler=scheduler_cosine)

    # History tracking
    history = {
        'train_loss': [], 'val_loss': [],
        'train_auc': [], 'val_auc': [],
        'triplet_loss': [], 'ce_loss': [],
        'learning_rates': [],
        'best_val_auc': 0, 'best_val_auc_epoch': 0,
        'best_val_loss': float('inf'), 'best_val_loss_epoch': 0
    }


    print(f"Fold {fold}: =========================================")
    
    early_stopping_active = False
    # es = EarlyStopper(min_delta=1e-3, patience=2)
    best_model_state_dict = None
    best_model_filename = None

    for epoch in range(1, n_epochs + 1):
        current_lr = optimizer.param_groups[0]['lr']
        history['learning_rates'].append(current_lr)
        
        print(f"\nEP {epoch}/{n_epochs} (LR: {current_lr:.6f}):")
        train_loss, train_auc, train_triplet_loss, train_ce_loss = train_epoch(model, train_loader, optimizer, criterion, triplet_loss, miner, device, loss_type="CE_trip")
        val_loss, val_auc, val_triplet_loss, val_ce_loss, recall_at_k = val_epoch(model, val_loader, criterion, triplet_loss, miner, device, loss_type='CE_trip')
        
        print(f"Train AUC: {train_auc:.4f}, CE Loss: {train_ce_loss:.4f}, Triplet Loss: {train_triplet_loss:.6f}")
        print(f"Val AUC: {val_auc:.4f}, CE Loss: {val_ce_loss:.4f}, Triplet Loss: {val_triplet_loss:.6f}")
        
        # Update history
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_auc'].append(train_auc)
        history['val_auc'].append(val_auc)
        history['triplet_loss'].append(train_triplet_loss)
        history['ce_loss'].append(train_ce_loss)
        
        if val_auc > history['best_val_auc']:
            history['best_val_auc'] = val_auc
            history['best_val_auc_epoch'] = epoch
            best_model_state_dict = model.state_dict()
            # checkpoint_path = os.path.join(run_dir, f"fold{fold}_best_auc{val_auc:.4f}_ep{epoch}.pth")
            # torch.save(model.state_dict(), checkpoint_path)
            print(f"New best AUC: {val_auc:.4f} at epoch {epoch}.")
        
        if val_loss < history['best_val_loss']:
            history['best_val_loss'] = val_loss
            history['best_val_loss_epoch'] = epoch
        
        # Step the scheduler (use warmup scheduler during warmup, then cosine)
        if epoch <= scheduler_warmup.total_epoch:
            scheduler_warmup.step()
        else:
            scheduler_cosine.step()
            
        if epoch == scheduler_warmup.total_epoch:
            early_stopping_active = True
            # Reset early stopper to forget the potentially misleading high scores during warmup
            es_ce = EarlyStopper(min_delta=1e-5, patience=10)
            es_trip = EarlyStopper(min_delta=1e-5, patience=10, use_loss=True)
            print("Warmup complete. Early stopping now active.")
        
        # Only check early stopping if it's active
        if early_stopping_active:
            if es_ce.early_stop(val_auc) and es_trip.early_stop(train_triplet_loss):
                print(f"Early stopping triggered at epoch {epoch}")
                break
    
    # Save the final model and best model
    if best_model_state_dict is not None:
        best_model_filename = os.path.join(run_dir, f"fold{fold}_best_auc{history['best_val_auc']:.4f}_ep{history['best_val_auc_epoch']}.pth")
        torch.save(best_model_state_dict, best_model_filename)
        print(f"Best model saved: {best_model_filename}")

    final_checkpoint_path = os.path.join(run_dir, f"fold{fold}_final.pth")
    torch.save(model.state_dict(), final_checkpoint_path)
    print(f"Final model saved: {final_checkpoint_path}")

    # Plot history
    plot_fold_history(fold, history, run_dir)
    
    # Compute OOF predictions after training
    best_model_path = os.path.join(run_dir, f"fold{fold}_best_auc{history['best_val_auc']:.4f}_ep{history['best_val_auc_epoch']}.pth")
    model.load_state_dict(torch.load(best_model_path))
    model.eval()
    oof_preds = []
    oof_targets = []
    oof_embeddings = []
    with torch.no_grad():
        for img, label in val_loader:
            inputs = img.to(device)
            targets = label.to(device)
            logits, embeddings = model(inputs)
            probs = F.softmax(logits, dim=1).cpu().numpy()
            oof_preds.append(probs)
            oof_targets.append(targets.cpu().numpy())
            oof_embeddings.append(embeddings.cpu().numpy())
    
    oof_preds = np.concatenate(oof_preds)
    oof_targets = np.concatenate(oof_targets)
    oof_embeddings = np.concatenate(oof_embeddings)
    
    # Save embeddings for later visualization or analysis
    embeddings_path = os.path.join(run_dir, f'fold_{fold}_embeddings.npy')
    np.save(embeddings_path, oof_embeddings)
    
    oof_names = val_df['path'].values
    oof_folds = np.full(len(oof_targets), fold)
    return oof_preds, oof_targets, oof_names, oof_folds, oof_embeddings

# Visualization function for embeddings using t-SNE
def visualize_embeddings(embeddings, labels, title="t-SNE Visualization of Embeddings", run_dir=None):
    from sklearn.manifold import TSNE
    
    # Apply t-SNE
    tsne = TSNE(n_components=2, random_state=42)
    embeddings_2d = tsne.fit_transform(embeddings)
    
    # Plot
    plt.figure(figsize=(10, 8))
    for class_idx in np.unique(labels):
        plt.scatter(
            embeddings_2d[labels == class_idx, 0],
            embeddings_2d[labels == class_idx, 1],
            label=f'Class {class_idx}'
        )
    
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plot_path = os.path.join(run_dir, f"{title.replace(' ', '_')}.png")
    plt.savefig(plot_path)
    plt.show()


In [None]:
IMG_SIZE= 600
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 16
n_epochs = 100
num_workers = 0
print(f"Num workers = {num_workers}")
folds=[0,1,2,3,4]
lr = 3e-4
backbone_name='resnet50'

checkpoint_dir = '../checkpoints'
run_num = 1
while True:
    run_dir = os.path.join(checkpoint_dir, f"run{run_num}")
    if not os.path.exists(run_dir):
        os.makedirs(run_dir)
        break
    run_num += 1
print(f"Checkpoints will be saved in: {run_dir}")

oof_preds_all = []
oof_targets_all = []
oof_names_all = []
oof_folds_all = []
oof_embeddings_all = []

transforms_train, transforms_val = get_transforms(IMG_SIZE)

for fold in folds:
    oof_preds, oof_targets, oof_names, oof_folds, oof_embeddings = run(
        fold, df, root_dir, transforms_train, transforms_val, num_workers, n_epochs, device, batch_size, lr, run_dir, backbone_name
    )
    oof_preds_all.append(oof_preds)
    oof_targets_all.append(oof_targets)
    oof_names_all.append(oof_names)
    oof_folds_all.append(oof_folds)
    
    # Visualize embeddings for this fold
    visualize_embeddings(oof_embeddings, oof_targets, f"Fold {fold} Embeddings", run_dir)

# Concatenate OOF data
oof_preds_all = np.concatenate(oof_preds_all)
oof_targets_all = np.concatenate(oof_targets_all)
oof_names_all = np.concatenate(oof_names_all)
oof_folds_all = np.concatenate(oof_folds_all)
# oof_embeddings_all = np.concatenate(oof_embeddings_all)

# Visualize all embeddings together
# visualize_embeddings(oof_embeddings_all, oof_targets_all, "All Folds Embeddings", run_dir)

# Compute overall OOF AUC
auc = roc_auc_score(oof_targets_all, oof_preds_all, multi_class='ovr')
print(f'Overall OOF AUC = {auc:.3f}')

# Save OOF to CSV with class probabilities
# Create a dictionary for the DataFrame
df_oof_dict = {
    'image_name': oof_names_all,
    'target': oof_targets_all,
    'fold': oof_folds_all
}

# Add probability columns for each class
for i, class_name in enumerate(class_list):
    df_oof_dict[f'prob_{class_name}'] = oof_preds_all[:, i]

# Create DataFrame
df_oof = pd.DataFrame(df_oof_dict)

# Save to run_dir
oof_path = os.path.join(run_dir, 'oof_triplet.csv')
df_oof.to_csv(oof_path, index=False)
print(f"OOF saved to: {oof_path}")
print(df_oof.head())