<a href="https://colab.research.google.com/github/spacebasie/multiagent-ssl/blob/main/vicreg_stef.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Install lightly
!pip install lightly

Collecting lightly
  Downloading lightly-1.5.21-py3-none-any.whl.metadata (37 kB)
Collecting hydra-core>=1.0.0 (from lightly)
  Downloading hydra_core-1.3.2-py3-none-any.whl.metadata (5.5 kB)
Collecting lightly_utils~=0.0.0 (from lightly)
  Downloading lightly_utils-0.0.2-py3-none-any.whl.metadata (1.4 kB)
Collecting pytorch_lightning>=1.0.4 (from lightly)
  Downloading pytorch_lightning-2.5.2-py3-none-any.whl.metadata (21 kB)
Collecting aenum>=3.1.11 (from lightly)
  Downloading aenum-3.1.16-py3-none-any.whl.metadata (3.8 kB)
Collecting torchmetrics>=0.7.0 (from pytorch_lightning>=1.0.4->lightly)
  Downloading torchmetrics-1.7.3-py3-none-any.whl.metadata (21 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch_lightning>=1.0.4->lightly)
  Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->lightly)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecti

**Import dependencies**

In [2]:
# This example requires the following dependencies to be installed:
import torch
import torchvision
from torch import nn
import matplotlib.pyplot as plt
from lightly.loss import VICRegLoss
from lightly.models.modules.heads import VICRegProjectionHead
from lightly.transforms.vicreg_transform import VICRegTransform
import torch.nn.functional as F




Define the VICReg model and its forward pass

In [3]:
# --- 1. Model Definition ---
class VICReg(nn.Module):
    def __init__(self, backbone, proj_input_dim=512):
        super().__init__()
        self.backbone = backbone
        self.projection_head = VICRegProjectionHead(
            input_dim=proj_input_dim,
            hidden_dim=2048,
            output_dim=2048,
        )

    def forward(self, x):
        x = self.backbone(x)
        x = x.flatten(start_dim=1)
        z = self.projection_head(x)
        return z

    def forward_backbone(self, x):
        x = self.backbone(x)
        return x.flatten(start_dim=1)


In [4]:
# --- 0. VICReg Loss Definition (New Block) ---
class VICRegLoss(nn.Module):
    """
    VICReg Loss Function.
    Args:
        lambda_ (float): Coefficient for the invariance term.
        mu (float): Coefficient for the variance term.
        nu (float): Coefficient for the covariance term.
        epsilon (float): Small value for numerical stability in variance calculation.
    """
    def __init__(self, lambda_=25.0, mu=25.0, nu=1.0, epsilon=1e-4):
        super().__init__()
        self.lambda_ = lambda_
        self.mu = mu
        self.nu = nu
        self.epsilon = epsilon

    def forward(self, z_a, z_b):
        # Invariance term (Mean Squared Error)
        # Encourages the representations of two views of the same image to be similar.
        sim_loss = F.mse_loss(z_a, z_b)

        # Variance term
        # Encourages the variance of each dimension in the representation batch to be close to 1.
        std_z_a = torch.sqrt(z_a.var(dim=0) + self.epsilon)
        std_z_b = torch.sqrt(z_b.var(dim=0) + self.epsilon)
        std_loss = torch.mean(F.relu(1 - std_z_a)) + torch.mean(F.relu(1 - std_z_b))

        # Covariance term
        # Encourages the off-diagonal elements of the covariance matrix to be zero,
        # decorrelating the dimensions of the representation.
        z_a_norm = z_a - z_a.mean(dim=0)
        z_b_norm = z_b - z_b.mean(dim=0)
        N, D = z_a.shape
        cov_z_a = (z_a_norm.T @ z_a_norm) / (N - 1)
        cov_z_b = (z_b_norm.T @ z_b_norm) / (N - 1)

        # Zero out the diagonal elements to only consider off-diagonal covariance
        off_diag_mask = ~torch.eye(D, device=z_a.device).bool()
        cov_loss = (cov_z_a[off_diag_mask].pow_(2).sum() / D) + \
                   (cov_z_b[off_diag_mask].pow_(2).sum() / D)

        # Combine the three terms with their coefficients
        loss = (self.lambda_ * sim_loss) + (self.mu * std_loss) + (self.nu * cov_loss)
        return loss

Initialize the resnet, backbone and model as well as the data loader and optimizer

