**LOADING THE PRETRAINED SSL MODEL AND TRAINING A CLASSIFIER ON TOP OF IT**

***

***

# Imports

In [1]:
import torch
import torchvision
from torchinfo import summary

import lightly
from lightly.models.modules.heads import SimSiamPredictionHead,SimSiamProjectionHead

import matplotlib.pyplot as plt
import numpy as np
import math
import random

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Hyperparamenters.
input_size = 224  # input_size = 256
batch_size = 32   # batch_size = 128
num_workers = 8
epochs = 5

# Seed torch and numpy.
seed = 1
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

***

***

# Loading dataset

In [3]:
data_dir = 'datasets/Sentinel2GlobalLULC_ratio'

## Custom tranforms

In [4]:
# Define the augmentations for self-supervised learning.
train_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((input_size, input_size)),
    torchvision.transforms.RandomResizedCrop(size=input_size, scale=(0.2, 1.0)),
    torchvision.transforms.RandomHorizontalFlip(p=0.5),
    torchvision.transforms.RandomVerticalFlip(p=0.5),
    torchvision.transforms.GaussianBlur(21),
    torchvision.transforms.ToTensor(),
])

test_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((input_size, input_size)),
    torchvision.transforms.ToTensor(),
])

## ImageFolder

In [5]:
# Loading both datasets.
train_data = torchvision.datasets.ImageFolder(data_dir + '/train/')

val_data = torchvision.datasets.ImageFolder(data_dir + '/val/')

test_data = torchvision.datasets.ImageFolder(data_dir + '/test/')

train_data_lightly = lightly.data.LightlyDataset.from_torch_dataset(train_data)
test_data_lightly = lightly.data.LightlyDataset.from_torch_dataset(test_data,
                                                                   transform=test_transform)

## Dataloaders

In [6]:
# Define the augmentations for self-supervised learning.
collate_fn_train = lightly.data.collate.BaseCollateFunction(train_transform)

# Create a dataloader for training and embedding.
dataloader_train_simsiam = torch.utils.data.DataLoader(
    train_data_lightly,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn_train,
    drop_last=True,
    num_workers=num_workers
)

dataloader_test = torch.utils.data.DataLoader(
    test_data_lightly,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers
)

## Check balance and size

In [7]:
# Check samples per class in train dataset.
print(np.unique(train_data.targets, return_counts=True))

(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28]), array([9800, 3259, 6208, 9800, 8355, 3105,  943, 9800, 7306, 4466, 2015,
        396,  880, 9800, 2739, 2710, 9793,  291,  340, 2943, 9800, 9800,
       9800, 1402,  589,  714,  247,  289, 8813]))


In [8]:
# Check samples per class in val dataset.
print(np.unique(val_data.targets, return_counts=True))

(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28]), array([1400,  465,  886, 1400, 1193,  443,  134, 1400, 1043,  638,  288,
         56,  125, 1400,  391,  387, 1399,   41,   48,  420, 1400, 1400,
       1400,  200,   84,  102,   35,   41, 1259]))


In [9]:
print(len(train_data.targets))
print(len(val_data.targets))
print(len(test_data.targets))

136403
19478
38996


***

***

# Loading the model

In [10]:
resnet = torchvision.models.resnet18(weights=None)
model = torch.nn.Sequential(*list(resnet.children())[:-1])
model.load_state_dict(torch.load('pytorch_models/simsiam_backbone_resnet18'))
# model.eval() or model.train()

<All keys matched successfully>

In [11]:
num_classes = len(np.unique(train_data.targets, return_counts=False))
print(num_classes)

29


In [12]:
summary(model, input_size=(batch_size, 3, input_size, input_size))

Layer (type:depth-idx)                   Output Shape              Param #
Sequential                               [32, 512, 1, 1]           --
├─Conv2d: 1-1                            [32, 64, 112, 112]        9,408
├─BatchNorm2d: 1-2                       [32, 64, 112, 112]        128
├─ReLU: 1-3                              [32, 64, 112, 112]        --
├─MaxPool2d: 1-4                         [32, 64, 56, 56]          --
├─Sequential: 1-5                        [32, 64, 56, 56]          --
│    └─BasicBlock: 2-1                   [32, 64, 56, 56]          --
│    │    └─Conv2d: 3-1                  [32, 64, 56, 56]          36,864
│    │    └─BatchNorm2d: 3-2             [32, 64, 56, 56]          128
│    │    └─ReLU: 3-3                    [32, 64, 56, 56]          --
│    │    └─Conv2d: 3-4                  [32, 64, 56, 56]          36,864
│    │    └─BatchNorm2d: 3-5             [32, 64, 56, 56]          128
│    │    └─ReLU: 3-6                    [32, 64, 56, 56]          --
│

In [15]:
model = torch.nn.Sequential(
    model,
    torch.nn.Linear(in_features=32*512, out_features=num_classes),
    torch.nn.Softmax(dim=1),
)

In [16]:
summary(model, input_size=(batch_size, 3, input_size, input_size))

RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Sequential: 1, Conv2d: 2, BatchNorm2d: 2, ReLU: 2, MaxPool2d: 2, Sequential: 2, BasicBlock: 3, Conv2d: 4, BatchNorm2d: 4, ReLU: 4, Conv2d: 4, BatchNorm2d: 4, ReLU: 4, BasicBlock: 3, Conv2d: 4, BatchNorm2d: 4, ReLU: 4, Conv2d: 4, BatchNorm2d: 4, ReLU: 4, Sequential: 2, BasicBlock: 3, Conv2d: 4, BatchNorm2d: 4, ReLU: 4, Conv2d: 4, BatchNorm2d: 4, Sequential: 4, Conv2d: 5, BatchNorm2d: 5, ReLU: 4, BasicBlock: 3, Conv2d: 4, BatchNorm2d: 4, ReLU: 4, Conv2d: 4, BatchNorm2d: 4, ReLU: 4, Sequential: 2, BasicBlock: 3, Conv2d: 4, BatchNorm2d: 4, ReLU: 4, Conv2d: 4, BatchNorm2d: 4, Sequential: 4, Conv2d: 5, BatchNorm2d: 5, ReLU: 4, BasicBlock: 3, Conv2d: 4, BatchNorm2d: 4, ReLU: 4, Conv2d: 4, BatchNorm2d: 4, ReLU: 4, Sequential: 2, BasicBlock: 3, Conv2d: 4, BatchNorm2d: 4, ReLU: 4, Conv2d: 4, BatchNorm2d: 4, Sequential: 4, Conv2d: 5, BatchNorm2d: 5, ReLU: 4, BasicBlock: 3, Conv2d: 4, BatchNorm2d: 4, ReLU: 4, Conv2d: 4, BatchNorm2d: 4, ReLU: 4, AdaptiveAvgPool2d: 2]