**FIRST ATTEMP TO APPLY SSL TO THE SENTINEL-2 DATASET**

Reference tutorial: https://docs.lightly.ai/tutorials/package/tutorial_simsiam_esa.html

In [None]:
%load_ext pycodestyle_magic

In [None]:
%pycodestyle_on

***

***

# Imports

## Packages and modules

In [None]:
# Main.
import utils

# OS module.
import os

# PyTorch.
import torch
import torchvision
from torchinfo import summary

# Data management.
import numpy as np
import pandas as pd

# Lightly.
import lightly

# Training checks.
from datetime import datetime
import time
import copy
import math

from lightly.utils.debug import std_of_l2_normalized

# Showing images in the notebook.
import IPython

# For plotting.
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.offsetbox as osb
from matplotlib import rcParams as rcp
import seaborn as sns

# For resizing images to thumbnails.
import torchvision.transforms.functional as functional

# For clustering and 2d representations.
from sklearn import random_projection

## Settings

In [None]:
# Handling class imbalance.
handle_imb_classes = True

# Hyperparamenters.
exp = utils.Experiment(epochs=2,
                       batch_size=64)
output_dir_fig = 'figures/'
output_dir_model = 'pytorch_models/'
print(exp.device)

In [None]:
# Dimension of the embeddings.
num_ftrs = 512

# Dimension of the output of the prediction and projection heads.
out_dim = proj_hidden_dim = 512

# The prediction head uses a bottleneck architecture.
pred_hidden_dim = 128

## Reproducibility

In [None]:
exp.reproducibility()

***

***

# Loading dataset

In [None]:
# List of trained models.
datasets_dir = 'datasets/'

# Get the subsets with full path.
data_dirs = utils.listdir_fullpath(datasets_dir)

# Leave out unwanted subsets (0_Raw and Clothing-dataset).
data_dirs = data_dirs[2:]
for dirs in data_dirs:
    print(dirs)

# Select the target dataset.
data_dir_target = data_dirs[1]
print('\nSelected: ' + data_dir_target)

# Ratio.
ratio = data_dir_target[
    data_dir_target.index("("):data_dir_target.index(")")+1
]
print(ratio)

# Load mean and std from file.
mean, std = utils.load_mean_std_values(data_dir_target)
print(mean)
print(std)

## Custom tranforms (w/o normalization)

Define the augmentations for self-supervised learning.

In [None]:
# Data augmentations for the train dataset.
train_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((exp.input_size, exp.input_size)),
    torchvision.transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
    torchvision.transforms.RandomApply([
            torchvision.transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
        ], p=0.8),  # not strengthened
    torchvision.transforms.RandomGrayscale(p=0.2),
    # torchvision.transforms.RandomApply([
    #     simsiam.loader.GaussianBlur([.1, 2.])
    # ], p=0.5),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean['train'], std['train'])
])

# Data augmentations for the val and test datasets.
val_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((exp.input_size, exp.input_size)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean['val'], std['val'])
])

# Data augmentations for the val and test datasets.
test_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((exp.input_size, exp.input_size)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean['test'], std['test'])
])

## ImageFolder

In [None]:
# Loading the three datasets.
train_data = torchvision.datasets.ImageFolder(data_dir_target + '/train/')

val_data = torchvision.datasets.ImageFolder(data_dir_target + '/val/')

test_data = torchvision.datasets.ImageFolder(data_dir_target + '/test/')

# Building the lightly datasets from the PyTorch datasets.
train_data_lightly = lightly.data.LightlyDataset.from_torch_dataset(train_data)

val_data_lightly = lightly.data.LightlyDataset.from_torch_dataset(
    val_data,
    transform=val_transform
)

test_data_lightly = lightly.data.LightlyDataset.from_torch_dataset(
    test_data,
    transform=test_transform
)

## Dealing with imbalanced data