In [5]:
# --- 2. Pretraining ---
def vicreg_pretraining(model, dataloader, epochs, device, lambda_, mu, nu):
    """Runs the VICReg self-supervised pre-training with specified hyperparameters."""
    # Pass hyperparameters to the loss function
    criterion = VICRegLoss(lambda_=lambda_, mu=mu, nu=nu)
    # Using a safer learning rate to prevent 'nan' loss instability
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
    model.to(device)

    epoch_losses = []
    print(f"Starting VICReg Pre-training for lambda={lambda_}, mu={mu}, nu={nu}")
    for epoch in range(epochs):
        total_loss = 0
        # Using the original, robust batch unpacking method
        for batch_idx, batch in enumerate(dataloader):
            # Unpack the two augmented views from the batch tuple
            x0, x1 = batch[0]
            x0 = x0.to(device)
            x1 = x1.to(device)

            z0 = model(x0)
            z1 = model(x1)

            loss = criterion(z0, z1)

            # Safety check for numerical instability
            if torch.isnan(loss):
                print(f"Loss became NaN at epoch {epoch}, batch {batch_idx}. Stopping training.")
                print("This is likely caused by a learning rate that is too high. Try reducing it further.")
                # Return early if loss is NaN
                return

            total_loss += loss.detach()

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        avg_loss = total_loss / len(dataloader)
        epoch_losses.append(avg_loss.cpu().item())
        print(f"Epoch: {epoch:02}, Loss: {avg_loss:.5f}")

    print("Pre-training Finished.")

    # Plotting the results (optional, can be adapted for multiple runs)
    # You might want to save plots with unique names for each hyperparameter set
    if epoch_losses:
        plt.figure(figsize=(10, 6))
        plt.plot(epoch_losses, marker='o', linestyle='-')
        plt.title(f'VICReg Training Loss (λ={lambda_}, μ={mu}, ν={nu})')
        plt.xlabel('Epoch')
        plt.ylabel('Average Loss')
        plt.grid(True)
        plt.savefig(f"vicreg_loss_l{lambda_}_m{mu}.png")
        plt.close()

