In [105]:
!python main_vicreg.py --batch-size 1024  --mlp '1024-1024-512'

Namespace(arch='resnet50', base_lr=0.1, batch_size=1024, cov_coeff=1.0, device='cuda', dist_url='env://', epochs=100, exp_dir=PosixPath('exp'), local_rank=-1, log_freq_time=60, mlp='1024-1024-512', num_workers=4, sim_coeff=25.0, std_coeff=25.0, wd=0.0001, world_size=1)
main_vicreg.py --batch-size 1024 --mlp 1024-1024-512
Files already downloaded and verified
{"epoch": 1, "step": 63, "loss": 20.535282135009766, "time": 60, "lr": 0.051428571428571435}
{"epoch": 2, "step": 132, "loss": 19.910831451416016, "time": 120, "lr": 0.10775510204081634}
{"epoch": 4, "step": 198, "loss": 19.32766342163086, "time": 180, "lr": 0.16163265306122448}
{"epoch": 5, "step": 267, "loss": 18.44390869140625, "time": 241, "lr": 0.21795918367346942}
{"epoch": 6, "step": 335, "loss": 17.307598114013672, "time": 301, "lr": 0.27346938775510204}
{"epoch": 8, "step": 402, "loss": 16.366077423095703, "time": 361, "lr": 0.3281632653061225}
{"epoch": 9, "step": 471, "loss": 15.74398136138916, "time": 422, "lr": 0.38448

In [106]:
import torch
import torchvision.models as models

# Load the SSL model and its backbone
ssl_model_state_dict = torch.load("exp/resnet_backbone.pth")

In [107]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms

# Define the data transformations
transform = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
])

# Load the CIFAR-10 dataset with 1% labeled data
trainset = datasets.CIFAR10(root="../data", train=True, download=True, transform=transform)
# trainset_1percent = torch.utils.data.Subset(trainset, range(0, 500))  # Only use 1% of the labels (i.e., 500 out of 50,000)

# # Load the CIFAR-10 dataset with 10% labeled data
# trainset_10percent = torch.utils.data.Subset(trainset, range(0, 5000)) 


Files already downloaded and verified


In [163]:
# Define the percentage of labels to take from each class
percent_per_class = 0.1

# Get the list of classes
classes = trainset.classes

# Create a dictionary to hold the indices of the images for each class
indices = {}
for c in classes:
    indices[c] = []

# Populate the dictionary with the indices of the images for each class
for i in range(len(trainset)):
    _, label = trainset[i]
    indices[classes[label]].append(i)

# Create a list of indices to use for each class based on the percentage
subset_indices = []
for c in classes:
    num_images = len(indices[c])
    num_subset_images = int(num_images * percent_per_class)
    subset_indices.extend(indices[c][:num_subset_images])

# Create a subset sampler using the subset indices
subset_sampler = torch.utils.data.sampler.SubsetRandomSampler(subset_indices)



In [124]:
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
])

In [109]:
testset = datasets.CIFAR10(root='../data', train=False,
                                       download=True, transform=test_transform)

Files already downloaded and verified


In [133]:
# Create the dataloader using the subset sampler
cifar10_1pct_train_loader = torch.utils.data.DataLoader(trainset, batch_size=256, sampler=subset_sampler)


In [164]:
# Create the dataloader using the subset sampler
cifar10_10pct_train_loader = torch.utils.data.DataLoader(trainset, batch_size=256, sampler=subset_sampler)

In [111]:
testloader = data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4)

In [165]:
import torch.nn as nn
from custom_resnet import custom_Resnet,Resnet_block
head = nn.Sequential(
    nn.Linear(in_features=128, out_features=64),
    nn.BatchNorm1d(64),
    nn.ReLU(),
    nn.Linear(in_features=64, out_features=32),
    nn.BatchNorm1d(32),
    nn.ReLU(),
    nn.Linear(in_features=32, out_features=10)
)

backbone= custom_Resnet(Resnet_block,32,[13,13,13])
backbone.load_state_dict(ssl_model_state_dict)
# Combine the backbone and the head

model = nn.Sequential(
    backbone,
    head
)

# Freeze the backbone weights
for param in backbone.parameters():
    param.requires_grad = False

# Define the optimizer
optimizer = torch.optim.Adam(head.parameters(), lr=1e-3)

# Define the loss function
criterion = nn.CrossEntropyLoss()


In [166]:
num_epochs=20
device =torch.device("cuda")
model.to(device)
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(cifar10_10pct_train_loader):
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # Print statistics
        running_loss += loss.item()
   
    print('[%d, %5d] loss: %.3f' %
          (epoch + 1, i + 1, running_loss / len(cifar10_10pct_train_loader)))

[1,    20] loss: 0.442
[2,    20] loss: 0.410
[3,    20] loss: 0.393
[4,    20] loss: 0.382
[5,    20] loss: 0.374
[6,    20] loss: 0.367
[7,    20] loss: 0.364
[8,    20] loss: 0.360
[9,    20] loss: 0.357
[10,    20] loss: 0.355
[11,    20] loss: 0.353
[12,    20] loss: 0.354
[13,    20] loss: 0.352
[14,    20] loss: 0.349
[15,    20] loss: 0.350
[16,    20] loss: 0.348
[17,    20] loss: 0.347
[18,    20] loss: 0.348
[19,    20] loss: 0.343
[20,    20] loss: 0.346


In [129]:
print(running_loss/100)#10%

0.34679766654968264


In [136]:
print(running_loss/100)#1%

0.038683040142059325


In [167]:
def evaluate_top_5(model, test_loader,device):
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = outputs.topk(k=5, dim=1)
            labels = labels.view(-1, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    accuracy = correct / total * 100
    return accuracy



In [159]:
evaluate_top_5(model,testloader,device)#1%

85.65

In [168]:
evaluate_top_5(model,testloader,device)#10%

90.83

In [130]:
import torch
import torch.nn.functional as F

def evaluate(model, dataloader, device):
    # Set model to evaluation mode
    model.eval()

    # Initialize variables to track accuracy and loss
    total_correct = 0
    total_loss = 0
    total_samples = 0

    # Disable gradient computation (speeds up inference)
    with torch.no_grad():
        # Iterate over batches in the dataloader
        for images, labels in dataloader:
            # Move data to the specified device
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass to get logits
            logits = model(images)

            # Compute cross-entropy loss
            loss = F.cross_entropy(logits, labels, reduction='sum')

            # Compute predictions and accuracy
            predictions = torch.argmax(logits, dim=1)
            total_correct += torch.sum(predictions == labels)
            total_loss += loss.item()
            total_samples += images.shape[0]

    # Compute average accuracy and loss
    avg_accuracy = total_correct / total_samples
    avg_loss = total_loss / total_samples

    return avg_accuracy.item(), avg_loss


In [137]:
evaluate(model,testloader,device)

(0.39989998936653137, 1.8280329650878906)

In [131]:
evaluate(model,testloader,device)

(0.4715999960899353, 1.5065254426956176)