**FIRST ATTEMP TO APPLY SSL TO THE SENTINEL-2 DATASET**

Reference tutorial: https://docs.lightly.ai/tutorials/package/tutorial_simsiam_esa.html

In [1]:
%load_ext pycodestyle_magic

In [2]:
%pycodestyle_on

In [3]:
# Reproducibility.
SEED = 42

***

***

# Split the dataset folder

In [4]:
import splitfolders

data_dir_initial = 'datasets/Sentinel2GlobalLULC_full_raw/Sentinel2LULC_JPEG/'

## Ratio (imbalanced)

In [5]:
# Split with a ratio.
# To only split into training and validation set,
# set a tuple to `ratio`, i.e, `(.8, .2)`.
splitfolders.ratio(data_dir_initial,
                   output="datasets/Sentinel2GlobalLULC_full_ratio_seed="
                   + str(SEED),
                   seed=SEED,
                   ratio=(.7, .1, .2),
                   group_prefix=None,
                   move=False)  # Default values.

Copying files: 194877 files [00:46, 4158.47 files/s]


## Fixed (balanced)

In [6]:
# Split val/test with a fixed number of items, e.g. `(100, 100)`, for each set.
# To only split into train-val set, use a single number to `fixed`, i.e., `10`.
# Set 3 values, e.g. `(300, 100, 100)`, to limit the number of training values.
splitfolders.fixed(data_dir_initial,
                   output="datasets/Sentinel2GlobalLULC_full_balanced_seed="
                   + str(SEED),
                   seed=SEED,
                   fixed=(250, 100),
                   oversample=True,
                   group_prefix=None,
                   move=False)  # Default values.

Copying files: 194877 files [00:28, 6776.75 files/s] 
Oversampling: 29 classes [00:25,  1.12 classes/s]


***

***

# Imports

In [None]:
import torch
import torchvision

import lightly
from lightly.models.modules.heads import SimSiamPredictionHead,SimSiamProjectionHead

import matplotlib.pyplot as plt
import numpy as np
import math
import random

In [None]:
# Hyperparamenters.
input_size = 224  # input_size = 256
batch_size = 32   # batch_size = 128
num_workers = 8
epochs = 5

# Dimension of the embeddings.
num_ftrs = 512
# Dimension of the output of the prediction and projection heads.
out_dim = proj_hidden_dim = 512
# The prediction head uses a bottleneck architecture.
pred_hidden_dim = 128

# Seed torch and numpy.
seed = 1
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

***

***

# Loading dataset

In [None]:
data_dir_target = 'datasets/Sentinel2GlobalLULC_ratio'

## Custom tranforms

In [None]:
# Define the augmentations for self-supervised learning.
train_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((input_size, input_size)),
    torchvision.transforms.RandomResizedCrop(size=input_size, scale=(0.2, 1.0)),
    torchvision.transforms.RandomHorizontalFlip(p=0.5),
    torchvision.transforms.RandomVerticalFlip(p=0.5),
    torchvision.transforms.GaussianBlur(21),
    torchvision.transforms.ToTensor(),
])

test_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((input_size, input_size)),
    torchvision.transforms.ToTensor(),
])

## ImageFolder

In [None]:
# Loading both datasets.
train_data = torchvision.datasets.ImageFolder(data_dir + '/train/')

val_data = torchvision.datasets.ImageFolder(data_dir + '/val/')

test_data = torchvision.datasets.ImageFolder(data_dir + '/test/')

train_data_lightly = lightly.data.LightlyDataset.from_torch_dataset(train_data)
test_data_lightly = lightly.data.LightlyDataset.from_torch_dataset(test_data,
                                                                   transform=test_transform)

## Dataloaders

In [None]:
# Define the augmentations for self-supervised learning.
collate_fn_train = lightly.data.collate.BaseCollateFunction(train_transform)

# Create a dataloader for training and embedding.
dataloader_train_simsiam = torch.utils.data.DataLoader(
    train_data_lightly,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn_train,
    drop_last=True,
    num_workers=num_workers
)

dataloader_test = torch.utils.data.DataLoader(
    test_data_lightly,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers
)

## Check balance and size

In [None]:
# Check samples per class in train dataset.
print(np.unique(train_data.targets, return_counts=True))

In [None]:
# Check samples per class in val dataset.
print(np.unique(val_data.targets, return_counts=True))

In [None]:
print(len(train_data.targets))
print(len(val_data.targets))
print(len(test_data.targets))

## See some samples

In [None]:
# import matplotlib.pyplot as plt


# # Accessing Data and Targets in a PyTorch DataLoader
# for images, labels, names in dataloader_train_simsiam:
#     img = images[0][0]
#     print(img.shape)
#     # img = images[0].squeeze()
#     # print(img)
#     label = labels[0]
#     plt.title("Label: " + str(int(label)))
#     plt.imshow(torch.permute(img,(1, 2, 0)))
#     plt.show()
#     break

