In [None]:
# import necessary dependencies
import argparse
import os, sys
import time
import datetime
from tqdm import tqdm_notebook as tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd

In [None]:
# specify the device for computation
#############################################
# your code here
import torch


device = torch.device("cuda")
print(device)
!nvidia-smi
# import torch
# print(torch.__version__)  # Should show PyTorch version
# print(torch.cuda.is_available())  # Should return True
# print(torch.version.cuda)  # Should return 11.8
# print(torch.cuda.get_device_name(0))  # Should return your GPU name

#############################################

# Model Setup

In [None]:
import torch
import torch.nn as nn
from torchvision.models import resnet18
from torchvision import models

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        weights = models.ResNet18_Weights.IMAGENET1K_V1  # or ResNet50_Weights.DEFAULT
        resnet = models.resnet18(weights=weights)

        # Modify first conv and remove maxpool for small CIFAR-10 images
        resnet.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        resnet.maxpool = nn.Identity()

        # Remove final FC layer, keep everything else
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])  # [B, 512, 1, 1]

    def forward(self, x):
        x = self.backbone(x)          # [B, 512, 1, 1]
        return torch.flatten(x, 1)  # Flatten → [B, 512]

In [None]:
class ProjectionHead(nn.Module):
    def __init__(self, in_dim=512, hidden_dim=512, out_dim=64):
        super().__init__()
        self.net = nn.Sequential(
          nn.Linear(in_dim, hidden_dim),
          nn.BatchNorm1d(hidden_dim),
          nn.ReLU(inplace=True),
          nn.Linear(hidden_dim, hidden_dim),
          nn.BatchNorm1d(hidden_dim),
          nn.ReLU(inplace=True),
          nn.Linear(hidden_dim, out_dim)
      )

    def forward(self, x):
        return self.net(x)

In [None]:
class RotationHead(nn.Module):
    def __init__(self, in_dim=512, num_classes=4):
        super().__init__()
        self.fc = nn.Linear(in_dim, num_classes)

    def forward(self, x):
        return self.fc(x)

In [None]:
class SimCLRRotNetModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.projection_head = ProjectionHead()
        self.rotation_head = RotationHead()

    def forward(self, x, task='simclr'):
        features = self.encoder(x)
        if task == 'simclr':
            return self.projection_head(features)
        elif task == 'rotnet':
            return self.rotation_head(features)
        else:
            raise ValueError("Task must be 'simclr' or 'rotnet'")

In [None]:
def test_encoder():
    model = Encoder()
    dummy_input = torch.randn(2, 3, 32, 32)  # CIFAR-10 shaped input
    output = model(dummy_input)
    assert output.shape[0] == 2, "Batch size mismatch"
    print(f"✅ Encoder output shape: {output.shape}")

def test_projection_head():
    head = ProjectionHead(in_dim=512, out_dim=128)
    dummy_feat = torch.randn(2, 512)
    output = head(dummy_feat)
    assert output.shape == (2, 128), "Projection head output mismatch"
    print(f"✅ ProjectionHead output shape: {output.shape}")

def test_rotation_head():
    head = RotationHead(in_dim=512, num_classes=4)
    dummy_feat = torch.randn(2, 512)
    output = head(dummy_feat)
    assert output.shape == (2, 4), "Rotation head output mismatch"
    print(f"✅ RotationHead output shape: {output.shape}")

def test_full_model():
    model = SimCLRRotNetModel()
    dummy_input = torch.randn(2, 3, 32, 32)

    # Test SimCLR head
    out_simclr = model(dummy_input, task='simclr')
    assert out_simclr.shape[0] == 2, "SimCLR batch size mismatch"
    print(f"✅ SimCLR head output: {out_simclr.shape}")

    # Test RotNet head
    out_rotnet = model(dummy_input, task='rotnet')
    assert out_rotnet.shape == (2, 4), "RotNet output mismatch"
    print(f"✅ RotNet head output: {out_rotnet.shape}")

if __name__ == "__main__":
    test_encoder()
    test_projection_head()
    test_rotation_head()
    test_full_model()

