In [1]:
# 1. Import Libraries and Setup
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
from torch.optim import AdamW
import os
import sys

# Fix OpenMP conflict error
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
# Optimize CUDA memory
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Add current directory to path to import loader
sys.path.append(os.getcwd())
from ouhands_loader import OuhandsDS

# Setup Device
device = (
    torch.device("cuda") if torch.cuda.is_available()
    else torch.device("mps") if torch.backends.mps.is_available()
    else torch.device("cpu")
)
print("Using device:", device)

Using device: cuda


In [2]:
# 2. Prepare DataLoaders

# DINO-style augmentation for Linear Probing / Fine-tuning
# The authors recommend RandomResizedCrop and Flip for downstream classification.
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Standard ImageNet Validation Transform
val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

batch_size = 32 # Adjusted for memory safety
num_workers = 0 # Windows compatibility

# Apply different transforms to Train vs Val/Test
train_ds = OuhandsDS(split='train', transform=train_transform)
val_ds = OuhandsDS(split='validation', transform=val_transform)
test_ds = OuhandsDS(split='test', transform=val_transform)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)

print(f"Train size: {len(train_ds)}")
print(f"Val size: {len(val_ds)}")
print(f"Test size: {len(test_ds)}")

Loaded 1600 samples for train split
Class distribution: {'A': 160, 'B': 160, 'C': 160, 'D': 160, 'E': 160, 'F': 160, 'H': 160, 'I': 160, 'J': 160, 'K': 160}
Loaded 400 samples for validation split
Class distribution: {'A': 40, 'B': 40, 'C': 40, 'D': 40, 'E': 40, 'F': 40, 'H': 40, 'I': 40, 'J': 40, 'K': 40}
Loaded 1000 samples for test split
Class distribution: {'A': 100, 'B': 100, 'C': 100, 'D': 100, 'E': 100, 'F': 100, 'H': 100, 'I': 100, 'J': 100, 'K': 100}
Train size: 1600
Val size: 400
Test size: 1000
Loaded 1000 samples for test split
Class distribution: {'A': 100, 'B': 100, 'C': 100, 'D': 100, 'E': 100, 'F': 100, 'H': 100, 'I': 100, 'J': 100, 'K': 100}
Train size: 1600
Val size: 400
Test size: 1000


In [3]:
# 3. Define Gated Fusion Model (DINO + SwAV)

class GatedFusionSwAVModel(nn.Module):
    def __init__(self, num_classes=10, common_dim=512):
        super(GatedFusionSwAVModel, self).__init__()
        
        # --- 1. Backbones ---
        
        # SwAV (ResNet50)
        print("Loading SwAV ResNet50 from torch.hub...")
        self.swav_backbone = torch.hub.load('facebookresearch/swav:main', 'resnet50')
        self.swav_backbone.fc = nn.Identity()
        
        # DINO (ViT-S/16)
        print("Loading DINO ViT-S/16 from torch.hub...")
        self.dino_backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')
        
        # --- 2. Projections ---
        # SwAV output: 2048 -> 512
        self.proj_sw = nn.Linear(2048, common_dim)
        # DINO output: 384 -> 512
        self.proj_d = nn.Linear(384, common_dim)
        
        # --- 3. Gating Mechanism ---
        # Input: Concatenated [u_D; u_Sw] (512 + 512 = 1024)
        # Output: Gate vector g (512)
        self.gate_layer = nn.Linear(common_dim * 2, common_dim)
        self.sigmoid = nn.Sigmoid()
        
        # --- 4. Classifier ---
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(common_dim, num_classes)
        )

    def forward(self, x):
        # 1. Extract Features
        # SwAV
        h_sw = self.swav_backbone(x) # (B, 2048)
        # DINO
        h_d = self.dino_backbone(x)   # (B, 384)
        
        # 2. Project
        u_sw = self.proj_sw(h_sw) # (B, 512)
        u_d = self.proj_d(h_d)    # (B, 512)
        
        # 3. Compute Gate
        # Concatenate
        concat = torch.cat([u_d, u_sw], dim=1) # (B, 1024)
        # Gate vector
        g = self.sigmoid(self.gate_layer(concat)) # (B, 512)
        
        # 4. Fuse
        # u_fused = g * u_D + (1 - g) * u_Sw
        u_fused = g * u_d + (1 - g) * u_sw
        
        # 5. Classify
        logits = self.classifier(u_fused)
        
        return logits