In [None]:
if handle_imb_classes:

    # Creating a list of labels of samples.
    train_sample_labels = train_data.targets

    # Calculating the number of samples per label/class.
    class_sample_count = np.unique(train_sample_labels,
                                   return_counts=True)[1]
    print(class_sample_count)

    # Weight per sample not per class.
    weight = 1. / class_sample_count
    samples_weight = np.array([weight[t] for t in train_sample_labels])

    # Casting.
    samples_weight = torch.from_numpy(samples_weight)
    samples_weigth = samples_weight.double()

    # Sampler, imbalanced data.
    sampler = torch.utils.data.WeightedRandomSampler(
        samples_weight,
        len(samples_weight)
    )
    shuffle = False

else:
    sampler = None
    shuffle = True

print(f'Sampler: {sampler}')
print(f'Shuffle: {shuffle}')

## Collate functions

PyTorch uses a Collate Function to combine the data in your batches together.

BaseCollateFunction (base class) takes a batch of images as input and transforms each image into two different augmentations with the help of random transforms. The images are then concatenated such that the output batch is exactly twice the length of the input batch.

In [None]:
# Base class for other collate implementations.
# This allows training.
collate_fn_train = lightly.data.collate.BaseCollateFunction(train_transform)

## PyTorch dataloaders

In [None]:
# Dataloader for training.
dataloader_train_simsiam = torch.utils.data.DataLoader(
    train_data_lightly,
    batch_size=exp.batch_size,
    shuffle=shuffle,
    collate_fn=collate_fn_train,
    drop_last=True,
    num_workers=exp.num_workers,
    worker_init_fn=exp.seed_worker,
    generator=exp.g,
    sampler=sampler
)

# Dataloader for embedding (val).
dataloader_val = torch.utils.data.DataLoader(
    val_data_lightly,
    batch_size=exp.batch_size,
    shuffle=False,
    drop_last=False,
    num_workers=exp.num_workers,
    worker_init_fn=exp.seed_worker,
    generator=exp.g
)

# Dataloader for embedding (test).
dataloader_test = torch.utils.data.DataLoader(
    test_data_lightly,
    batch_size=exp.batch_size,
    shuffle=False,
    drop_last=False,
    num_workers=exp.num_workers,
    worker_init_fn=exp.seed_worker,
    generator=exp.g
)

## Check the balance and size of the dataset

In [None]:
# Check samples per class in train dataset.
print(np.unique(train_data.targets, return_counts=True)[1])

In [None]:
# Check samples per class in train dataset.
print(np.unique(val_data.targets, return_counts=True)[1])

In [None]:
# Check samples per class in test dataset.
print(np.unique(test_data.targets, return_counts=True)[1])

In [None]:
print('N batches in train dataset: ' + str(len(dataloader_train_simsiam)))

# Check the size of each dataset.
print(len(train_data.targets))
print(len(val_data.targets))
print(len(test_data.targets))

## Check the distribution of samples in the dataloader (lightly dataset)

## See some samples (lightly dataset)

In [None]:
# Accessing Data and Targets in a PyTorch DataLoader.
for i, (images, labels, names) in enumerate(dataloader_train_simsiam):
    img = images[0][0]
    label = labels[0]
    print(images[0].shape)
    print(labels.shape)
    plt.title("Label: " + str(int(label)))
    plt.imshow(torch.permute(img, (1, 2, 0)))
    plt.show()
    if i == 0:
        break  # Only a few batches.

# Self-supervised models

In [None]:
model_name = 'simclr'

## Creation

In [None]:
from models import SimSiam
from models import SimCLRModel

Copied from Lightly tutorials

## Backbone net (w/ ResNet18)

This is different from the tutorial: resnet without pretrained weights (not now).

In [None]:
# Resnet trained from scratch.
resnet = torchvision.models.resnet18(
    # weights=torchvision.models.ResNet18_Weights.DEFAULT
    weights=None
)

