In [1]:
# 1. Import Libraries and Setup
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
from torch.optim import AdamW
import os
import sys

# Fix OpenMP conflict error
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
# Optimize CUDA memory
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Add current directory to path to import loader
sys.path.append(os.getcwd())
from ouhands_loader import OuhandsDS

# Setup Device
device = (
    torch.device("cuda") if torch.cuda.is_available()
    else torch.device("mps") if torch.backends.mps.is_available()
    else torch.device("cpu")
)
print("Using device:", device)

Using device: cuda


In [2]:
# 2. Prepare DataLoaders

# Standard ImageNet normalization
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

batch_size = 32 # Adjusted for memory safety
num_workers = 0 # Windows compatibility

train_ds = OuhandsDS(split='train', transform=transform)
val_ds = OuhandsDS(split='validation', transform=transform)
test_ds = OuhandsDS(split='test', transform=transform)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)

print(f"Train size: {len(train_ds)}")
print(f"Val size: {len(val_ds)}")
print(f"Test size: {len(test_ds)}")

Loaded 1600 samples for train split
Class distribution: {'A': 160, 'B': 160, 'C': 160, 'D': 160, 'E': 160, 'F': 160, 'H': 160, 'I': 160, 'J': 160, 'K': 160}
Loaded 400 samples for validation split
Class distribution: {'A': 40, 'B': 40, 'C': 40, 'D': 40, 'E': 40, 'F': 40, 'H': 40, 'I': 40, 'J': 40, 'K': 40}
Loaded 1000 samples for test split
Class distribution: {'A': 100, 'B': 100, 'C': 100, 'D': 100, 'E': 100, 'F': 100, 'H': 100, 'I': 100, 'J': 100, 'K': 100}
Train size: 1600
Val size: 400
Test size: 1000
Loaded 400 samples for validation split
Class distribution: {'A': 40, 'B': 40, 'C': 40, 'D': 40, 'E': 40, 'F': 40, 'H': 40, 'I': 40, 'J': 40, 'K': 40}
Loaded 1000 samples for test split
Class distribution: {'A': 100, 'B': 100, 'C': 100, 'D': 100, 'E': 100, 'F': 100, 'H': 100, 'I': 100, 'J': 100, 'K': 100}
Train size: 1600
Val size: 400
Test size: 1000


In [3]:
# 3. Define Gated Fusion Model

class GatedFusionModel(nn.Module):
    def __init__(self, num_classes=10, common_dim=512, simclr_path=None):
        super(GatedFusionModel, self).__init__()
        
        # --- 1. Backbones ---
        
        # SimCLR (ResNet50)
        self.simclr_backbone = models.resnet50(weights=None)
        self.simclr_backbone.fc = nn.Identity()
        self._load_simclr_weights(simclr_path)
        
        # DINO (ViT-S/16)
        print("Loading DINO ViT-S/16...")
        self.dino_backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')
        
        # --- 2. Projections ---
        # SimCLR output: 2048 -> 512
        self.proj_s = nn.Linear(2048, common_dim)
        # DINO output: 384 -> 512
        self.proj_d = nn.Linear(384, common_dim)
        
        # --- 3. Gating Mechanism ---
        # Input: Concatenated [u_D; u_S] (512 + 512 = 1024)
        # Output: Gate vector g (512)
        self.gate_layer = nn.Linear(common_dim * 2, common_dim)
        self.sigmoid = nn.Sigmoid()
        
        # --- 4. Classifier ---
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(common_dim, num_classes)
        )

    def _load_simclr_weights(self, path):
        if path and os.path.exists(path):
            print(f"Loading SimCLR weights from {path}")
            try:
                checkpoint = torch.load(path, map_location="cpu")
                state_dict = checkpoint.get('state_dict', checkpoint)
                
                new_state_dict = {}
                for k, v in state_dict.items():
                    name = k.replace("module.", "")
                    if name.startswith("backbone."):
                        name = name.replace("backbone.", "")
                    if name.startswith("resnet."):
                        name = name.replace("resnet.", "")
                    new_state_dict[name] = v
                
                msg = self.simclr_backbone.load_state_dict(new_state_dict, strict=False)
                print(f"SimCLR weights loaded: {msg}")
            except Exception as e:
                print(f"Error loading SimCLR: {e}")
        else:
            print("SimCLR weights not found, using random init.")

    def forward(self, x):
        # 1. Extract Features
        # SimCLR
        h_s = self.simclr_backbone(x) # (B, 2048)
        # DINO
        h_d = self.dino_backbone(x)   # (B, 384)
        
        # 2. Project
        u_s = self.proj_s(h_s) # (B, 512)
        u_d = self.proj_d(h_d) # (B, 512)
        
        # 3. Compute Gate
        # Concatenate
        concat = torch.cat([u_d, u_s], dim=1) # (B, 1024)
        # Gate vector
        g = self.sigmoid(self.gate_layer(concat)) # (B, 512)
        
        # 4. Fuse
        # u_fused = g * u_D + (1 - g) * u_S
        u_fused = g * u_d + (1 - g) * u_s
        
        # 5. Classify
        logits = self.classifier(u_fused)
        
        return logits