## Data Preprocessing and Loading

In [None]:
import torchvision.transforms as T

simclr_transform = T.Compose([
    T.RandomResizedCrop(size=32, scale=(0.5, 1.0)),  # lighter resize crop
    T.RandomHorizontalFlip(p=0.5),
    T.ColorJitter(0.5, 0.5, 0.5, 0.2),  # strong but lighter
    T.RandomGrayscale(p=0.2),
    T.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),  # lighter blur for 32x32
    T.ToTensor(),
    T.Normalize(mean=(0.4914, 0.4822, 0.4465),  std=(0.2023, 0.1994, 0.2010))
])

rotnet_base_transform = T.ToTensor()  # no spatial transforms

In [None]:
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

train_set = CIFAR10(root='./data', train=True, download=True, transform=simclr_transform)
train_loader = DataLoader(train_set, batch_size=256, shuffle=True, num_workers=2, drop_last=True)


In [None]:
def rotate_batch(x):
    rotations = []
    labels = []
    for k in range(4):
        rotated = torch.rot90(x, k=k, dims=[2, 3])
        rotations.append(rotated)
        labels += [k] * x.size(0)
    return torch.cat(rotations), torch.tensor(labels, device=x.device)

In [None]:
from torchvision.datasets import CIFAR10

class SimCLRRotNetDataset(torch.utils.data.Dataset):
    def __init__(self, base_dataset, simclr_transform, rotnet_transform):
        self.base_dataset = base_dataset
        self.simclr_transform = simclr_transform
        self.rotnet_transform = rotnet_transform

    def __getitem__(self, index):
        img, _ = self.base_dataset[index]

        # SimCLR views
        x_i = self.simclr_transform(img)
        x_j = self.simclr_transform(img)

        # RotNet base (apply rotations in batch later)
        x_rot = self.rotnet_transform(img)

        return x_i, x_j, x_rot

    def __len__(self):
        return len(self.base_dataset)

In [None]:
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

base_dataset = CIFAR10(root='./data', train=True, download=True)
multi_view_dataset = SimCLRRotNetDataset(base_dataset, simclr_transform, rotnet_base_transform)

train_loader = DataLoader(multi_view_dataset, batch_size=256, shuffle=True, num_workers=2, drop_last=True)

## Training Setup and Loop

In [None]:
# SimCLR Loss Function
import torch.nn.functional as F

def nt_xent_loss(z_i, z_j, temperature=0.5):
    z_i = F.normalize(z_i, dim=1)
    z_j = F.normalize(z_j, dim=1)
    representations = torch.cat([z_i, z_j], dim=0)  # [2N, D]

    similarity_matrix = torch.matmul(representations, representations.T)  # [2N, 2N]
    batch_size = z_i.size(0)
    labels = torch.arange(batch_size, device=z_i.device)
    labels = torch.cat([labels, labels], dim=0)

    # Mask out self-similarity
    mask = torch.eye(2 * batch_size, device=z_i.device).bool()
    similarity_matrix = similarity_matrix[~mask].view(2 * batch_size, -1)

    positives = torch.sum(z_i * z_j, dim=1)
    positives = torch.cat([positives, positives], dim=0)

    logits = similarity_matrix / temperature
    loss = F.cross_entropy(logits, labels)
    return loss

# RotNet Loss Function
rotation_loss_fn = nn.CrossEntropyLoss()

In [None]:
pip install warmup_scheduler

In [None]:
import os
import time
import torch
from torch.amp import autocast, GradScaler
from torch.optim.lr_scheduler import CosineAnnealingLR
from warmup_scheduler import GradualWarmupScheduler

CHECKPOINT_FOLDER = '/content/drive/My Drive/ECE661/Project/hybrid_model/training_models/cosine_warmup_scheduler'
LEARNING_RATE = 0.05
WARMUP_EPOCHS = 10
NUM_EPOCHS = 200
TEMPERATURE = 0.2
scaler = GradScaler(device='cuda')
start_epoch = 0