# Removing head from resnet. Embedding.
backbone = torch.nn.Sequential(*list(resnet.children())[:-1])

# Model creation.
if model_name == 'simsiam':
    model = SimSiam(backbone, num_ftrs, proj_hidden_dim,
                    pred_hidden_dim, out_dim)
elif model_name == 'simclr':
    hidden_dim = resnet.fc.in_features
    model = SimCLRModel(backbone, hidden_dim)

In [None]:
# Model's backbone structure.
summary(
    model.backbone,
    input_size=(exp.batch_size, 3, exp.input_size, exp.input_size),
    device=exp.device
)

## Training setup

SimSiam uses a symmetric negative cosine similarity loss and does therefore not require any negative samples. We build a criterion and an optimizer.



In [None]:
# Scale the learning rate.
# lr = 0.05 * exp.batch_size / 256
lr = 0.2

# Use SGD with momentum and weight decay.
optimizer = torch.optim.SGD(
    model.parameters(),
    lr=lr,
    momentum=0.9,
    weight_decay=1e-4
)

## Training

### Loop

In [None]:
print(f"Using {exp.device} device")
model.to(exp.device)

# Setup.
avg_loss = 0.
avg_output_std = 0.
avg_rep_collapse = 0.

# Saving best model's weights.
# best_model_wts = copy.deepcopy(model.state_dict())
lowest_loss = 10000
lowest_rep_collapse = 10000

# Main training loop.
for e in range(exp.epochs):

    # Timer added.
    t0 = time.time()

    # Training enabled.
    model.train()

    # Iterating through the dataloader (lightly dataset is different).
    batch_id = 0
    running_train_loss = 0.
    for i, ((x0, x1), _, _) in enumerate(dataloader_train_simsiam):

        # Move images to the GPU.
        x0 = x0.to(exp.device)
        x1 = x1.to(exp.device)

        # Run the model on both transforms of the images:
        # We get projections (z0 and z1) and
        # predictions (p0 and p1) as output.
        if model_name == 'simsiam':
            z0, p0 = model(x0)
            z1, p1 = model(x1)
            loss = 0.5 * (model.criterion(z0, p1) + model.criterion(z1, p0))
        elif model_name == 'simclr':
            z0 = model(x0)
            z1 = model(x1)
            loss = model.criterion(z0, z1)

        # Averaged loss across all training examples * batch_size.
        running_train_loss += loss.item() * exp.batch_size

        # Run backpropagation.
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        if model_name == 'simsiam':

            # Calculate the per-dimension standard deviation of the outputs.
            # We can use this later to check whether the embeddings are collapsing.
            output = p0.detach()
            output = torch.nn.functional.normalize(output, dim=1)

            output_std = torch.std(output, 0)
            output_std = output_std.mean()

            # Use moving averages to track the loss and standard deviation.
            w = 0.9
            avg_loss = w * avg_loss + (1 - w) * loss.item()
            avg_output_std = w * avg_output_std + (1 - w) * output_std.item()

            # Use moving averages to track representation collapse.
            avg_rep_collapse = (w * avg_rep_collapse
                                + (1 - w) * std_of_l2_normalized(output).item())

        if batch_id % 250 == 249:

            if model_name == 'simsiam':
                print(f'T[{e + 1}, {batch_id + 1:5d}] | '
                      f'Loss: {avg_loss:.4f} | '
                      f'Running train loss: {running_train_loss/(i*exp.batch_size):.4f} | '
                      f'Representation std: {avg_rep_collapse:.4f}')

            elif model_name == 'simclr':
                print(f'T[{e + 1}, {batch_id + 1:5d}] | '
                      f'Running train loss: {running_train_loss/(i*exp.batch_size):.4f}')

        batch_id += 1
    
    # Loss averaged across all training examples for the current epoch.
    epoch_train_loss = running_train_loss / len(dataloader_train_simsiam.sampler)

    # The evaluation process was not okey (it's been deleted).

    # Save model.
    save_model = epoch_train_loss < lowest_loss
    if save_model:

        # Update new lowest losses
        lowest_loss = epoch_train_loss
        lowest_rep_collapse = avg_rep_collapse

        # Move the model to CPU before saving it
        # and then back to the GPU.
        model.to('cpu')
        model.save(e,
                   epoch_train_loss,
                   handle_imb_classes,
                   ratio,
                   output_dir_model,
                   avg_rep_collapse=lowest_rep_collapse)
        model.to(exp.device)

    # Show some stats per epoch completed.
    print(f'[Epoch {e:3d}] | '
          f'Train loss: {epoch_train_loss:.4f} | '
          f'Duration: {(time.time()-t0):.2f} s | '
          f'Saved: {save_model}')

