In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision import models

In [None]:
def load_pretrained_encoder(checkpoint_path, backbone='resnet50'):
    """Load pretrained encoder for downstream tasks"""
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    
    # Initialize encoder
    if backbone == 'resnet50':
        encoder = models.resnet50(pretrained=False)
        encoder.fc = nn.Identity()
    elif backbone == 'resnet18':
        encoder = models.resnet18(pretrained=False)
        encoder.fc = nn.Identity()
    
    # Load weights
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    
    return encoder

def create_linear_classifier(encoder, num_classes, freeze_encoder=True):
    """Create a linear classifier on top of the pretrained encoder"""
    if freeze_encoder:
        for param in encoder.parameters():
            param.requires_grad = False
    
    # Get encoder output dimension
    if 'resnet50' in encoder:
        encoder_dim = 2048
    elif 'resnet18' in encoder:
        encoder_dim = 512
        print("ResNet 18 initialized")
    else:
        encoder_dim = 512  # default
    
    classifier = nn.Sequential(
        encoder,
        nn.Linear(encoder_dim, num_classes)
    )
    
    return classifier

In [3]:
class ProjectionMLP(nn.Module):
    """Projection MLP for SimSiam"""
    def __init__(self, input_dim=2048, hidden_dim=2048, output_dim=2048):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True)
        )
        self.layer2 = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True)
        )
        self.layer3 = nn.Sequential(
            nn.Linear(hidden_dim, output_dim),
            nn.BatchNorm1d(output_dim, affine=False)  # No bias/scale in final BN
        )
    
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x

class PredictionMLP(nn.Module):
    """Prediction MLP for SimSiam"""
    def __init__(self, input_dim=2048, hidden_dim=512, output_dim=2048):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True)
        )
        self.layer2 = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        return x

class SimSiam(nn.Module):
    """
    SimSiam model implementation
    """
    def __init__(self, backbone='resnet50', proj_dim=2048, pred_dim=512):
        super().__init__()
        
        # Backbone encoder
        if backbone == 'resnet50':
            self.encoder = models.resnet50(pretrained=False)
            self.encoder.fc = nn.Identity()  # Remove classification head
            encoder_dim = 2048
        elif backbone == 'resnet18':
            self.encoder = models.resnet18(pretrained=False)
            self.encoder.fc = nn.Identity()
            encoder_dim = 512
        else:
            raise ValueError(f"Backbone {backbone} not supported")
        
        # Projection head
        self.projector = ProjectionMLP(encoder_dim, proj_dim, proj_dim)
        
        # Prediction head
        self.predictor = PredictionMLP(proj_dim, pred_dim, proj_dim)
    
    def forward(self, x1, x2):
        # Encode both views
        z1 = self.projector(self.encoder(x1))
        z2 = self.projector(self.encoder(x2))
        
        # Predict
        p1 = self.predictor(z1)
        p2 = self.predictor(z2)
        
        return p1, p2, z1.detach(), z2.detach()

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimSiam(backbone='resnet18').to(device)



In [5]:
# --- Load pretrained encoder ---
encoder = load_pretrained_encoder('/home/asharab/Documents/Masters/SSL_PROJECT/Model/checkpoints_new_backbone/simsiam_encoder.pth', backbone='resnet18')
model = create_linear_classifier(encoder, num_classes=100, freeze_encoder=True).to('cuda')

# --- CIFAR-100 Dataloaders ---
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(p=0.5),          # Random horizontal flip
    transforms.RandomCrop(224, padding=4),           # Random crop with padding
    transforms.ColorJitter(brightness=0.2, 
                           contrast=0.2, 
                           saturation=0.2, 
                           hue=0.1),                   # Random color jitter
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

train_set = datasets.CIFAR100(root='./data', train=True, transform=transform, download=False)
test_set = datasets.CIFAR100(root='./data', train=False, transform=transform, download=False)

train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=4)
test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=4)

