In [5]:
import copy
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torchvision

from lightly.data import LightlyDataset
from lightly.loss import NTXentLoss
from lightly.models import ResNetGenerator
from lightly.models.modules.heads import MoCoProjectionHead
from lightly.models.utils import (
    batch_shuffle,
    batch_unshuffle,
    deactivate_requires_grad,
    update_momentum,
)
from lightly.transforms import MoCoV2Transform, utils

from lightly.loss import NTXentLoss
from lightly.models.modules import MoCoProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.moco_transform import MoCoV2Transform
from lightly.utils.scheduler import cosine_schedule

In [2]:
num_workers = 8
batch_size = 256  # STL-10 is larger, so a smaller batch size may be better for memory
memory_bank_size = 4096
seed = 1
max_epochs = 100

path_to_data = "../data"

In [51]:
import copy

import torch
import torchvision
from torch import nn
from torchvision.datasets import STL10

from lightly.loss import NTXentLoss
from lightly.models.modules import MoCoProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.moco_transform import MoCoV2Transform
from lightly.utils.scheduler import cosine_schedule


class MoCo(nn.Module):
    def __init__(self, backbone):
        super().__init__()

        self.backbone = backbone
        self.projection_head = MoCoProjectionHead(512, 512, 128)

        self.backbone_momentum = copy.deepcopy(self.backbone)
        self.projection_head_momentum = copy.deepcopy(self.projection_head)

        deactivate_requires_grad(self.backbone_momentum)
        deactivate_requires_grad(self.projection_head_momentum)

    def forward(self, x):
        query = self.backbone(x).flatten(start_dim=1)
        query = self.projection_head(query)
        return query

    def forward_momentum(self, x):
        key = self.backbone_momentum(x).flatten(start_dim=1)
        key = self.projection_head_momentum(key).detach()
        return key


resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = MoCo(backbone)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)


transform = MoCoV2Transform(input_size=96)
dataset = torchvision.datasets.STL10(
    root='../data', split='train+unlabeled', download=True, transform=transform
)

# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder", transform=transform)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=256,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)



Files already downloaded and verified


In [52]:
criterion = NTXentLoss(memory_bank_size=(4096, 128))
optimizer = torch.optim.SGD(model.parameters(), lr=0.06)

epochs = 200

print("Starting Training")
for epoch in range(epochs):
    total_loss = 0
    momentum_val = cosine_schedule(epoch, epochs, 0.996, 1)
    for batch in dataloader:
        x_query, x_key = batch[0]
        update_momentum(model.backbone, model.backbone_momentum, m=momentum_val)
        update_momentum(
            model.projection_head, model.projection_head_momentum, m=momentum_val
        )
        x_query = x_query.to(device)
        x_key = x_key.to(device)
        query = model(x_query)
        key = model.forward_momentum(x_key)
        loss = criterion(query, key)
        total_loss += loss.detach()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    avg_loss = total_loss / len(dataloader)
    print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")

Starting Training


epoch: 00, loss: 7.89435
epoch: 01, loss: 7.44988
epoch: 02, loss: 7.29025
epoch: 03, loss: 7.20489
epoch: 04, loss: 7.15354
epoch: 05, loss: 7.11737
epoch: 06, loss: 7.09379
epoch: 07, loss: 7.07331
epoch: 08, loss: 7.05743
epoch: 09, loss: 7.04059
epoch: 10, loss: 7.02780
epoch: 11, loss: 7.01487
epoch: 12, loss: 7.00247
epoch: 13, loss: 6.99387
epoch: 14, loss: 6.98306
epoch: 15, loss: 6.97650
epoch: 16, loss: 6.96730
epoch: 17, loss: 6.95951
epoch: 18, loss: 6.95395
epoch: 19, loss: 6.94624
epoch: 20, loss: 6.93977
epoch: 21, loss: 6.93605
epoch: 22, loss: 6.92766
epoch: 23, loss: 6.92371
epoch: 24, loss: 6.91980
epoch: 25, loss: 6.91480
epoch: 26, loss: 6.91063
epoch: 27, loss: 6.90630
epoch: 28, loss: 6.90125
epoch: 29, loss: 6.89666
epoch: 30, loss: 6.89313
epoch: 31, loss: 6.89133
epoch: 32, loss: 6.88518
epoch: 33, loss: 6.88483
epoch: 34, loss: 6.88132
epoch: 35, loss: 6.87815
epoch: 36, loss: 6.87500
epoch: 37, loss: 6.87133
epoch: 38, loss: 6.86969
epoch: 39, loss: 6.86777


In [53]:
new_backbone  = nn.Sequential(*list(model.backbone.children())[:-1])
torch.save(new_backbone.state_dict(), 'models/backbone_weights_200.pth')