# Initialize Model
simclr_path = r"D:\Courses\Csc2503\proj\CSC2503-Project\notebooks\checkpoint_100"
model = GatedFusionModel(num_classes=10, simclr_path=simclr_path).to(device)
print(model)

Loading SimCLR weights from D:\Courses\Csc2503\proj\CSC2503-Project\notebooks\checkpoint_100
SimCLR weights loaded: _IncompatibleKeys(missing_keys=['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'layer1.0.conv1.weight', 'layer1.0.bn1.weight', 'layer1.0.bn1.bias', 'layer1.0.bn1.running_mean', 'layer1.0.bn1.running_var', 'layer1.0.conv2.weight', 'layer1.0.bn2.weight', 'layer1.0.bn2.bias', 'layer1.0.bn2.running_mean', 'layer1.0.bn2.running_var', 'layer1.0.conv3.weight', 'layer1.0.bn3.weight', 'layer1.0.bn3.bias', 'layer1.0.bn3.running_mean', 'layer1.0.bn3.running_var', 'layer1.0.downsample.0.weight', 'layer1.0.downsample.1.weight', 'layer1.0.downsample.1.bias', 'layer1.0.downsample.1.running_mean', 'layer1.0.downsample.1.running_var', 'layer1.1.conv1.weight', 'layer1.1.bn1.weight', 'layer1.1.bn1.bias', 'layer1.1.bn1.running_mean', 'layer1.1.bn1.running_var', 'layer1.1.conv2.weight', 'layer1.1.bn2.weight', 'layer1.1.bn2.bias', 'layer1.1.bn2.running_mean', 

  checkpoint = torch.load(path, map_location="cpu")
Using cache found in C:\Users\24912/.cache\torch\hub\facebookresearch_dino_main


GatedFusionModel(
  (simclr_backbone): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Seque

  return t.to(


In [4]:
# 4. Training Loop

criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)

def train_one_epoch(model, loader, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in tqdm(loader, desc="Training"):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
    return running_loss / total, 100. * correct / total

def evaluate(model, loader, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
    return running_loss / total, 100. * correct / total

# Run Training
num_epochs = 20
best_acc = 0.0

print("Starting Training...")
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, device)
    val_loss, val_acc = evaluate(model, val_loader, device)
    
    scheduler.step()
    
    print(f"Train Loss: {train_loss:.4f} | Acc: {train_acc:.2f}%")
    print(f"Val Loss: {val_loss:.4f} | Acc: {val_acc:.2f}%")
    
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), "best_gated_fusion_model.pth")
        print("Saved Best Model!")

print(f"Training Complete. Best Val Acc: {best_acc:.2f}%")

Starting Training...
Epoch 1/20


Training: 100%|██████████| 50/50 [00:48<00:00,  1.04it/s]
Training: 100%|██████████| 50/50 [00:48<00:00,  1.04it/s]
Evaluating: 100%|██████████| 13/13 [00:08<00:00,  1.48it/s]



Train Loss: 2.6080 | Acc: 9.94%
Val Loss: 2.3970 | Acc: 10.00%
Saved Best Model!
Epoch 2/20


Training: 100%|██████████| 50/50 [00:33<00:00,  1.51it/s]
Training: 100%|██████████| 50/50 [00:33<00:00,  1.51it/s]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.26it/s]



Train Loss: 2.4003 | Acc: 12.19%
Val Loss: 2.5774 | Acc: 10.75%
Saved Best Model!
Epoch 3/20


Training: 100%|██████████| 50/50 [00:32<00:00,  1.52it/s]
Training: 100%|██████████| 50/50 [00:32<00:00,  1.52it/s]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.36it/s]