# --- Training setup ---
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model[-1].parameters(), lr=0.001, weight_decay=1e-4)
#scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)



In [None]:
model

In [7]:
from torch.utils.tensorboard import SummaryWriter
import torch
from tqdm import tqdm

writer = SummaryWriter(log_dir='./runs/linear_eval_experiment')

# --- Train only linear head ---
for epoch in range(100):
    model.train()
    correct, total, total_loss = 0, 0, 0.0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1:03d}", leave=False)
    for x, y in pbar:
        x, y = x.to('cuda'), y.to('cuda')

        logits = model(x)
        loss = criterion(logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        preds = logits.argmax(dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)

        pbar.set_postfix({
            'Loss': f"{loss.item():.4f}",
            'Acc': f"{100 * correct / total:.2f}%"
        })

    #scheduler.step()
    acc = 100 * correct / total
    writer.add_scalar('Train/Loss', total_loss / len(train_loader), epoch + 1)
    writer.add_scalar('Train/Accuracy', acc, epoch + 1)

    print(f"Epoch {epoch+1:3d} | Loss: {total_loss:.4f} | Train Acc: {acc:.2f}%")

# --- Evaluation on Test Set ---
model.eval()
correct, total = 0, 0

pbar = tqdm(test_loader, desc="Evaluating", leave=False)
with torch.no_grad():
    for x, y in pbar:
        x, y = x.to('cuda'), y.to('cuda')
        preds = model(x).argmax(dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)

test_acc = 100 * correct / total
writer.add_scalar('Test/Accuracy', test_acc, 100)

print(f"\nFinal Linear Evaluation Accuracy on CIFAR-100: {test_acc:.2f}%")

# --- Save final model weights ---
torch.save(model.state_dict(), 'linear_head_final_weights.pth')
print("Final model weights saved as 'linear_head_final_weights.pth'")

writer.close()

                                                                                     

Epoch   1 | Loss: 1564.8314 | Train Acc: 10.25%


                                                                                     

Epoch   2 | Loss: 1525.6230 | Train Acc: 12.11%


                                                                                     

Epoch   3 | Loss: 1495.9007 | Train Acc: 13.37%


                                                                                     

Epoch   4 | Loss: 1472.7487 | Train Acc: 14.21%


                                                                                     

Epoch   5 | Loss: 1454.9549 | Train Acc: 15.45%


                                                                                     

Epoch   6 | Loss: 1438.6132 | Train Acc: 15.97%


                                                                                     

Epoch   7 | Loss: 1424.2438 | Train Acc: 16.53%


                                                                                     

Epoch   8 | Loss: 1411.9967 | Train Acc: 17.00%


                                                                                     

Epoch   9 | Loss: 1401.6240 | Train Acc: 17.41%


                                                                                     

Epoch  10 | Loss: 1391.3848 | Train Acc: 18.06%


                                                                                     

Epoch  11 | Loss: 1384.1990 | Train Acc: 18.37%


                                                                                     

Epoch  12 | Loss: 1375.7403 | Train Acc: 18.65%


                                                                                     

Epoch  13 | Loss: 1366.4100 | Train Acc: 19.23%


                                                                                     

Epoch  14 | Loss: 1359.1835 | Train Acc: 19.53%


                                                                                     

Epoch  15 | Loss: 1353.5501 | Train Acc: 19.88%


                                                                                     

Epoch  16 | Loss: 1349.3949 | Train Acc: 20.01%


                                                                                     

Epoch  17 | Loss: 1342.5256 | Train Acc: 20.16%


                                                                                     

Epoch  18 | Loss: 1337.0120 | Train Acc: 20.52%


                                                                                     

Epoch  19 | Loss: 1332.7876 | Train Acc: 20.47%


                                                                                     

Epoch  20 | Loss: 1327.3714 | Train Acc: 20.90%


                                                                                     

Epoch  21 | Loss: 1323.1409 | Train Acc: 21.29%


                                                                                     

Epoch  22 | Loss: 1318.9280 | Train Acc: 21.58%


                                                                                     

Epoch  23 | Loss: 1316.3019 | Train Acc: 21.40%


                                                                                     

Epoch  24 | Loss: 1310.9756 | Train Acc: 21.75%


                                                                                     

Epoch  25 | Loss: 1309.4923 | Train Acc: 21.86%


                                                                                     

Epoch  26 | Loss: 1305.3580 | Train Acc: 21.92%


                                                                                     

Epoch  27 | Loss: 1301.0556 | Train Acc: 22.31%


                                                                                     

Epoch  28 | Loss: 1298.3053 | Train Acc: 22.22%


                                                                                     

Epoch  29 | Loss: 1295.7632 | Train Acc: 22.57%


                                                                                     

Epoch  30 | Loss: 1293.0457 | Train Acc: 22.68%


                                                                                     

Epoch  31 | Loss: 1290.9124 | Train Acc: 22.59%


                                                                                     

Epoch  32 | Loss: 1287.9069 | Train Acc: 22.80%


                                                                                     

Epoch  33 | Loss: 1286.2482 | Train Acc: 22.90%


                                                                                     

Epoch  34 | Loss: 1281.9306 | Train Acc: 23.04%


                                                                                     

Epoch  35 | Loss: 1280.7855 | Train Acc: 22.87%


                                                                                     

Epoch  36 | Loss: 1279.7741 | Train Acc: 23.19%


                                                                                     

Epoch  37 | Loss: 1276.2310 | Train Acc: 23.37%


                                                                                     

Epoch  38 | Loss: 1275.1296 | Train Acc: 23.49%


                                                                                     

Epoch  39 | Loss: 1272.2063 | Train Acc: 23.46%


                                                                                     

Epoch  40 | Loss: 1269.1476 | Train Acc: 23.60%


                                                                                     

Epoch  41 | Loss: 1266.8187 | Train Acc: 23.77%


                                                                                     

Epoch  42 | Loss: 1265.7874 | Train Acc: 23.89%


                                                                                     

Epoch  43 | Loss: 1264.0325 | Train Acc: 23.86%


                                                                                     

Epoch  44 | Loss: 1261.2584 | Train Acc: 24.04%


                                                                                     

Epoch  45 | Loss: 1259.4628 | Train Acc: 24.27%


                                                                                     

Epoch  46 | Loss: 1258.9974 | Train Acc: 24.48%


                                                                                     

Epoch  47 | Loss: 1257.1762 | Train Acc: 24.37%


                                                                                     

Epoch  48 | Loss: 1255.4228 | Train Acc: 24.43%


                                                                                     

Epoch  49 | Loss: 1252.4486 | Train Acc: 24.41%


                                                                                     

Epoch  50 | Loss: 1254.3177 | Train Acc: 24.53%


                                                                                     

Epoch  51 | Loss: 1251.7317 | Train Acc: 24.25%


                                                                                     

Epoch  52 | Loss: 1250.7123 | Train Acc: 24.51%


                                                                                     

Epoch  53 | Loss: 1249.3769 | Train Acc: 24.75%


                                                                                     

Epoch  54 | Loss: 1247.9113 | Train Acc: 24.77%


                                                                                     

Epoch  55 | Loss: 1246.2712 | Train Acc: 24.83%


                                                                                     

Epoch  56 | Loss: 1244.2286 | Train Acc: 24.87%


                                                                                     

Epoch  57 | Loss: 1243.5862 | Train Acc: 25.04%


                                                                                     

Epoch  58 | Loss: 1242.1724 | Train Acc: 25.10%


                                                                                     

Epoch  59 | Loss: 1240.2042 | Train Acc: 25.11%


                                                                                     

Epoch  60 | Loss: 1241.0772 | Train Acc: 24.87%


                                                                                     

Epoch  61 | Loss: 1240.2349 | Train Acc: 25.28%


                                                                                     

Epoch  62 | Loss: 1237.2846 | Train Acc: 25.30%


                                                                                     

Epoch  63 | Loss: 1238.0135 | Train Acc: 25.37%


                                                                                     

Epoch  64 | Loss: 1235.3368 | Train Acc: 25.53%


                                                                                     

Epoch  65 | Loss: 1233.0140 | Train Acc: 25.42%


                                                                                     

Epoch  66 | Loss: 1234.7908 | Train Acc: 25.38%


                                                                                     

Epoch  67 | Loss: 1233.4680 | Train Acc: 25.12%


                                                                                     

Epoch  68 | Loss: 1231.3132 | Train Acc: 25.48%


                                                                                     

Epoch  69 | Loss: 1231.2526 | Train Acc: 25.49%


                                                                                     

Epoch  70 | Loss: 1230.4300 | Train Acc: 25.70%


                                                                                     

Epoch  71 | Loss: 1230.4892 | Train Acc: 25.62%


                                                                                     

Epoch  72 | Loss: 1227.0666 | Train Acc: 25.83%


                                                                                     

Epoch  73 | Loss: 1228.3751 | Train Acc: 25.71%


                                                                                     

Epoch  74 | Loss: 1228.6901 | Train Acc: 25.50%


                                                                                     

Epoch  75 | Loss: 1225.1891 | Train Acc: 25.81%


                                                                                     

Epoch  76 | Loss: 1224.7026 | Train Acc: 25.76%


                                                                                     

Epoch  77 | Loss: 1223.4743 | Train Acc: 26.02%


                                                                                     

Epoch  78 | Loss: 1222.8716 | Train Acc: 25.90%


                                                                                     

Epoch  79 | Loss: 1223.0687 | Train Acc: 26.02%


                                                                                     

Epoch  80 | Loss: 1221.0755 | Train Acc: 26.04%


                                                                                     

Epoch  81 | Loss: 1219.0473 | Train Acc: 26.27%


                                                                                     

Epoch  82 | Loss: 1221.9821 | Train Acc: 26.07%


                                                                                     

Epoch  83 | Loss: 1220.2295 | Train Acc: 26.19%


                                                                                     

Epoch  84 | Loss: 1219.8065 | Train Acc: 26.13%


                                                                                     

Epoch  85 | Loss: 1217.9880 | Train Acc: 26.18%


                                                                                     

Epoch  86 | Loss: 1217.4365 | Train Acc: 26.26%


                                                                                     

Epoch  87 | Loss: 1217.1754 | Train Acc: 26.23%


                                                                                     

Epoch  88 | Loss: 1217.0381 | Train Acc: 26.14%


                                                                                     

Epoch  89 | Loss: 1216.3169 | Train Acc: 26.30%


                                                                                     

Epoch  90 | Loss: 1214.9948 | Train Acc: 26.39%


                                                                                     

Epoch  91 | Loss: 1216.8765 | Train Acc: 26.19%


                                                                                     

Epoch  92 | Loss: 1214.2124 | Train Acc: 26.31%


                                                                                     

Epoch  93 | Loss: 1214.2456 | Train Acc: 26.31%


                                                                                     

Epoch  94 | Loss: 1213.4153 | Train Acc: 26.53%


                                                                                     

Epoch  95 | Loss: 1213.3175 | Train Acc: 26.37%


                                                                                     

Epoch  96 | Loss: 1212.5822 | Train Acc: 26.63%


                                                                                     

Epoch  97 | Loss: 1211.8219 | Train Acc: 26.55%


                                                                                     

Epoch  98 | Loss: 1208.5280 | Train Acc: 26.82%


                                                                                     

Epoch  99 | Loss: 1210.8917 | Train Acc: 26.40%


                                                                                     

Epoch 100 | Loss: 1210.1002 | Train Acc: 26.71%


                                                           


Final Linear Evaluation Accuracy on CIFAR-100: 24.62%
Final model weights saved as 'linear_head_final_weights.pth'