#     # Print intermediate results (timing added).
#     print(f'[Epoch {e:3d}] '
#           f'Loss: {avg_loss:.4f} | '
#           f'Collapse Level: {collapse_level:.4f} / 1.00 | '
#           f'Representation std: {avg_rep_collapse:.4f} / {rep_ideal:4f} | '
#           f'Duration: {(time.time()-t0):.2f} s | '
#           f'Saved: {save_model}')

In [None]:
print(model)

Collapse level: the closer to zero the better

A value close to 0 indicates that the representations have collapsed. A value close to 1/sqrt(dimensions), where dimensions are the number of representation dimensions, indicates that the representations are stable. 

### Checking the weights of the last model

In [None]:
# First convolutional layer weights.
print(model.backbone[0])
print(model.backbone[0].weight[63])

***

***

# Reduce dimensionality

## Calculate embeddings

In [None]:
# Empty lists.
embeddings = []
labels = []

# Disable gradients for faster calculations.
# Put the model in evaluation mode.
model.eval()
with torch.no_grad():
    for i, (x, y, fnames) in enumerate(dataloader_val):

        # Move the images to the GPU.
        x = x.to(exp.device)
        y = y.to(exp.device)

        # Embed the images with the pre-trained backbone.
        emb = model.backbone(x).flatten(start_dim=1)

        # Store the embeddings and filenames in lists.
        embeddings.append(emb)
        labels.append(y)

# Concatenate the embeddings and convert to numpy.
embeddings = torch.cat(embeddings, dim=0).to('cpu').numpy()
labels = torch.cat(labels, dim=0).to('cpu').numpy()

# Show shapes.
print(np.shape(embeddings))
print(np.shape(labels))

# PCA

https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html

In [None]:
utils.pca_computation_and_plot(embeddings,
                               labels,
                               exp.seed,
                               plot='all',
                               filename=str(model),
                               save_2d=True)

# t-SNE

https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html

In [None]:
utils.tsne_computation_and_plot(embeddings,
                                labels,
                                exp.seed,
                                plot='all',
                                filename=str(model),
                                save_2d=True)

***

***

# Check each model's performance/collapse on val data

In [None]:
def get_scatter_plot_with_thumbnails_axes(ax, title=''):
    """
    Creates a scatter plot with image overlays
    that are plotted in a particular ax position.

    """

    # Shuffle images and find out which images to show.
    shown_images_idx = []
    shown_images = np.array([[1., 1.]])
    iterator = [i for i in range(embeddings_2d.shape[0])]
    np.random.shuffle(iterator)
    for i in iterator:

        # Only show image if it is sufficiently far away from the others.
        dist = np.sum((embeddings_2d[i] - shown_images) ** 2, 1)
        if np.min(dist) < 2e-3:
            continue
        shown_images = np.r_[shown_images, [embeddings_2d[i]]]
        shown_images_idx.append(i)

    # Plot image overlays.
    for idx in shown_images_idx:
        thumbnail_size = int(rcp['figure.figsize'][0] * 2.5)  # 2.
        path = os.path.join(data_dir_test, filenames[idx])
        img = Image.open(path)
        img = functional.resize(img, thumbnail_size)
        img = np.array(img)
        img_box = osb.AnnotationBbox(
            osb.OffsetImage(img, cmap=plt.cm.gray_r),
            embeddings_2d[idx],
            pad=0.2,
        )
        ax.add_artist(img_box)

    # Set aspect ratio.
    ratio = 1. / ax.get_data_ratio()
    ax.set_aspect(ratio, adjustable='box')
    ax.title.set_text(title)