In [6]:
# --- 3. Linear Evaluation ---
def linear_evaluation(model, proj_output_dim, train_loader, test_loader, epochs, device):
    """Runs the linear evaluation on the frozen backbone."""
    print("\nStarting Linear Evaluation")

    # Freeze the backbone
    for param in model.backbone.parameters():
        param.requires_grad = False

    # The input dimension to the classifier must match the backbone's output dimension
    classifier = nn.Linear(proj_output_dim, 10).to(device) # CIFAR-10 has 10 classes

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001)

    # Training the linear classifier
    for epoch in range(epochs):
        total_loss = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            with torch.no_grad():
                representations = model.forward_backbone(images)

            predictions = classifier(representations)
            loss = criterion(predictions, labels)
            total_loss += loss.item()

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        avg_loss = total_loss / len(train_loader)
        print(f"Classifier Training Epoch: {epoch:02}, Loss: {avg_loss:.5f}")

    # Evaluate the classifier
    print("\nEvaluating on Test Set...")
    classifier.eval()
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            representations = model.forward_backbone(images)
            predictions = classifier(representations)
            _, predicted = torch.max(predictions.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Final Test Accuracy: {accuracy:.2f}%")
    return accuracy

**Main Execution**

Define the kNN evaluation function.

In [7]:
# --- 4. kNN Evaluation ---
def knn_evaluation(model, train_loader, test_loader, device, k=200, temperature=0.1):
    """Runs the kNN evaluation on the frozen backbone."""
    print("\nStarting kNN Evaluation")

    # Freeze the model backbone
    model.eval()
    for param in model.backbone.parameters():
        param.requires_grad = False

    # Gather features from the training set
    train_features = []
    train_labels = []
    print("Gathering training features for kNN...")
    with torch.no_grad():
        for images, labels in train_loader:
            images = images.to(device)
            features = model.forward_backbone(images)
            train_features.append(features)
            train_labels.append(labels)

    train_features = torch.cat(train_features, dim=0)
    train_labels = torch.cat(train_labels, dim=0).to(device) # Move train_labels to the specified device

    # Normalize features
    train_features = F.normalize(train_features, dim=1)

    # Evaluate on the test set
    correct = 0
    total = 0
    print("Evaluating on test set using kNN...")
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            test_features = model.forward_backbone(images)
            test_features = F.normalize(test_features, dim=1)

            # Compute similarity matrix
            similarity_matrix = torch.matmul(test_features, train_features.T) / temperature

            # Get top-k neighbors
            _, indices = similarity_matrix.topk(k, dim=1, largest=True, sorted=True)

            # Get labels of top-k neighbors
            k_neighbor_labels = train_labels[indices]

            # Predict the class based on majority vote
            # For each test sample, count the occurrences of each class among its k neighbors
            predictions = torch.mode(k_neighbor_labels, dim=1).values
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    accuracy = 100 * correct / total
    print(f"Final kNN Test Accuracy: {accuracy:.2f}%")
    return accuracy

Update the main execution block to include kNN evaluation.

In [None]:
# --- Main Execution ---
if __name__ == '__main__':
    device = "cuda" if torch.cuda.is_available() else "cpu"
    PRETRAIN_EPOCHS = 100
    EVAL_EPOCHS = 50
    BATCH_SIZE = 256

    # --- Hyperparameter Grid Search ---
    # Define the grid of hyperparameters to search.
    # The paper suggests lambda = mu = 25 and nu = 1 as a good starting point.
    lambda_values = [30]
    mu_values = [30]
    nu_value = 1.0  # Fixed as per the paper and your request

    results = []
    print("Starting Hyperparameter Grid Search for VICReg...")

    # --- Data Loading (define it once outside the loop) ---
    transform_vicreg = VICRegTransform(input_size=32)
    pretrain_dataset = torchvision.datasets.CIFAR10(
        "datasets/cifar10", download=True, transform=transform_vicreg
    )
    pretrain_dataloader = torch.utils.data.DataLoader(
        pretrain_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=2
    )
    transform_eval = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    train_dataset_eval = torchvision.datasets.CIFAR10(
        "datasets/cifar10", download=True, train=True, transform=transform_eval
    )
    test_dataset_eval = torchvision.datasets.CIFAR10(
        "datasets/cifar10", download=True, train=False, transform=transform_eval
    )
    train_loader_eval = torch.utils.data.DataLoader(
        train_dataset_eval, batch_size=BATCH_SIZE, shuffle=True, num_workers=2
    )
    test_loader_eval = torch.utils.data.DataLoader(
        test_dataset_eval, batch_size=BATCH_SIZE, shuffle=False, num_workers=2
    )

    for lambda_ in lambda_values:
        for mu in mu_values:
            print("-" * 60)
            print(f"Testing hyperparameters: lambda={lambda_}, mu={mu}, nu={nu_value}")
            print("-" * 60)

            # --- Model Initialization ---
            # Re-initialize the model for each run to ensure a fair, independent trial
            BACKBONE = torchvision.models.resnet18()
            PROJ_INPUT_DIM = 512 # ResNet-18 outputs 512 features
            backbone = nn.Sequential(*list(BACKBONE.children())[:-1])
            model = VICReg(backbone, proj_input_dim=PROJ_INPUT_DIM).to(device)

            # --- Training ---
            # Pass the current hyperparameters to the pretraining function
            vicreg_pretraining(model, pretrain_dataloader, PRETRAIN_EPOCHS, device, lambda_=lambda_, mu=mu, nu=nu_value)

            # --- Linear Evaluation ---
            linear_acc = linear_evaluation(model, proj_output_dim=PROJ_INPUT_DIM, train_loader=train_loader_eval, test_loader=test_loader_eval, epochs=EVAL_EPOCHS, device=device)

            # --- kNN evaluation ---
            knn_acc = knn_evaluation(model, train_loader_eval, test_loader_eval, device)

            # Store results for this run
            results.append({
                'lambda': lambda_,
                'mu': mu,
                'nu': nu_value,
                'linear_accuracy': linear_acc,
                'knn_accuracy': knn_acc
            })

    # --- Print Final Results Summary ---
    print("\n\n" + "=" * 60)
    print("Hyperparameter Tuning Grid Search Results")
    print("=" * 60)
    # Sort results by the metric you care about most (e.g., linear accuracy)
    sorted_results = sorted(results, key=lambda x: x['linear_accuracy'], reverse=True)
    for res in sorted_results:
        print(f"λ={res['lambda']:<4}, μ={res['mu']:<4}, ν={res['nu']} -> "
              f"Linear Accuracy: {res['linear_accuracy']:.2f}%, "
              f"kNN Accuracy: {res['knn_accuracy']:.2f}%")

    best_run = sorted_results[0]
    print("\nBest Performing Hyperparameters (by Linear Accuracy):")
    print(f"λ={best_run['lambda']}, μ={best_run['mu']}, ν={best_run['nu']} with "
          f"Linear Accuracy: {best_run['linear_accuracy']:.2f}% and "
          f"kNN Accuracy: {best_run['knn_accuracy']:.2f}%")

Starting Hyperparameter Grid Search for VICReg...




------------------------------------------------------------
Testing hyperparameters: lambda=30, mu=30, nu=1.0
------------------------------------------------------------
Starting VICReg Pre-training for lambda=30, mu=30, nu=1.0
Epoch: 00, Loss: 46.05739
Epoch: 01, Loss: 45.13380
Epoch: 02, Loss: 44.65610
Epoch: 03, Loss: 44.29099
Epoch: 04, Loss: 44.02333
Epoch: 05, Loss: 43.75507
Epoch: 06, Loss: 43.53495
Epoch: 07, Loss: 43.34589
Epoch: 08, Loss: 43.22565
Epoch: 09, Loss: 43.03850
Epoch: 10, Loss: 42.90893
Epoch: 11, Loss: 42.76710
Epoch: 12, Loss: 42.64973
Epoch: 13, Loss: 42.57952
Epoch: 14, Loss: 42.48263
Epoch: 15, Loss: 42.38155
Epoch: 16, Loss: 42.31141
Epoch: 17, Loss: 42.20211
Epoch: 18, Loss: 42.16481
Epoch: 19, Loss: 42.05097
Epoch: 20, Loss: 42.04647
Epoch: 21, Loss: 41.93783
Epoch: 22, Loss: 41.90476
Epoch: 23, Loss: 41.83053
Epoch: 24, Loss: 41.72695
Epoch: 25, Loss: 41.76415
Epoch: 26, Loss: 41.61401
Epoch: 27, Loss: 41.61245
Epoch: 28, Loss: 41.56062
Epoch: 29, Loss: