In [6]:
# 1. Import Libraries and Setup Device
import os
# Fix OpenMP conflict error
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from torch.optim import AdamW
import sys
import numpy as np
from PIL import Image, ImageOps, ImageFilter
import random

# Add current directory to path to import loader
sys.path.append(os.getcwd())
from ouhands_loader import OuhandsDS

# Setup Device
device = (
    torch.device("cuda") if torch.cuda.is_available()
    else torch.device("mps") if torch.backends.mps.is_available()
    else torch.device("cpu")
)
print("Using device:", device)

Using device: cuda


In [7]:
# 2. Define Barlow Twins Augmentations

class BarlowTwinsTransform:
    """
    Generates two augmented views of the same image.
    Includes: RandomResizedCrop, HorizontalFlip, ColorJitter, Grayscale, Solarization, GaussianBlur.
    """
    def __init__(self, size=224):
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(size, scale=(0.4, 1.0)), # Less aggressive crop for small dataset
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
                p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomApply([transforms.GaussianBlur(kernel_size=23)], p=0.1), # Blur
            transforms.RandomSolarize(threshold=128, p=0.0), # Disable solarize for hand gestures (might destroy shape)
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        # Simple transform for validation/testing (no augmentation)
        self.test_transform = transforms.Compose([
            transforms.Resize((size, size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __call__(self, x):
        # Return two augmented views
        return self.transform(x), self.transform(x)

# 3. Dataset Wrapper
class BarlowTwinsOuhandsDS(Dataset):
    def __init__(self, split='train', root_dir=r'D:\Courses\Csc2503\proj\archive'):
        # Initialize base dataset
        # We disable default transform in base_ds to get PIL images
        self.base_ds = OuhandsDS(
            root_dir=root_dir,
            split=split,
            transform=lambda x: x 
        )
        self.augmentor = BarlowTwinsTransform(size=224)
        self.split = split

    def __len__(self):
        return len(self.base_ds)

    def __getitem__(self, idx):
        # Get raw PIL image and label
        img, label = self.base_ds[idx]
        
        if self.split == 'train':
            # Training: Return two augmented views + label
            view1, view2 = self.augmentor(img)
            return view1, view2, label
        else:
            # Validation/Test: Return one standard view + label
            # We just use the test_transform defined in augmentor
            view = self.augmentor.test_transform(img)
            return view, view, label # Return duplicate view to keep signature consistent

# Create Datasets
train_ds = BarlowTwinsOuhandsDS(split='train')
val_ds = BarlowTwinsOuhandsDS(split='validation')
test_ds = BarlowTwinsOuhandsDS(split='test')

# Create DataLoaders
batch_size = 32
num_workers = 0 # Windows compatibility

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)

print(f"Train size: {len(train_ds)}")

Loaded 1600 samples for train split
Class distribution: {'A': 160, 'B': 160, 'C': 160, 'D': 160, 'E': 160, 'F': 160, 'H': 160, 'I': 160, 'J': 160, 'K': 160}
Loaded 400 samples for validation split
Class distribution: {'A': 40, 'B': 40, 'C': 40, 'D': 40, 'E': 40, 'F': 40, 'H': 40, 'I': 40, 'J': 40, 'K': 40}
Loaded 1000 samples for test split
Class distribution: {'A': 100, 'B': 100, 'C': 100, 'D': 100, 'E': 100, 'F': 100, 'H': 100, 'I': 100, 'J': 100, 'K': 100}
Train size: 1600
Loaded 1000 samples for test split
Class distribution: {'A': 100, 'B': 100, 'C': 100, 'D': 100, 'E': 100, 'F': 100, 'H': 100, 'I': 100, 'J': 100, 'K': 100}
Train size: 1600


In [8]:
# 4. Define Barlow Twins Loss

class BarlowTwinsLoss(nn.Module):
    def __init__(self, lambda_param=0.005, vector_dim=2048):
        super(BarlowTwinsLoss, self).__init__()
        self.lambda_param = lambda_param
        self.vector_dim = vector_dim
        self.bn = nn.BatchNorm1d(vector_dim, affine=False) # BN without learnable params

    def forward(self, z1, z2):
        # z1, z2: (Batch, vector_dim)
        
        # Empirical cross-correlation matrix
        c = self.bn(z1).T @ self.bn(z2)
        
        # Sum the cross-correlation matrix between all gpus (if distributed)
        # Here we are single GPU, so just normalize by batch size
        c.div_(z1.size(0))

        # Loss
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
        off_diag = off_diagonal(c).pow_(2).sum()
        loss = on_diag + self.lambda_param * off_diag
        return loss

def off_diagonal(x):
    # return a flattened view of the off-diagonal elements of a square matrix
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

In [9]:
# 5. Define Hybrid Model (DINO + Barlow Twins)

class DINOBarlowTwins(nn.Module):
    def __init__(self, num_classes=10, projector_dim=2048):
        super(DINOBarlowTwins, self).__init__()
        
        # Backbone: DINO ViT-S/16
        print("Loading DINO ViT-S/16 backbone...")
        self.backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')
        embed_dim = 384
        
        # Classifier Head (Supervised)
        self.classifier = nn.Linear(embed_dim, num_classes)
        
        # Projector Head (Barlow Twins)
        # 3-layer MLP: Linear -> BN -> ReLU -> Linear -> BN -> ReLU -> Linear
        self.projector = nn.Sequential(
            nn.Linear(embed_dim, projector_dim),
            nn.BatchNorm1d(projector_dim),
            nn.ReLU(),
            nn.Linear(projector_dim, projector_dim),
            nn.BatchNorm1d(projector_dim),
            nn.ReLU(),
            nn.Linear(projector_dim, projector_dim)
        )

    def forward(self, x):
        # Backbone features
        features = self.backbone(x) # (B, 384)
        
        # Classification logits
        logits = self.classifier(features)
        
        # Projections for Barlow Twins
        projections = self.projector(features)
        
        return logits, projections

model = DINOBarlowTwins(num_classes=10).to(device)

Loading DINO ViT-S/16 backbone...


Using cache found in C:\Users\24912/.cache\torch\hub\facebookresearch_dino_main


In [10]:
# 6. Training Loop (Hybrid Loss)

# Losses
criterion_ce = nn.CrossEntropyLoss()
criterion_bt = BarlowTwinsLoss(lambda_param=0.005, vector_dim=2048).to(device)

# Optimizer
optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

# Weight for Barlow Twins Loss
alpha = 0.1 # Balance between CE and BT loss

def train_one_epoch(model, loader, optimizer, device):
    model.train()
    running_loss = 0.0
    running_ce = 0.0
    running_bt = 0.0
    correct = 0
    total = 0
    
    for view1, view2, labels in tqdm(loader, desc="Training"):
        view1, view2, labels = view1.to(device), view2.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass for both views
        logits1, proj1 = model(view1)
        logits2, proj2 = model(view2)
        
        # 1. Supervised Loss (CrossEntropy)
        # We can use logits from both views or just one. Let's use both for robustness.
        loss_ce = (criterion_ce(logits1, labels) + criterion_ce(logits2, labels)) / 2
        
        # 2. Self-Supervised Loss (Barlow Twins)
        loss_bt = criterion_bt(proj1, proj2)
        
        # Total Loss
        loss = loss_ce + alpha * loss_bt
        
        loss.backward()
        optimizer.step()
        
        # Metrics
        running_loss += loss.item() * labels.size(0)
        running_ce += loss_ce.item() * labels.size(0)
        running_bt += loss_bt.item() * labels.size(0)
        
        # Accuracy (using view1)
        _, predicted = logits1.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
    epoch_loss = running_loss / total
    epoch_acc = 100. * correct / total
    return epoch_loss, running_ce/total, running_bt/total, epoch_acc

def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for view1, _, labels in tqdm(loader, desc="Evaluating"):
            view1, labels = view1.to(device), labels.to(device)
            logits, _ = model(view1)
            _, predicted = logits.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
    return 100. * correct / total

# Run Training
num_epochs = 10
best_acc = 0.0

print(f"Starting training with Alpha={alpha} (BT Weight)...")

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    
    loss, loss_ce, loss_bt, train_acc = train_one_epoch(model, train_loader, optimizer, device)
    val_acc = evaluate(model, val_loader, device)
    
    scheduler.step()
    
    print(f"Train Loss: {loss:.4f} (CE: {loss_ce:.4f}, BT: {loss_bt:.4f}) | Acc: {train_acc:.2f}%")
    print(f"Val Acc: {val_acc:.2f}%")
    
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), "best_dino_bt_model.pth")
        print("Saved Best Model!")

print(f"Training Complete. Best Val Acc: {best_acc:.2f}%")

Starting training with Alpha=0.1 (BT Weight)...
Epoch 1/10


Training: 100%|██████████| 50/50 [00:52<00:00,  1.05s/it]
Training: 100%|██████████| 50/50 [00:52<00:00,  1.05s/it]
Evaluating: 100%|██████████| 13/13 [00:06<00:00,  1.91it/s]



Train Loss: 117.7415 (CE: 2.5591, BT: 1151.8237) | Acc: 23.50%
Val Acc: 33.25%
Saved Best Model!
Epoch 2/10


Training: 100%|██████████| 50/50 [00:42<00:00,  1.17it/s]
Training: 100%|██████████| 50/50 [00:42<00:00,  1.17it/s]
Evaluating: 100%|██████████| 13/13 [00:04<00:00,  2.73it/s]



Train Loss: 95.0413 (CE: 1.0860, BT: 939.5533) | Acc: 60.62%
Val Acc: 58.25%
Saved Best Model!
Epoch 3/10


Training: 100%|██████████| 50/50 [00:42<00:00,  1.17it/s]
Training: 100%|██████████| 50/50 [00:42<00:00,  1.17it/s]
Evaluating: 100%|██████████| 13/13 [00:04<00:00,  2.69it/s]



Train Loss: 88.2861 (CE: 0.6765, BT: 876.0967) | Acc: 74.38%
Val Acc: 63.75%
Saved Best Model!
Epoch 4/10


Training: 100%|██████████| 50/50 [00:42<00:00,  1.17it/s]
Training: 100%|██████████| 50/50 [00:42<00:00,  1.17it/s]
Evaluating: 100%|██████████| 13/13 [00:04<00:00,  2.65it/s]



Train Loss: 84.3170 (CE: 0.4550, BT: 838.6198) | Acc: 83.94%
Val Acc: 70.00%
Saved Best Model!
Epoch 5/10


Training: 100%|██████████| 50/50 [00:43<00:00,  1.14it/s]
Training: 100%|██████████| 50/50 [00:43<00:00,  1.14it/s]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.57it/s]