In [None]:
# # Display image and label.
# train_features, train_labels = next(iter(dataloader_train_simsiam))

# print(f"Features shape of the current batch is {train_features.size()}")
# print(f"Labels shape of the current batch shape is {train_labels.size()}")

# img = train_features[0].squeeze()
# label = train_labels[0]
# plt.title("Label: " + str(int(label)))
# plt.imshow(torch.permute(img,(1, 2, 0)), cmap="gray")
# plt.show()
# print(f"Label: {label}")

# SimSiam model

## Creation

In [None]:
class SimSiam(torch.nn.Module):
    def __init__(
        self, backbone, num_ftrs, proj_hidden_dim, pred_hidden_dim, out_dim
    ):
        super().__init__()
        self.backbone = backbone
        self.projection_head = SimSiamProjectionHead(
            num_ftrs, proj_hidden_dim, out_dim
        )
        self.prediction_head = SimSiamPredictionHead(
            out_dim, pred_hidden_dim, out_dim
        )

    def forward(self, x):
        # get representations
        f = self.backbone(x).flatten(start_dim=1)
        # get projections
        z = self.projection_head(f)
        # get predictions
        p = self.prediction_head(z)
        # stop gradient
        z = z.detach()
        return z, p

In [None]:
# We use a pretrained resnet for this tutorial to speed
# up training time but you can also train one from scratch.
resnet = torchvision.models.resnet18(weights=None) # ADDED
backbone = torch.nn.Sequential(*list(resnet.children())[:-1])
model = SimSiam(backbone, num_ftrs, proj_hidden_dim, pred_hidden_dim, out_dim)

SimSiam uses a symmetric negative cosine similarity loss and does therefore not require any negative samples. We build a criterion and an optimizer.



In [None]:
from torchinfo import summary

summary(backbone, input_size=(batch_size, 3, input_size, input_size))

## Training setup

In [None]:
# SimSiam uses a symmetric negative cosine similarity loss
criterion = lightly.loss.NegativeCosineSimilarity()

# scale the learning rate
lr = 0.05 * batch_size / 256
# use SGD with momentum and weight decay
optimizer = torch.optim.SGD(
    model.parameters(),
    lr=lr,
    momentum=0.9,
    weight_decay=5e-4
)

## Training

In [None]:
import time

### Loop

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)

avg_loss = 0.
avg_output_std = 0.
for e in range(epochs):
    
    t0 = time.time()  # Added by me.

    for (x0, x1), _, _ in dataloader_train_simsiam:

        # move images to the gpu
        x0 = x0.to(device)
        x1 = x1.to(device)

        # run the model on both transforms of the images
        # we get projections (z0 and z1) and
        # predictions (p0 and p1) as output
        z0, p0 = model(x0)
        z1, p1 = model(x1)

        # apply the symmetric negative cosine similarity
        # and run backpropagation
        loss = 0.5 * (criterion(z0, p1) + criterion(z1, p0))
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        # calculate the per-dimension standard deviation of the outputs
        # we can use this later to check whether the embeddings are collapsing
        output = p0.detach()
        output = torch.nn.functional.normalize(output, dim=1)

        output_std = torch.std(output, 0)
        output_std = output_std.mean()

        # use moving averages to track the loss and standard deviation
        w = 0.9
        avg_loss = w * avg_loss + (1 - w) * loss.item()
        avg_output_std = w * avg_output_std + (1 - w) * output_std.item()

    # the level of collapse is large if the standard deviation of the l2
    # normalized output is much smaller than 1 / sqrt(dim)
    collapse_level = max(0., 1 - math.sqrt(out_dim) * avg_output_std)
    # print intermediate results
    print(f'[Epoch {e:3d}] '
          f'Loss = {avg_loss:.2f} | '
          f'Collapse Level: {collapse_level:.2f} / 1.00 | '
          f'Duration: {(time.time()-t0):.2f} s')

### Saving the pretrained model

In [None]:
torch.save(backbone.state_dict(), 'pytorch_models/simsiam_backbone_resnet18')

### Embeddings for the samples of the test dataset

In [None]:
embeddings = []
filenames = []

# disable gradients for faster calculations
model.eval()
with torch.no_grad():
    for i, (x, _, fnames) in enumerate(dataloader_test):
        # move the images to the gpu
        x = x.to(device)
        # embed the images with the pre-trained backbone
        y = model.backbone(x).flatten(start_dim=1)
        # store the embeddings and filenames in lists
        embeddings.append(y)
        filenames = filenames + list(fnames)

# concatenate the embeddings and convert to numpy
embeddings = torch.cat(embeddings, dim=0)
embeddings = embeddings.cpu().numpy()

# Scatter Plot and Nearest Neighbors

In [None]:
# for plotting
import os
from PIL import Image