# 🧠 Define model and optimizer before loading checkpoint
model = SimCLRRotNetModel().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9, weight_decay=1e-4)

# Set up schedulers
scheduler = CosineAnnealingLR(optimizer, T_max=5, eta_min=0.001)
gradual_warmup = GradualWarmupScheduler(optimizer, multiplier=10, total_epoch=WARMUP_EPOCHS, after_scheduler=scheduler)

# 🔁 Try to resume from a checkpoint
for epoch in range(200, -10, -10):
    ckpt_name = f"SimCLR_RotNet_epoch{epoch}.pth"
    ckpt_path = os.path.join(CHECKPOINT_FOLDER, ckpt_name)
    if os.path.exists(ckpt_path):
        checkpoint = torch.load(ckpt_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scaler.load_state_dict(checkpoint['scaler_state_dict'])
        start_epoch = checkpoint['epoch']
        scheduler.last_epoch = start_epoch - 1
        print(f"🔁 Resumed from checkpoint: {ckpt_path} (epoch {start_epoch})")
        break
    else:
        print(f"No checkpoint found at {epoch}. Starting training from scratch.")

# # ⏩ Step scheduler forward to match resumed epoch
# scheduler.step()

# 🔁 Training loop
for epoch in range(start_epoch, NUM_EPOCHS):
    model.train()
    total_loss = 0
    total_simclr_loss = 0
    total_rotnet_loss = 0
    num_batches = len(train_loader)
    start_time = time.time()

    for step, (x_i, x_j, x_rot) in enumerate(train_loader, 1):
        x_i, x_j, x_rot = x_i.to(device), x_j.to(device), x_rot.to(device)
        optimizer.zero_grad()

        with autocast(device_type='cuda'):
            z_i = model(x_i, task='simclr')
            z_j = model(x_j, task='simclr')
            loss_simclr = nt_xent_loss(z_i, z_j, temperature=TEMPERATURE)

            x_rot_batch, rot_labels = rotate_batch(x_rot)
            rot_logits = model(x_rot_batch, task='rotnet')
            loss_rotnet = rotation_loss_fn(rot_logits, rot_labels)

            loss = loss_simclr + ROT_LOSS_WEIGHT * loss_rotnet

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()
        total_simclr_loss += loss_simclr.item()
        total_rotnet_loss += loss_rotnet.item()

        if step % 20 == 0 or step == num_batches:
            print(f"[Step {step}/{num_batches}] Total: {loss.item():.4f} | SimCLR: {loss_simclr.item():.4f} | RotNet: {loss_rotnet.item():.4f}")

    avg_total_loss = total_loss / num_batches
    avg_simclr_loss = total_simclr_loss / num_batches
    avg_rotnet_loss = total_rotnet_loss / num_batches
    epoch_time = time.time() - start_time

    scheduler.step()  # 🔁 Update learning rate
    current_lr = scheduler.get_last_lr()[0]

    print(f"[Epoch {epoch+1}/{NUM_EPOCHS}] Avg Total: {avg_total_loss:.4f} | "
          f"SimCLR: {avg_simclr_loss:.4f} | "
          f"RotNet: {avg_rotnet_loss:.4f} | ⏱ Time: {epoch_time:.2f}s\n"
          f"Learning Rate (Epoch {epoch+1}): {current_lr:.6f}\n")

    # 💾 Save checkpoint every 10 epochs
    if (epoch + 1) % 10 == 0:
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scaler_state_dict': scaler.state_dict(),
            'avg_total_loss': avg_total_loss,
            'avg_simclr_loss': avg_simclr_loss,
            'avg_rotnet_loss': avg_rotnet_loss
        }
        save_path = os.path.join(CHECKPOINT_FOLDER, f"SimCLR_RotNet_epoch{epoch+1}.pth")
        os.makedirs(CHECKPOINT_FOLDER, exist_ok=True)
        torch.save(checkpoint, save_path)
        print(f"✅ Saved checkpoint and losses at: {save_path}")

# Updating SimCLR+RotNet Training Loop
### Reduce RotNet Weight to 0.2, Use only Cosine Scheduler, Freeze Rotation Head

