In [5]:

# train SimCLR on CIFAR-10
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
from data import PairTransform, transform
from model import SimCLR, nt_xent

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

bsz = 1024

train_dataset = torchvision.datasets.CIFAR10(root='./data', transform=PairTransform(transform), download=True)
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=bsz, shuffle=True)

# Define model, optimizer, and scheduler
model = SimCLR(resnet=18, out_dim=128, projection="nonlinear").to(device)
criterion = nt_xent
optimizer  = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-4)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(trainloader), eta_min=0, last_epoch=-1)

model.train()
for epoch in range(100):
    for i, (images, labels) in enumerate(trainloader):
        optimizer.zero_grad() 
        h_i, h_j, z_i, z_j  = model(images[0].to(device), images[1].to(device))
        loss = criterion(z_i, z_j, t=0.5)
        loss.backward()
        optimizer.step()
        scheduler.step()
        if i % 10 == 0:
            print(f'Epoch {epoch}, Iteration {i}, Loss: {loss.item()}')

Files already downloaded and verified
Epoch 0, Iteration 0, Loss: 6.900012969970703
Epoch 0, Iteration 10, Loss: 6.378054141998291
Epoch 0, Iteration 20, Loss: 6.375323295593262
Epoch 0, Iteration 30, Loss: 6.294527530670166
Epoch 0, Iteration 40, Loss: 6.292476654052734
Epoch 0, Iteration 50, Loss: 6.212110996246338
Epoch 0, Iteration 60, Loss: 6.18182373046875
Epoch 0, Iteration 70, Loss: 6.207594871520996
Epoch 0, Iteration 80, Loss: 6.162230491638184
Epoch 0, Iteration 90, Loss: 6.167133331298828
Epoch 1, Iteration 0, Loss: 6.218161582946777
Epoch 1, Iteration 10, Loss: 6.172418594360352
Epoch 1, Iteration 20, Loss: 6.183294296264648


KeyboardInterrupt: 

In [14]:
# Evaluate the model by training a linear classifier on top of the learned features

# Define linear classifier
class LinearClassifier(nn.Module):
    def __init__(self, feature_extractor, in_dim, out_dim, freeze=True):
        super(LinearClassifier, self).__init__()
        # copy the feature extractor
        self.feature_extractor = feature_extractor
        # freeze the feature extractor
        self.feature_extractor.requires_grad_(not freeze)
        self.fc = nn.Linear(in_dim, out_dim)

    def forward(self, x):
        return self.fc(self.feature_extractor(x))
    

# Load data
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=bsz, shuffle=True)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=bsz, shuffle=False)

# Define model, optimizer, and scheduler
lin_classifier = LinearClassifier(model.resnet, in_dim=512, out_dim=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(lin_classifier.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(trainloader), eta_min=0, last_epoch=-1)

# Train
for epoch in range(10):
    lin_classifier.train()
    for i, (images, labels) in enumerate(trainloader):
        optimizer.zero_grad()
        outputs = lin_classifier(images.to(device))
        loss = criterion(outputs, labels.to(device))
        loss.backward()
        optimizer.step()
        scheduler.step()
        if i % 10 == 0:
            print(f'Epoch {epoch}, Iteration {i}, Loss: {loss.item()}')


# Evaluate
lin_classifier.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in testloader:
        outputs = lin_classifier(images.to(device))
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels.to(device)).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct / total}%')

Files already downloaded and verified
Files already downloaded and verified
Epoch 0, Iteration 0, Loss: 2.504277467727661
Epoch 0, Iteration 10, Loss: 2.271807909011841
Epoch 0, Iteration 20, Loss: 2.170149803161621
Epoch 0, Iteration 30, Loss: 2.124979019165039
Epoch 0, Iteration 40, Loss: 2.0536441802978516
Epoch 0, Iteration 50, Loss: 2.0429749488830566
Epoch 0, Iteration 60, Loss: 2.035769462585449
Epoch 1, Iteration 0, Loss: 2.0489633083343506
Epoch 1, Iteration 10, Loss: 2.0272107124328613
Epoch 1, Iteration 20, Loss: 2.024613380432129
Epoch 1, Iteration 30, Loss: 1.9897041320800781
Epoch 1, Iteration 40, Loss: 1.9635157585144043
Epoch 1, Iteration 50, Loss: 1.8909162282943726
Epoch 1, Iteration 60, Loss: 1.8146157264709473
Epoch 2, Iteration 0, Loss: 1.7867932319641113
Epoch 2, Iteration 10, Loss: 1.7435411214828491
Epoch 2, Iteration 20, Loss: 1.6680068969726562
Epoch 2, Iteration 30, Loss: 1.6417274475097656
Epoch 2, Iteration 40, Loss: 1.6209174394607544
Epoch 2, Iteration 50

In [17]:
# Save the model
torch.save(model.state_dict(), './models/model.pth')