In [None]:
import torch
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("GPU:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU")
print("Number of GPUs available:", torch.cuda.device_count())


PyTorch version: 2.5.1
CUDA available: True
GPU: NVIDIA A100 80GB PCIe MIG 7g.80gb
Number of GPUs available: 1


In [2]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.models import resnet34
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
import os
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from torch.utils.data import Dataset  # Import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.transforms import functional as F


In [5]:
class BagSSLModel(nn.Module):
    def __init__(self, embedding_dim=128):
        super(BagSSLModel, self).__init__()
        self.encoder = resnet34(pretrained=False)
        self.encoder.fc = nn.Identity()  # Remove the classification head

        # Projector: 2-layer MLP
        self.projector = nn.Sequential(
            nn.Linear(512, 4096),
            nn.BatchNorm1d(4096),
            nn.ReLU(),
            nn.Linear(4096, embedding_dim),
            nn.BatchNorm1d(embedding_dim)
        )
    
    def forward(self, x):
        features = self.encoder(x)  # Extract features from ResNet-34
        projections = self.projector(features)  # Map to embedding space
        return projections

In [6]:
import torch.nn.functional as F

def contrastive_loss(z1, z2, temperature=0.07):
    # Normalize embeddings
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)

    # Compute similarity matrix
    similarity_matrix = torch.cat([z1, z2]).mm(torch.cat([z1, z2]).T) / temperature

    # Positive pairs on diagonal
    batch_size = z1.size(0)
    labels = torch.arange(batch_size).cuda()
    logits = similarity_matrix[:batch_size, batch_size:]

    # Compute cross-entropy loss
    return F.cross_entropy(logits, labels)


In [7]:
print(f"Current working directory: {os.getcwd()}")

Current working directory: /home/synaderi


In [8]:
#checkpoint_dir = "/home/synaderi/checkpoints"
#os.makedirs(checkpoint_dir, exist_ok=True)
#print(f"Checkpoints will be saved in: {checkpoint_dir}")

Checkpoints will be saved in: /home/synaderi/checkpoints


In [9]:
# Define checkpoint directory
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir, exist_ok=True)
    print(f"Created directory: {checkpoint_dir}")
else:
    print(f"Directory already exists: {checkpoint_dir}")


Directory already exists: /home/synaderi/checkpoints


In [8]:
# Initialize model, optimizer, and scheduler
model = BagSSLModel(embedding_dim=128).cuda()

optimizer = torch.optim.SGD(model.parameters(), lr=0.3, momentum=0.9, weight_decay=1e-4)


# Define scheduler for cosine decay
scheduler = CosineAnnealingLR(optimizer, T_max=600)  # T_max: Total number of epochs




In [9]:
from PIL import ImageOps

def solarize(img, threshold=128):
    """
    Apply solarization to a PIL image. Inverts all pixel values above the threshold.
    Args:
        img (PIL.Image.Image): Input image in PIL format.
        threshold (int): Pixel value threshold for solarization.
    Returns:
        PIL.Image.Image: Solarized image.
    """
    return ImageOps.solarize(img, threshold=threshold)


# Define augmentations for T
augmentations_T = transforms.Compose([
    transforms.RandomResizedCrop(size=32, scale=(0.08, 1.0), ratio=(3/4, 4/3)),  # Random crop
    transforms.RandomHorizontalFlip(p=0.5),  # Flip probability = 0.5
    transforms.RandomApply(
        [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
        p=0.8
    ),  # Color jittering
    transforms.RandomGrayscale(p=0.2),  # Color dropping (grayscale conversion)
    transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),  # Gaussian blur, probability 1.0
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])  # CIFAR-10 normalization
])

# Define augmentations for T'
augmentations_T_prime = transforms.Compose([
    transforms.RandomResizedCrop(size=32, scale=(0.08, 1.0), ratio=(3/4, 4/3)),  # Random crop
    transforms.RandomHorizontalFlip(p=0.5),  # Flip probability = 0.5
    transforms.RandomApply(
        [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
        p=0.8
    ),  # Color jittering
    transforms.RandomGrayscale(p=0.2),  # Color dropping (grayscale conversion)
    transforms.RandomApply([transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))], p=0.1),  # Gaussian blur, probability 0.1
    transforms.RandomApply([transforms.Lambda(lambda img: solarize(img, threshold=128))], p=0.2),  # Solarization
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])  # CIFAR-10 normalization
])


In [10]:

# Custom Dataset for BYOL
class BYOLDataset(Dataset):
    def __init__(self, dataset, transform1, transform2):
        self.dataset = dataset  # Original CIFAR-10 dataset
        self.transform1 = transform1  # Augmentation pipeline T
        self.transform2 = transform2  # Augmentation pipeline T'

    def __getitem__(self, index):
        image, _ = self.dataset[index]
        image1 = self.transform1(image)  # Apply T
        image2 = self.transform2(image)  # Apply T'
        return image1, image2

    def __len__(self):
        return len(self.dataset)

# Create BYOL dataset and DataLoader
train_dataset = BYOLDataset(
    CIFAR10(root="./data", train=True, download=True, transform=None),
    augmentations_T, augmentations_T_prime
)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4)


Files already downloaded and verified




In [11]:
# Specify the checkpoint path
checkpoint_dir = "./checkpoints"
latest_checkpoint = None

# Find the latest checkpoint file
if os.path.exists(checkpoint_dir):
    checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pth")]
    if checkpoints:
        latest_checkpoint = os.path.join(checkpoint_dir, sorted(checkpoints, key=lambda x: int(x.split('_')[1].split('.')[0]))[-1])

# Load the checkpoint
start_epoch = 0
if latest_checkpoint:
    print(f"Loading checkpoint: {latest_checkpoint}")
    checkpoint = torch.load(latest_checkpoint)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    start_epoch = checkpoint['epoch']
    print(f"Resuming training from epoch {start_epoch}")

Loading checkpoint: ./checkpoints/epoch_100.pth
Resuming training from epoch 100


  checkpoint = torch.load(latest_checkpoint)


In [12]:
# Training Loop
for epoch in range(start_epoch, 600):
    model.train()
    total_loss = 0.0
    for images1, images2 in train_loader:
        # Move to GPU
        images1 = images1.cuda()
        images2 = images2.cuda()

        # Forward pass
        z1 = model(images1)
        z2 = model(images2)
        loss = contrastive_loss(z1, z2)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    
    # Update learning rate
    scheduler.step()

    print(f"Epoch [{epoch + 1}/600], Loss: {total_loss / len(train_loader):.4f}")

    # Save checkpoint every 10 epochs
    if (epoch + 1) % 10 == 0:
        checkpoint_path = f"./checkpoints/epoch_{epoch + 1}.pth"
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': total_loss
        }, checkpoint_path)
        print(f"Checkpoint saved at {checkpoint_path}")


Epoch [101/600], Loss: 1.8201
Epoch [102/600], Loss: 1.8321
Epoch [103/600], Loss: 1.8379
Epoch [104/600], Loss: 1.8217
Epoch [105/600], Loss: 1.8086
Epoch [106/600], Loss: 1.7929
Epoch [107/600], Loss: 1.7805
Epoch [108/600], Loss: 1.7886


: 