# Initialize Model
model = GatedFusionSwAVModel(num_classes=10).to(device)
print(model)

Loading SwAV ResNet50 from torch.hub...


Using cache found in C:\Users\24912/.cache\torch\hub\facebookresearch_swav_main


Loading DINO ViT-S/16 from torch.hub...


Using cache found in C:\Users\24912/.cache\torch\hub\facebookresearch_dino_main


GatedFusionSwAVModel(
  (swav_backbone): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Seq

  return t.to(


In [4]:
# 4. Training Loop

criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)

def train_one_epoch(model, loader, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in tqdm(loader, desc="Training"):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
    return running_loss / total, 100. * correct / total

def evaluate(model, loader, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
    return running_loss / total, 100. * correct / total

# Run Training
num_epochs = 15
best_acc = 0.0

print("Starting Training...")
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, device)
    val_loss, val_acc = evaluate(model, val_loader, device)
    
    scheduler.step()
    
    print(f"Train Loss: {train_loss:.4f} | Acc: {train_acc:.2f}%")
    print(f"Val Loss: {val_loss:.4f} | Acc: {val_acc:.2f}%")
    
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), "best_gated_fusion_swav_model.pth")
        print("Saved Best Model!")

print(f"Training Complete. Best Val Acc: {best_acc:.2f}%")

Starting Training...
Epoch 1/15


Training: 100%|██████████| 50/50 [01:11<00:00,  1.43s/it]
Training: 100%|██████████| 50/50 [01:11<00:00,  1.43s/it]
Evaluating: 100%|██████████| 13/13 [00:06<00:00,  2.16it/s]



Train Loss: 2.5412 | Acc: 11.38%
Val Loss: 2.3160 | Acc: 10.00%
Saved Best Model!
Epoch 2/15


Training: 100%|██████████| 50/50 [01:09<00:00,  1.38s/it]
Training: 100%|██████████| 50/50 [01:09<00:00,  1.38s/it]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.19it/s]



Train Loss: 1.9372 | Acc: 32.06%
Val Loss: 0.9669 | Acc: 69.00%
Saved Best Model!
Epoch 3/15


Training: 100%|██████████| 50/50 [01:09<00:00,  1.40s/it]
Training: 100%|██████████| 50/50 [01:09<00:00,  1.40s/it]
Evaluating: 100%|██████████| 13/13 [00:06<00:00,  2.13it/s]



Train Loss: 1.0136 | Acc: 65.88%
Val Loss: 0.4331 | Acc: 86.25%
Saved Best Model!
Epoch 4/15
Saved Best Model!
Epoch 4/15


Training: 100%|██████████| 50/50 [01:09<00:00,  1.39s/it]
Training: 100%|██████████| 50/50 [01:09<00:00,  1.39s/it]
Evaluating: 100%|██████████| 13/13 [00:06<00:00,  2.16it/s]



Train Loss: 0.6955 | Acc: 76.25%
Val Loss: 0.3815 | Acc: 88.25%
Saved Best Model!
Epoch 5/15


Training: 100%|██████████| 50/50 [01:08<00:00,  1.38s/it]
Training: 100%|██████████| 50/50 [01:08<00:00,  1.38s/it]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.18it/s]



Train Loss: 0.5474 | Acc: 80.69%
Val Loss: 0.2437 | Acc: 91.75%
Saved Best Model!
Epoch 6/15


Training: 100%|██████████| 50/50 [01:10<00:00,  1.41s/it]
Training: 100%|██████████| 50/50 [01:10<00:00,  1.41s/it]
Evaluating: 100%|██████████| 13/13 [00:06<00:00,  2.05it/s]



Train Loss: 0.5263 | Acc: 82.75%
Val Loss: 0.1620 | Acc: 94.50%
Saved Best Model!
Epoch 7/15
Saved Best Model!
Epoch 7/15


Training: 100%|██████████| 50/50 [01:10<00:00,  1.42s/it]
Training: 100%|██████████| 50/50 [01:10<00:00,  1.42s/it]
Evaluating: 100%|██████████| 13/13 [00:06<00:00,  2.06it/s]
Evaluating: 100%|██████████| 13/13 [00:06<00:00,  2.06it/s]


Train Loss: 0.4554 | Acc: 83.50%
Val Loss: 0.1759 | Acc: 93.50%
Epoch 8/15


Training: 100%|██████████| 50/50 [01:10<00:00,  1.41s/it]
Training: 100%|██████████| 50/50 [01:10<00:00,  1.41s/it]
Evaluating: 100%|██████████| 13/13 [00:06<00:00,  2.04it/s]



Train Loss: 0.4110 | Acc: 85.94%
Val Loss: 0.1710 | Acc: 95.25%
Saved Best Model!
Epoch 9/15
Saved Best Model!
Epoch 9/15


Training: 100%|██████████| 50/50 [01:10<00:00,  1.41s/it]
Training: 100%|██████████| 50/50 [01:10<00:00,  1.41s/it]
Evaluating: 100%|██████████| 13/13 [00:06<00:00,  2.05it/s]



Train Loss: 0.4359 | Acc: 84.81%
Val Loss: 0.1412 | Acc: 96.00%
Saved Best Model!
Epoch 10/15


Training: 100%|██████████| 50/50 [01:10<00:00,  1.41s/it]
Training: 100%|██████████| 50/50 [01:10<00:00,  1.41s/it]
Evaluating: 100%|██████████| 13/13 [00:06<00:00,  2.03it/s]
Evaluating: 100%|██████████| 13/13 [00:06<00:00,  2.03it/s]


Train Loss: 0.3645 | Acc: 87.44%
Val Loss: 0.1525 | Acc: 95.00%
Epoch 11/15


Training: 100%|██████████| 50/50 [01:01<00:00,  1.23s/it]
Training: 100%|██████████| 50/50 [01:01<00:00,  1.23s/it]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.23it/s]



Train Loss: 0.3608 | Acc: 87.44%
Val Loss: 0.0912 | Acc: 97.75%
Saved Best Model!
Epoch 12/15


Training: 100%|██████████| 50/50 [01:01<00:00,  1.22s/it]
Training: 100%|██████████| 50/50 [01:01<00:00,  1.22s/it]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.25it/s]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.25it/s]


Train Loss: 0.3706 | Acc: 86.50%
Val Loss: 0.0948 | Acc: 97.25%
Epoch 13/15


Training: 100%|██████████| 50/50 [01:01<00:00,  1.23s/it]
Training: 100%|██████████| 50/50 [01:01<00:00,  1.23s/it]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.23it/s]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.23it/s]


Train Loss: 0.3727 | Acc: 87.62%
Val Loss: 0.0870 | Acc: 97.25%
Epoch 14/15


Training: 100%|██████████| 50/50 [01:01<00:00,  1.23s/it]
Training: 100%|██████████| 50/50 [01:01<00:00,  1.23s/it]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.25it/s]



Train Loss: 0.3603 | Acc: 86.62%
Val Loss: 0.0776 | Acc: 98.00%
Saved Best Model!
Epoch 15/15


Training: 100%|██████████| 50/50 [01:01<00:00,  1.23s/it]
Training: 100%|██████████| 50/50 [01:01<00:00,  1.23s/it]
Evaluating: 100%|██████████| 13/13 [00:05<00:00,  2.26it/s]

Train Loss: 0.3551 | Acc: 87.94%
Val Loss: 0.0779 | Acc: 97.50%
Training Complete. Best Val Acc: 98.00%





In [5]:
# 5. Final Evaluation
from sklearn.metrics import f1_score
import numpy as np

if os.path.exists("best_gated_fusion_swav_model.pth"):
    model.load_state_dict(torch.load("best_gated_fusion_swav_model.pth"))
    print("Loaded best model.")

model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Testing"):
        images = images.to(device)
        outputs = model(images)
        _, preds = outputs.max(1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.numpy())

test_acc = np.mean(np.array(all_preds) == np.array(all_labels)) * 100
test_f1 = f1_score(all_labels, all_preds, average='macro')

print("\n" + "="*30)
print(f"FINAL RESULTS (DINO + SwAV Gated Fusion)")
print("="*30)
print(f"{'Metric':<15} | {'Value':<10}")
print("-" * 30)
print(f"{'Top-1 Acc (%)':<15} | {test_acc:.2f}")
print(f"{'Macro-F1':<15} | {test_f1:.4f}")
print("="*30)

  model.load_state_dict(torch.load("best_gated_fusion_swav_model.pth"))


Loaded best model.


Testing: 100%|██████████| 32/32 [00:14<00:00,  2.15it/s]


FINAL RESULTS (DINO + SwAV Gated Fusion)
Metric          | Value     
------------------------------
Top-1 Acc (%)   | 85.70
Macro-F1        | 0.8599