Train Loss: 2.3203 | Acc: 15.44%
Val Loss: 2.4723 | Acc: 11.50%
Saved Best Model!
Epoch 4/20


Training: 100%|██████████| 50/50 [00:32<00:00,  1.52it/s]
Training: 100%|██████████| 50/50 [00:32<00:00,  1.52it/s]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.34it/s]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.34it/s]


Train Loss: 2.1463 | Acc: 20.56%
Val Loss: 2.8408 | Acc: 8.75%
Epoch 5/20


Training: 100%|██████████| 50/50 [00:32<00:00,  1.52it/s]
Training: 100%|██████████| 50/50 [00:32<00:00,  1.52it/s]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.34it/s]



Train Loss: 1.7718 | Acc: 35.12%
Val Loss: 2.5718 | Acc: 17.75%
Saved Best Model!
Epoch 6/20


Training: 100%|██████████| 50/50 [00:33<00:00,  1.52it/s]
Training: 100%|██████████| 50/50 [00:33<00:00,  1.52it/s]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.35it/s]



Train Loss: 1.2707 | Acc: 53.94%
Val Loss: 2.3881 | Acc: 22.50%
Saved Best Model!
Epoch 7/20


Training: 100%|██████████| 50/50 [00:33<00:00,  1.48it/s]
Training: 100%|██████████| 50/50 [00:33<00:00,  1.48it/s]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.35it/s]



Train Loss: 0.7974 | Acc: 72.25%
Val Loss: 3.2309 | Acc: 23.75%
Saved Best Model!
Epoch 8/20
Saved Best Model!
Epoch 8/20


Training: 100%|██████████| 50/50 [00:32<00:00,  1.54it/s]
Training: 100%|██████████| 50/50 [00:32<00:00,  1.54it/s]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.38it/s]



Train Loss: 0.4579 | Acc: 84.75%
Val Loss: 2.2043 | Acc: 42.50%
Saved Best Model!
Epoch 9/20


Training: 100%|██████████| 50/50 [00:32<00:00,  1.54it/s]
Training: 100%|██████████| 50/50 [00:32<00:00,  1.54it/s]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.38it/s]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.38it/s]


Train Loss: 0.2726 | Acc: 91.56%
Val Loss: 2.8430 | Acc: 40.50%
Epoch 10/20


Training: 100%|██████████| 50/50 [00:32<00:00,  1.53it/s]
Training: 100%|██████████| 50/50 [00:32<00:00,  1.53it/s]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.35it/s]



Train Loss: 0.1572 | Acc: 96.12%
Val Loss: 1.9011 | Acc: 47.25%
Saved Best Model!
Epoch 11/20


Training: 100%|██████████| 50/50 [00:32<00:00,  1.54it/s]
Training: 100%|██████████| 50/50 [00:32<00:00,  1.54it/s]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.38it/s]



Train Loss: 0.0785 | Acc: 98.44%
Val Loss: 1.9051 | Acc: 50.50%
Saved Best Model!
Epoch 12/20


Training: 100%|██████████| 50/50 [00:32<00:00,  1.54it/s]
Training: 100%|██████████| 50/50 [00:32<00:00,  1.54it/s]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.37it/s]



Train Loss: 0.0423 | Acc: 99.12%
Val Loss: 1.8545 | Acc: 51.50%
Saved Best Model!
Epoch 13/20