Train Loss: 81.7091 (CE: 0.3636, BT: 813.4554) | Acc: 87.69%
Val Acc: 79.00%
Saved Best Model!
Epoch 6/10


Training: 100%|██████████| 50/50 [00:42<00:00,  1.17it/s]
Training: 100%|██████████| 50/50 [00:42<00:00,  1.17it/s]
Evaluating: 100%|██████████| 13/13 [00:04<00:00,  2.70it/s]
Evaluating: 100%|██████████| 13/13 [00:04<00:00,  2.70it/s]


Train Loss: 80.6224 (CE: 0.2821, BT: 803.4031) | Acc: 90.94%
Val Acc: 69.75%
Epoch 7/10


Training: 100%|██████████| 50/50 [00:43<00:00,  1.16it/s]
Training: 100%|██████████| 50/50 [00:43<00:00,  1.16it/s]
Evaluating: 100%|██████████| 13/13 [00:04<00:00,  2.68it/s]



Train Loss: 80.0966 (CE: 0.2177, BT: 798.7896) | Acc: 93.31%
Val Acc: 81.25%
Saved Best Model!
Epoch 8/10


Training: 100%|██████████| 50/50 [00:43<00:00,  1.15it/s]
Training: 100%|██████████| 50/50 [00:43<00:00,  1.15it/s]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.59it/s]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.59it/s]