import matplotlib.pyplot as plt
import matplotlib.offsetbox as osb
from matplotlib import rcParams as rcp

# for resizing images to thumbnails
import torchvision.transforms.functional as functional

# for clustering and 2d representations
from sklearn import random_projection

In [None]:
# for the scatter plot we want to transform the images to a two-dimensional
# vector space using a random Gaussian projection
projection = random_projection.GaussianRandomProjection(n_components=2)
embeddings_2d = projection.fit_transform(embeddings)

# normalize the embeddings to fit in the [0, 1] square
M = np.max(embeddings_2d, axis=0)
m = np.min(embeddings_2d, axis=0)
embeddings_2d = (embeddings_2d - m) / (M - m)

In [None]:
def get_scatter_plot_with_thumbnails():
    """Creates a scatter plot with image overlays.
    """
    # initialize empty figure and add subplot
    fig = plt.figure(figsize=(15, 15))
    fig.suptitle('Scatter Plot of the Sentinel-2 Dataset')
    ax = fig.add_subplot(1, 1, 1)
    # shuffle images and find out which images to show
    shown_images_idx = []
    shown_images = np.array([[1., 1.]])
    iterator = [i for i in range(embeddings_2d.shape[0])]
    np.random.shuffle(iterator)
    for i in iterator:
        # only show image if it is sufficiently far away from the others
        dist = np.sum((embeddings_2d[i] - shown_images) ** 2, 1)
        if np.min(dist) < 2e-3:
            continue
        shown_images = np.r_[shown_images, [embeddings_2d[i]]]
        shown_images_idx.append(i)

    # plot image overlays
    for idx in shown_images_idx:
        thumbnail_size = int(rcp['figure.figsize'][0] * 10.)
        path = os.path.join(data_dir + '/test/', filenames[idx])
        img = Image.open(path)
        img = functional.resize(img, thumbnail_size)
        img = np.array(img)
        img_box = osb.AnnotationBbox(
            osb.OffsetImage(img, cmap=plt.cm.gray_r),
            embeddings_2d[idx],
            pad=0.2,
        )
        ax.add_artist(img_box)

    # set aspect ratio
    ratio = 1. / ax.get_data_ratio()
    ax.set_aspect(ratio, adjustable='box')


# get a scatter plot with thumbnail overlays
get_scatter_plot_with_thumbnails()

In [None]:
# List of subdirectories (classes)
directory_list = []
for root, dirs, files in os.walk(data_dir + '/test/'):
    for dirname in sorted(dirs):
        directory_list.append(os.path.join(root, dirname))
        print(dirname)

# print(directory_list)

In [None]:
example_images = []
for classes in directory_list:
    random_file = np.random.choice(os.listdir(classes))
    path_to_random_file = classes + '/' + random_file
    example_images.append(path_to_random_file[40:])
    print(path_to_random_file)

# example_images = example_images[:3]
# print(example_images)

In [None]:
def get_image_as_np_array(filename: str):
    """
    Loads the image with filename and returns it as a numpy array.
    
    """
    img = Image.open(filename)
    return np.asarray(img)


def get_image_as_np_array_with_frame(filename: str, w: int = 5):
    """
    Returns an image as a numpy array with a black frame of width w.

    """
    img = get_image_as_np_array(filename)
    ny, nx, _ = img.shape
    # create an empty image with padding for the frame
    framed_img = np.zeros((w + ny + w, w + nx + w, 3))
    framed_img = framed_img.astype(np.uint8)
    # put the original image in the middle of the new one
    framed_img[w:-w, w:-w] = img
    return framed_img


def plot_nearest_neighbors_nxn(example_image: str, i: int):
    """
    Plots the example image and its eight nearest neighbors.

    """
    n_subplots = 6
    # initialize empty figure
    fig = plt.figure(figsize=(10, 10))
    fig.suptitle(f"Nearest Neighbor Plot {i + 1}")
    #
    example_idx = filenames.index(example_image)
    # get distances to the cluster center
    distances = embeddings - embeddings[example_idx]
    distances = np.power(distances, 2).sum(-1).squeeze()
    # sort indices by distance to the center
    nearest_neighbors = np.argsort(distances)[:n_subplots]
    # show images
    for plot_offset, plot_idx in enumerate(nearest_neighbors):
        ax = fig.add_subplot(3, 3, plot_offset + 1)
        # get the corresponding filename
        fname = os.path.join(data_dir + '/test/', filenames[plot_idx])
        if plot_offset == 0:
            ax.set_title(f"Example Image")
            plt.imshow(get_image_as_np_array_with_frame(fname))
        else:
            plt.imshow(get_image_as_np_array(fname))
        # let's disable the axis
        plt.axis("off")

# Show example images for each cluster.
for i, example_image in enumerate(example_images):
    plot_nearest_neighbors_nxn(example_image, i)