Training: 100%|██████████| 50/50 [00:33<00:00,  1.48it/s]
Training: 100%|██████████| 50/50 [00:33<00:00,  1.48it/s]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.33it/s]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.33it/s]


Train Loss: 0.0361 | Acc: 99.19%
Val Loss: 1.9475 | Acc: 51.25%
Epoch 14/20


Training: 100%|██████████| 50/50 [00:32<00:00,  1.55it/s]
Training: 100%|██████████| 50/50 [00:32<00:00,  1.55it/s]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.37it/s]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.37it/s]


Train Loss: 0.0266 | Acc: 99.44%
Val Loss: 1.8750 | Acc: 47.75%
Epoch 15/20


Training: 100%|██████████| 50/50 [00:32<00:00,  1.54it/s]
Training: 100%|██████████| 50/50 [00:32<00:00,  1.54it/s]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.36it/s]



Train Loss: 0.0202 | Acc: 99.94%
Val Loss: 1.9695 | Acc: 52.75%
Saved Best Model!
Epoch 16/20


Training: 100%|██████████| 50/50 [00:32<00:00,  1.55it/s]
Training: 100%|██████████| 50/50 [00:32<00:00,  1.55it/s]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.37it/s]



Train Loss: 0.0155 | Acc: 99.88%
Val Loss: 1.8574 | Acc: 55.75%
Saved Best Model!
Epoch 17/20


Training: 100%|██████████| 50/50 [00:32<00:00,  1.52it/s]
Training: 100%|██████████| 50/50 [00:32<00:00,  1.52it/s]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.17it/s]



Train Loss: 0.0124 | Acc: 99.88%
Val Loss: 1.8041 | Acc: 56.00%
Saved Best Model!
Epoch 18/20


Training: 100%|██████████| 50/50 [00:32<00:00,  1.53it/s]
Training: 100%|██████████| 50/50 [00:32<00:00,  1.53it/s]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.37it/s]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.37it/s]


Train Loss: 0.0153 | Acc: 99.94%
Val Loss: 1.7623 | Acc: 54.25%
Epoch 19/20


Training: 100%|██████████| 50/50 [00:32<00:00,  1.54it/s]
Training: 100%|██████████| 50/50 [00:32<00:00,  1.54it/s]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.37it/s]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.37it/s]


Train Loss: 0.0132 | Acc: 99.88%
Val Loss: 1.7788 | Acc: 53.50%
Epoch 20/20


Training: 100%|██████████| 50/50 [00:32<00:00,  1.54it/s]
Training: 100%|██████████| 50/50 [00:32<00:00,  1.54it/s]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.23it/s]

Train Loss: 0.0102 | Acc: 99.94%
Val Loss: 1.8478 | Acc: 54.50%
Training Complete. Best Val Acc: 56.00%





In [5]:
# 5. Final Evaluation
from sklearn.metrics import f1_score
import numpy as np

if os.path.exists("best_gated_fusion_model.pth"):
    model.load_state_dict(torch.load("best_gated_fusion_model.pth"))
    print("Loaded best model.")

model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Testing"):
        images = images.to(device)
        outputs = model(images)
        _, preds = outputs.max(1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.numpy())

test_acc = np.mean(np.array(all_preds) == np.array(all_labels)) * 100
test_f1 = f1_score(all_labels, all_preds, average='macro')

print("\n" + "="*30)
print(f"FINAL RESULTS (Gated Fusion)")
print("="*30)
print(f"{'Metric':<15} | {'Value':<10}")
print("-" * 30)
print(f"{'Top-1 Acc (%)':<15} | {test_acc:.2f}")
print(f"{'Macro-F1':<15} | {test_f1:.4f}")
print("="*30)

  model.load_state_dict(torch.load("best_gated_fusion_model.pth"))


Loaded best model.


Testing: 100%|██████████| 32/32 [00:22<00:00,  1.40it/s]


FINAL RESULTS (Gated Fusion)
Metric          | Value     
------------------------------
Top-1 Acc (%)   | 34.40
Macro-F1        | 0.3470