In [None]:
import os
import time

def train_rotnet_simclr(
    model,
    optimizer,
    scaler,
    train_loader,
    rotate_batch,
    rotation_loss_fn,
    nt_xent_loss,
    gradual_warmup,
    device,
    BATCH_SIZE,
    CHECKPOINT_FOLDER,
    NUM_EPOCHS,
    TEMPERATURE
):
  # 🔁 Try to resume from a checkpoint
  for epoch in range(100, -10, -10):
      ckpt_name = f"SimCLR_RotNet_bs{BATCH_SIZE}_epoch{epoch}.pth"
      ckpt_path = os.path.join(CHECKPOINT_FOLDER, ckpt_name)
      if os.path.exists(ckpt_path):
          checkpoint = torch.load(ckpt_path, map_location=device)
          model.load_state_dict(checkpoint['model_state_dict'])
          optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
          scaler.load_state_dict(checkpoint['scaler_state_dict'])
          start_epoch = checkpoint['epoch']
          gradual_warmup.last_epoch = start_epoch - 1
          print(f"Resumed from checkpoint: {ckpt_path} (epoch {start_epoch})")
          break
      else:
          print(f"No checkpoint found at {epoch}. Starting training from scratch.")
          start_epoch=0

  if start_epoch >= 25:
      for param in model.rotation_head.parameters():
          param.requires_grad = False

  # 🔁 Training loop
  print(f"Begin Training for Batch Size={BATCH_SIZE}")
  for epoch in range(start_epoch, NUM_EPOCHS):
      model.train()
      total_loss = 0
      total_simclr_loss = 0
      total_rotnet_loss = 0
      num_batches = len(train_loader)
      start_time = time.time()

      if epoch < 25:
        current_rot_weight = 0.5 * (1 - epoch / 25)
        compute_rot = True
      else:
        compute_rot = False
        current_rot_weight = 0.0
        for param in model.rotation_head.parameters():
          param.requires_grad = False

      for step, (x_i, x_j, x_rot) in enumerate(train_loader, 1):
          x_i, x_j, x_rot = x_i.to(device), x_j.to(device), x_rot.to(device)
          optimizer.zero_grad()

          with autocast(device_type='cuda'):
              z_i = model(x_i, task='simclr')
              z_j = model(x_j, task='simclr')
              loss_simclr = nt_xent_loss(z_i, z_j, temperature=TEMPERATURE)

              if compute_rot:
                x_rot_batch, rot_labels = rotate_batch(x_rot)
                rot_logits = model(x_rot_batch, task='rotnet')
                loss_rotnet = rotation_loss_fn(rot_logits, rot_labels)
              else:
                  loss_rotnet = torch.tensor(0.0, device=device)

              loss = loss_simclr + current_rot_weight * loss_rotnet

          scaler.scale(loss).backward()
          scaler.step(optimizer)
          scaler.update()

          total_loss += loss.item()
          total_simclr_loss += loss_simclr.item()
          total_rotnet_loss += loss_rotnet.item()

          if step % 20 == 0 or step == num_batches:
              print(f"[Step {step}/{num_batches}] Total: {loss.item():.4f} | SimCLR: {loss_simclr.item():.4f} | RotNet: {loss_rotnet.item():.4f}")

      avg_total_loss = total_loss / num_batches
      avg_simclr_loss = total_simclr_loss / num_batches
      avg_rotnet_loss = total_rotnet_loss / num_batches
      epoch_time = time.time() - start_time

      gradual_warmup.step()  # 🔁 Update learning rate
      current_lr = optimizer.param_groups[0]['lr']

      print(f"[Epoch {epoch+1}/{NUM_EPOCHS}] Avg Total: {avg_total_loss:.4f} | "
            f"SimCLR: {avg_simclr_loss:.4f} | "
            f"RotNet: {avg_rotnet_loss:.4f} | Time: {epoch_time:.2f}s\n"
            f"Learning Rate (Epoch {epoch+1}): {current_lr:.6f}\n")

      # Save checkpoint every 10 epochs
      if (epoch + 1) % 10 == 0:
          checkpoint = {
              'epoch': epoch + 1,
              'model_state_dict': model.state_dict(),
              'optimizer_state_dict': optimizer.state_dict(),
              'scaler_state_dict': scaler.state_dict(),
              'avg_total_loss': avg_total_loss,
              'avg_simclr_loss': avg_simclr_loss,
              'avg_rotnet_loss': avg_rotnet_loss
          }
          save_path = os.path.join(CHECKPOINT_FOLDER, f"SimCLR_RotNet_bs{BATCH_SIZE}_epoch{epoch+1}.pth")
          os.makedirs(CHECKPOINT_FOLDER, exist_ok=True)
          torch.save(checkpoint, save_path)
          print(f"✅ Saved checkpoint and losses at: {save_path}")