In [None]:
# Validation dataset.
data_dir_test = data_dir_target + '/val/'
print(data_dir_test)

# List of trained models.
model_list = []
# root_dir = 'pytorch_models/simsiam/'
root_dir = 'pytorch_models/simclr/'
for root, dirs, files in os.walk(root_dir):
    for i, filename in enumerate(sorted(files, reverse=False)):
        model_list.append(root + filename)
        print(f'{i:02}: {filename}')

# Plot setup.
ncols = 5
nrows = int(math.ceil(len(model_list) / ncols))

fig, axes = plt.subplots(nrows=nrows,
                         ncols=ncols,
                         figsize=(12*ncols, 12*nrows))

# Convert the array to 1 dimension.
axes = axes.ravel()

# Main loop over the models.
for model_id, model_name in enumerate(model_list):

    # Load model weights.
    model.backbone.load_state_dict(torch.load(model_name))

    # Empty lists.
    embeddings = []
    filenames = []

    # Disable gradients for faster calculations.
    # Put the model in evaluation mode.
    model.eval()
    with torch.no_grad():
        for i, (x, _, fnames) in enumerate(dataloader_val):

            # Move the images to the GPU.
            x = x.to(exp.device)

            # Embed the images with the pre-trained backbone.
            y = model.backbone(x).flatten(start_dim=1)

            # Store the embeddings and filenames in lists.
            embeddings.append(y)
            filenames = filenames + list(fnames)

    # Concatenate the embeddings and convert to numpy.
    embeddings = torch.cat(embeddings, dim=0)
    embeddings = embeddings.cpu().numpy()

    # For the scatter plot we want to transform the images to a
    # 2-D vector space using a random Gaussian projection.
    projection = random_projection.GaussianRandomProjection(
        n_components=2,
        random_state=exp.seed
    )
    embeddings_2d = projection.fit_transform(embeddings)

    # Normalize the embeddings to fit in the [0, 1] square.
    M = np.max(embeddings_2d, axis=0)
    m = np.min(embeddings_2d, axis=0)
    embeddings_2d = (embeddings_2d - m) / (M - m)

    # Get a scatter plot with thumbnail overlays.
    get_scatter_plot_with_thumbnails_axes(axes[model_id],
                                          title=model_name[49:104])

    # Show progress.
    print(f'Subplot of model-{model_id} done!',
          end='\r',
          flush=True)

# Save figure.
fig.savefig(f'{output_dir_fig}models_knn_{datetime.now():%Y_%m_%d_%H_%M_%S}.pdf',
            bbox_inches='tight')

***

***

# Embeddings for the samples of the test dataset

## Setup

In [None]:
# Test dataset.
data_dir_test = data_dir_target + '/test/'
print(data_dir_test)

In [None]:
# load best model weights
idx = 0
print(model_list[idx])
model.backbone.load_state_dict(torch.load(model_list[idx]))

## Compute embeddings

In [None]:
# Empty lists.
embeddings = []
filenames = []

# Disable gradients for faster calculations.
# Put the model in evaluation mode.
model.eval()
with torch.no_grad():
    for i, (x, _, fnames) in enumerate(dataloader_test):

        # Move the images to the GPU.
        x = x.to(exp.device)

        # Embed the images with the pre-trained backbone.
        y = model.backbone(x).flatten(start_dim=1)

        # Store the embeddings and filenames in lists.
        embeddings.append(y)
        filenames = filenames + list(fnames)