Train Loss: 78.5920 (CE: 0.1663, BT: 784.2570) | Acc: 96.00%
Val Acc: 75.75%
Epoch 9/10


Training: 100%|██████████| 50/50 [00:43<00:00,  1.16it/s]
Training: 100%|██████████| 50/50 [00:43<00:00,  1.16it/s]
Evaluating: 100%|██████████| 13/13 [00:04<00:00,  2.62it/s]



Train Loss: 77.8319 (CE: 0.1393, BT: 776.9261) | Acc: 96.94%
Val Acc: 82.00%
Saved Best Model!
Epoch 10/10


Training: 100%|██████████| 50/50 [00:43<00:00,  1.16it/s]
Training: 100%|██████████| 50/50 [00:43<00:00,  1.16it/s]
Evaluating: 100%|██████████| 13/13 [00:04<00:00,  2.70it/s]

Train Loss: 77.0165 (CE: 0.1331, BT: 768.8340) | Acc: 96.75%
Val Acc: 81.50%
Training Complete. Best Val Acc: 82.00%





In [11]:
# 7. Final Evaluation
from sklearn.metrics import f1_score

if os.path.exists("best_dino_bt_model.pth"):
    model.load_state_dict(torch.load("best_dino_bt_model.pth"))
    print("Loaded best model.")

model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for view1, _, labels in tqdm(test_loader, desc="Testing"):
        view1 = view1.to(device)
        logits, _ = model(view1)
        _, preds = logits.max(1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.numpy())

test_acc = np.mean(np.array(all_preds) == np.array(all_labels)) * 100
test_f1 = f1_score(all_labels, all_preds, average='macro')

print("\n" + "="*30)
print(f"FINAL RESULTS (DINO + Barlow Twins)")
print("="*30)
print(f"{'Metric':<15} | {'Value':<10}")
print("-" * 30)
print(f"{'Top-1 Acc (%)':<15} | {test_acc:.2f}")
print(f"{'Macro-F1':<15} | {test_f1:.4f}")
print("="*30)

  model.load_state_dict(torch.load("best_dino_bt_model.pth"))


Loaded best model.


Testing: 100%|██████████| 32/32 [00:19<00:00,  1.63it/s]


FINAL RESULTS (DINO + Barlow Twins)
Metric          | Value     
------------------------------
Top-1 Acc (%)   | 65.80
Macro-F1        | 0.6552