In [None]:
# SETUP
import torch
from torch.amp import autocast, GradScaler
from torch.optim.lr_scheduler import CosineAnnealingLR
from warmup_scheduler import GradualWarmupScheduler

CHECKPOINT_FOLDER = '/content/drive/My Drive/ECE661/Project/hybrid_model/training_models/stronger_augmentation'
LEARNING_RATE = 0.025
ETA_MIN = 0.0125
WARMUP_EPOCHS = 10
NUM_EPOCHS = 100
scaler = GradScaler(device='cuda')
start_epoch = 0
WEIGHT_DECAY = 1e-4
TEMPERATURE = 0.2 # Try 0.2, makes contrastive learning more sharp

# 🧠 Define model and optimizer before loading checkpoint
model = SimCLRRotNetModel().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9, weight_decay=1e-4)

# Unfreeze encoder
for param in model.encoder.parameters():
    assert param.requires_grad == True

# 🔒 Freeze BatchNorm layers in the encoder
for m in model.encoder.modules():
    if isinstance(m, torch.nn.BatchNorm2d):
        m.eval()

# Cosine Annealing after warmup: from 0.025 → 0.0125
cosine_scheduler = CosineAnnealingLR(
    optimizer,
    T_max=100,
    eta_min=ETA_MIN
)

# Warmup scheduler: ramps up from 0.0025 → 0.025 over 10 epochs, then hands off to cosine
gradual_warmup = GradualWarmupScheduler(
    optimizer,
    multiplier=1.0,  # keep LR at base (0.025) after warmup
    total_epoch=WARMUP_EPOCHS,
    after_scheduler=cosine_scheduler
)

# Train Model on Batch Size = 64

In [None]:
# Test
BATCH_SIZE = 256
CHECKPOINT_FOLDER = '/content/drive/My Drive/ECE661/Project/hybrid_model/training_models/stronger_augmentation'

os.listdir(CHECKPOINT_FOLDER)

In [None]:
# Batch Size = 64
BATCH_SIZE = 64
CHECKPOINT_FOLDER = '/content/drive/My Drive/ECE661/Project/hybrid_model/training_models/stronger_augmentation'
train_rotnet_simclr(
    model=model,
    optimizer=optimizer,
    scaler=scaler,
    train_loader=train_loader,  # make sure you have defined this elsewhere
    rotate_batch=rotate_batch,  # function that returns rotated batch and labels
    rotation_loss_fn=rotation_loss_fn,  # e.g., nn.CrossEntropyLoss()
    nt_xent_loss=nt_xent_loss,  # your contrastive loss function
    gradual_warmup=gradual_warmup,
    device='cuda',
    BATCH_SIZE=BATCH_SIZE,
    CHECKPOINT_FOLDER=CHECKPOINT_FOLDER,
    NUM_EPOCHS=100,
    TEMPERATURE=TEMPERATURE
)