# Concatenate the embeddings and convert to numpy.
embeddings = torch.cat(embeddings, dim=0)
embeddings = embeddings.cpu().numpy()

## Projection to 2D space

In [None]:
# For the scatter plot we want to transform the images to a two-dimensional
# vector space using a random Gaussian projection.
projection = random_projection.GaussianRandomProjection(
    n_components=2,
    random_state=exp.seed
)
embeddings_2d = projection.fit_transform(embeddings)

# Normalize the embeddings to fit in the [0, 1] square.
M = np.max(embeddings_2d, axis=0)
m = np.min(embeddings_2d, axis=0)
embeddings_2d = (embeddings_2d - m) / (M - m)

## Scatter plots

In [None]:
# Initialize empty figure and add subplot.
fig = plt.figure(figsize=(15, 15))
ax = fig.add_subplot(1, 1, 1)

# Get a scatter plot with thumbnail overlays.
get_scatter_plot_with_thumbnails_axes(
    ax,
    title='Scatter Plot of the Sentinel-2 Dataset'
)

## Nearest Neighbors

### Pick up one random sample per class

In [None]:
# List of subdirectories (classes).
directory_list = []
for root, dirs, files in os.walk(data_dir_test):
    for dirname in sorted(dirs):
        directory_list.append(os.path.join(root, dirname))
        # print(dirname)

In [None]:
# List of files (samples).
example_images = []
for classes in directory_list:

    # Random samples.
    random_file = np.random.choice(os.listdir(classes))
    path_to_random_file = classes + '/' + random_file

    # Only class and filename.
    start_chr = path_to_random_file.index('test/') + 5

    # Append filename.
    example_images.append(path_to_random_file[start_chr:])
    # print(example_images)

### Look for similar images

In [None]:
def get_image_as_np_array(filename: str):
    """
    Loads the image with filename and returns it as a numpy array.

    """
    img = Image.open(filename)
    return np.asarray(img)


def get_image_as_np_array_with_frame(filename: str, w: int = 5):
    """
    Returns an image as a numpy array with a black frame of width w.

    """
    img = get_image_as_np_array(filename)
    ny, nx, _ = img.shape

    # Create an empty image with padding for the frame.
    framed_img = np.zeros((w + ny + w, w + nx + w, 3))
    framed_img = framed_img.astype(np.uint8)

    # Put the original image in the middle of the new one.
    framed_img[w:-w, w:-w] = img
    return framed_img


def plot_nearest_neighbors_nxn(example_image: str, i: int):
    """
    Plots the example image and its eight nearest neighbors.

    """
    n_subplots = 6

    # Initialize empty figure.
    fig = plt.figure(figsize=(10, 10))
    fig.suptitle(f"Nearest Neighbor Plot {i + 1}")

    # Get indexes.
    example_idx = filenames.index(example_image)

    # Get distances to the cluster center.
    distances = embeddings - embeddings[example_idx]
    distances = np.power(distances, 2).sum(-1).squeeze()

    # Sort indices by distance to the center.
    nearest_neighbors = np.argsort(distances)[:n_subplots]

    # Show images.
    for plot_offset, plot_idx in enumerate(nearest_neighbors):
        ax = fig.add_subplot(3, 3, plot_offset + 1)

        # Get the corresponding filename.
        fname = os.path.join(data_dir_test, filenames[plot_idx])
        if plot_offset == 0:
            ax.set_title(f"Example Image")
            plt.imshow(get_image_as_np_array_with_frame(fname))
        else:
            plt.imshow(get_image_as_np_array(fname))

        # Let's disable the axis.
        plt.axis("off")

In [None]:
# Show example images for each cluster.
for i, example_image in enumerate(example_images):
    plot_nearest_neighbors_nxn(example_image, i)