In [54]:
class ClassificationNet(nn.Module):
    def __init__(self, backbone, num_classes):
        super(ClassificationNet, self).__init__()
        self.backbone = backbone
        self.classifier = nn.Linear(512, num_classes)

    def forward(self, x):
        features = self.backbone(x)
        pooled_features = nn.AdaptiveAvgPool2d((1, 1))(features)
        pooled_features = pooled_features.view(pooled_features.size(0), -1)
        output = self.classifier(pooled_features)
        return output

classification_model = ClassificationNet(new_backbone, num_classes=10).to(device)

In [55]:
from torchvision import transforms
classification_transform = transforms.Compose([
    # transforms.RandomResizedCrop(96),
    # transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),  # RGB for classification
])

In [56]:
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import STL10
stl10_train = STL10(root='../data', split='train', download=True, transform=classification_transform)
stl10_test = STL10(root='../data', split='test', download=True, transform=classification_transform)

# Fine-tuning: Load training data for classification task
train_loader = DataLoader(stl10_train, batch_size=64, shuffle=True)

# Testing: Load test data for final evaluation
test_loader = DataLoader(stl10_test, batch_size=64, shuffle=True)

Files already downloaded and verified
Files already downloaded and verified


In [57]:
criterion = nn.CrossEntropyLoss()  # multi-class classification
optimizer = torch.optim.Adam(classification_model.parameters(), lr=1e-4)

# Training Loop
num_epochs = 150
for epoch in range(num_epochs):
    classification_model.train()  
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = classification_model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}")

    # if (epoch + 1) % 10 == 0:
    #     torch.save(classification_model.state_dict(), f'models/downstream/classification_model_weights_epoch_{epoch+1}.pth')



PATH = 'models/downstream/classification_model_weights_final_200.pth'
torch.save(classification_model.state_dict(), PATH)

Epoch [1/150], Loss: 1.1250
Epoch [2/150], Loss: 0.7289
Epoch [3/150], Loss: 0.5277
Epoch [4/150], Loss: 0.3535
Epoch [5/150], Loss: 0.1954
Epoch [6/150], Loss: 0.1410
Epoch [7/150], Loss: 0.1201
Epoch [8/150], Loss: 0.0611
Epoch [9/150], Loss: 0.0964
Epoch [10/150], Loss: 0.0590
Epoch [11/150], Loss: 0.0520
Epoch [12/150], Loss: 0.0229
Epoch [13/150], Loss: 0.0086
Epoch [14/150], Loss: 0.0284
Epoch [15/150], Loss: 0.0232
Epoch [16/150], Loss: 0.0605
Epoch [17/150], Loss: 0.0467
Epoch [18/150], Loss: 0.0409
Epoch [19/150], Loss: 0.0115
Epoch [20/150], Loss: 0.0057
Epoch [21/150], Loss: 0.0064
Epoch [22/150], Loss: 0.0094
Epoch [23/150], Loss: 0.0046
Epoch [24/150], Loss: 0.0108
Epoch [25/150], Loss: 0.1085
Epoch [26/150], Loss: 0.0453
Epoch [27/150], Loss: 0.0313
Epoch [28/150], Loss: 0.0083
Epoch [29/150], Loss: 0.0102
Epoch [30/150], Loss: 0.0632
Epoch [31/150], Loss: 0.0195
Epoch [32/150], Loss: 0.0105
Epoch [33/150], Loss: 0.0551
Epoch [34/150], Loss: 0.0288
Epoch [35/150], Loss: 0

In [58]:
# Evaluation
classification_model.eval()  # Set model to evaluation mode
correct = 0
top_5_correct = 0
top_3_correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = classification_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        _, predicted_3 = torch.topk(outputs.data, k=3, dim=1)
        correct_3 = predicted_3.eq(labels.unsqueeze(1).expand_as(predicted_3))
        top_3_correct += correct_3.any(dim=1).sum().item()

        _, predicted_5 = torch.topk(outputs.data, k=5, dim=1)
        correct_5 = predicted_5.eq(labels.unsqueeze(1).expand_as(predicted_5))
        top_5_correct += correct_5.any(dim=1).sum().item()



accuracy = 100 * correct / total
top_5 = 100 * top_5_correct / total
top_3 = 100 * top_3_correct / total
print(f'Top-1 Accuracy of the model on the test set: {accuracy:.2f}%')
print(f'Top-5 Accuracy of the model on the test set: {top_5:.2f}%')
print(f'Top-3 Accuracy of the model on the test set: {top_3:.2f}%')

Top-1 Accuracy of the model on the test set: 71.85%
Top-5 Accuracy of the model on the test set: 96.89%
Top-3 Accuracy of the model on the test set: 91.45%