In [None]:
# Batch Size = 128
BATCH_SIZE = 128
CHECKPOINT_FOLDER = '/content/drive/My Drive/ECE661/Project/hybrid_model/training_models/stronger_augmentation'
train_rotnet_simclr(
    model=model,
    optimizer=optimizer,
    scaler=scaler,
    train_loader=train_loader,  # make sure you have defined this elsewhere
    rotate_batch=rotate_batch,  # function that returns rotated batch and labels
    rotation_loss_fn=rotation_loss_fn,  # e.g., nn.CrossEntropyLoss()
    nt_xent_loss=nt_xent_loss,  # your contrastive loss function
    gradual_warmup=gradual_warmup,
    device='cuda',
    BATCH_SIZE=BATCH_SIZE,
    CHECKPOINT_FOLDER=CHECKPOINT_FOLDER,
    NUM_EPOCHS=100,
    TEMPERATURE=TEMPERATURE
)

# Linear Evaluation Loop

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import time

def train_linear_classifier(model_name, encoder, train_features, train_labels, test_features, test_labels,
                            batch_size=256, num_epochs=20, lr=0.1):
    """
    Trains a linear classifier on frozen features.
    """
    print(f"[{model_name}] Evaluating...")

    # Freeze encoder
    encoder.eval()
    for param in encoder.parameters():
        param.requires_grad = False

    # Create classifier and move to device
    classifier = nn.Linear(train_features.shape[1], 10).to(device)

    # Optimizer and loss
    optimizer = optim.SGD(classifier.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss()

    # DataLoaders
    train_loader = DataLoader(TensorDataset(train_features, train_labels), batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(TensorDataset(test_features, test_labels), batch_size=batch_size)

    # Training loop
    classifier.train()
    for epoch in range(1, num_epochs + 1):
        epoch_loss = 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)

            optimizer.zero_grad()
            logits = classifier(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
    # Evaluation
    classifier.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            logits = classifier(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)

    acc = correct / total
    print(f"[{model_name}] Final Test Accuracy: {acc * 100:.2f}%\n")
    return acc


In [None]:
from torch.amp import autocast
def getEncoder(model):
    return model.encoder.eval()

def extractFeatures(dataloader, encoder):
    encoder.eval()
    features, labels = [], []

    with torch.no_grad():
        for images, targets in dataloader:
            images = images.to(device, non_blocking=True)
            with autocast(device_type="cuda"):  # Mixed precision context
                outputs = encoder(images)

            outputs = outputs.float().cpu()  # Convert back to full float32, then move to CPU
            features.append(outputs)
            labels.append(targets)

    return torch.cat(features), torch.cat(labels)

In [None]:
from torch.utils.data import DataLoader
import torchvision


DATA_ROOT = '/content/data'
transform_eval = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
])

trainDataset = torchvision.datasets.CIFAR10(root=DATA_ROOT, train=True, download=True, transform=transform_eval)
testDataset = torchvision.datasets.CIFAR10(root=DATA_ROOT, train=False, download=True, transform=transform_eval)
train_loader_eval = DataLoader(trainDataset, batch_size=256, shuffle=False, num_workers=2)
test_loader_eval = DataLoader(testDataset, batch_size=256, shuffle=False, num_workers=2)

In [None]:
import os
import pandas as pd
import numpy as np
import torch  # Make sure torch is imported

BATCH_SIZES = [256]
EPOCHS = np.arange(20, 201, 20)

def linear_eval(train_loader_eval, test_loader_eval, percent_labeled, CHECKPOINT_FOLDER):
    RESULTS_FOLDER = '/content/drive/My Drive/ECE661/Project/hybrid_model/linear_evals'
    os.makedirs(RESULTS_FOLDER, exist_ok=True)  # Ensure results folder exists

    all_results = []

    for bs in BATCH_SIZES:
        print("=" * 150)
        print(f"Starting Linear Evaluation for Batch Size {bs}...\n")
        for epoch in EPOCHS:
            # 1. Load Pretrained Model Checkpoint
            checkpoint_name = f"SimCLR_RotNet_bs{bs}_epoch{epoch}.pth"
            checkpoint_path = os.path.join(
                CHECKPOINT_FOLDER,
                f"SimCLR_RotNet_bs{bs}_epoch{epoch}.pth"
            )
            if not os.path.exists(checkpoint_path):
                print(f"⚠️ Checkpoint not found at {checkpoint_path}")
                continue

            checkpoint = torch.load(checkpoint_path, map_location=device)
            model = SimCLRRotNetModel().to(device)
            model.load_state_dict(checkpoint['model_state_dict'])

            # 2. Extract and Freeze Encoder
            encoder = model.encoder.to(device)
            encoder.eval()
            for param in encoder.parameters():
                param.requires_grad = False

            # 3. Extract features
            train_features, train_labels = extractFeatures(train_loader_eval, encoder)
            test_features, test_labels = extractFeatures(test_loader_eval, encoder)

            # 4. Train linear classifier
            accuracy = train_linear_classifier(
                checkpoint_name,
                encoder,
                train_features, train_labels,
                test_features, test_labels,
                batch_size=256,
                num_epochs=20,
                lr=0.1
            )

            result = {
                "epoch": epoch,
                "batch_size": bs,
                "acc": accuracy
            }
            all_results.append(result)

    if all_results:
        df = pd.DataFrame(all_results)
        save_path = os.path.join(RESULTS_FOLDER, f"hybrid_model_{percent_labeled}percent_data_accuracy_summary.csv")
        df.to_csv(save_path, index=False)
        print(f"All results saved to {save_path}")
    else:
        print("No results to save.")


In [None]:
CHECKPOINT_FOLDER = '/content/drive/My Drive/ECE661/Project/hybrid_model/training_models/stronger_augmentation'
linear_eval(train_loader_eval, test_loader_eval, 100, CHECKPOINT_FOLDER)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Load all results
hybrid_df = pd.read_csv("/content/drive/My Drive/ECE 661/Project/hybrid_model/linear_evals/hybrid_model_100percent_data_accuracy_summary.csv")
simclr_df = pd.read_csv("/content/drive/My Drive/ECE 661/Project/linear_evals/LRset2_100percent_labeled_data_accuracy_summary.csv")
rotnet_df = pd.read_csv("/content/drive/My Drive/ECE 661/Project/final_rotnet_eval_resultsCSV.csv")

# Filter for batch_size = 256
simclr_df = simclr_df[simclr_df["batch_size"] == 256]
rotnet_df = rotnet_df[rotnet_df["batch_size"] == 256]

# Create consistent x positions
x = np.arange(len(hybrid_df["epoch"]))
bar_width = 0.25

# Plot
plt.figure(figsize=(12, 6))
plt.bar(x - bar_width, hybrid_df["acc"], width=bar_width, label="Hybrid")
plt.bar(x, simclr_df["avg_accuracy"], width=bar_width, label="SimCLR")
plt.bar(x + bar_width, rotnet_df["accuracy"], width=bar_width, label="RotNet")

plt.xlabel("Training Epochs")
plt.ylabel("Top 1 Accuracy")
plt.title("Top 1 Accuracy vs Epoch (Batch Size = 256)")
plt.xticks(x, hybrid_df["epoch"])
plt.ylim(0.6, 0.9)
plt.grid(axis='y')
plt.legend()
plt.tight_layout()

plt.savefig("figure_hybridComp2.svg")
plt.show()





In [None]:
# 1. Load Pretrained Model Checkpoint
checkpoint_path = '/content/drive/My Drive/ECE661/Project/hybrid_model/training_models/stronger_augmentation/SimCLR_RotNet_epoch90.pth'
checkpoint = torch.load(checkpoint_path, map_location=device)
model = SimCLRRotNetModel().to(device)
model.load_state_dict(checkpoint['model_state_dict'])

# 2. Extract and Freeze Encoder
encoder = model.encoder.to(device)
encoder.eval()
for param in encoder.parameters():
    param.requires_grad = False  # freeze encoder

# Get features
train_features, train_labels = extractFeatures(train_loader_eval, encoder)
test_features, test_labels = extractFeatures(test_loader_eval, encoder)

# Run linear eval
accuracy = train_linear_classifier("SimCLR+RotNet", encoder,
                                   train_features, train_labels,
                                   test_features, test_labels,
                                   batch_size=256, num_epochs=20, lr=0.1)