In [None]:
# Chosen manually for class contrast diversity
INDEXES = {
    "BI": "((S1 + R) - (N + B))/((S1 + R) + (N + B))",
    "BNDVI": "(N - B)/(N + B)",
    "MGRVI": "(G ** 2.0 - R ** 2.0) / (G ** 2.0 + R ** 2.0)",
    "NDCI": "(RE1 - R)/(RE1 + R)",
    "NLI": "((N ** 2) - R)/((N ** 2) + R)",
}

In [None]:
import torch.nn as nn
from LandcoverDataset import LandcoverDataset as LD
from Model import CNN
from utils import *
from random import choice as c
import matplotlib.pyplot as plt
from matplotlib.ticker import AutoMinorLocator

# Datasets loading into RAM

In [None]:
BATCH = 8

TRAIN = LD(
    img_path="datasets/USA/train/images/",
    mask_path="datasets/USA/train/masks/",
    batch_size=BATCH,
    indexes=list(INDEXES.values()),
    n_random=200
)
VALIDATION = LD(
    img_path="datasets/USA/valid/images/",
    mask_path="datasets/USA/valid/masks/",
    batch_size=BATCH,
    transforms=False,
    indexes=list(INDEXES.values()),
    n_random=50
)
TEST = LD(
    img_path="datasets/Russia/test/images/",
    mask_path="datasets/Russia/test/masks/",
    batch_size=BATCH,
    transforms=False,
    indexes=list(INDEXES.values()),
    n_random=50
)
RUSSIA = LD(
    img_path="datasets/Russia/train/images/",
    mask_path="datasets/Russia/train/masks/",
    batch_size=BATCH,
    indexes=list(INDEXES.values()),
    n_random=200
)

## Some insights about data

In [None]:
image_name, mask_name = TRAIN.rand_samp_names()
print(1, tif_info(image_name))
print(2, tif_info(mask_name))
print(3)
TRAIN.getinfo()

## Just look at it

In [None]:
TRAIN.plot_sample(90)

In [None]:
TRAIN.plot_sample(90,index="((N ** 2) - R)/((N ** 2) + R)")

In [None]:
TRAIN.plot_sample(90,index="(G ** 2.0 - R ** 2.0) / (G ** 2.0 + R ** 2.0)")

# The model

In [None]:
INPUT_CHANNELS = 10 + len(INDEXES)
N_CLASSES = 5

model = CNN([
    nn.Conv2d(INPUT_CHANNELS, 32, kernel_size=3, padding=1),
    nn.BatchNorm2d(32),
    nn.ReLU(),
    nn.Conv2d(32, 32, kernel_size=3, padding=1),
    nn.BatchNorm2d(32),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),

    nn.Conv2d(32, 64, kernel_size=3, padding=1),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.Conv2d(64, 64, kernel_size=3, padding=1),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),

    nn.Conv2d(64, 128, kernel_size=3, padding=1),
    nn.BatchNorm2d(128),
    nn.ReLU(),
    nn.Conv2d(128, 128, kernel_size=3, padding=1),
    nn.BatchNorm2d(128),
    nn.ReLU(),
    nn.Conv2d(128, N_CLASSES, kernel_size=1),
    nn.Upsample(size=(512, 512), mode='nearest')
])

## Hyperparameters

In [None]:
learning_rate = 0.001
n_epochs = 20
saving_threshold = 0.51

## Start training

In [None]:
model.train(TRAIN, VALIDATION, n_epochs, learning_rate, saving_threshold)

In [None]:
VALIDATION.plot_prediction(model, 5)

## Give it a name and pickle it

In [None]:
modelname = f"lr={learning_rate},n_epochs={n_epochs},thres={saving_threshold}"
model.pickle(f"models/{modelname}.torch")

In [None]:
model.unpickle(f"models/{modelname}.torch")

## Here is the training story (pickles alongside the model)

In [None]:
model.plot_training_history()

# Experiment
Compare performance of a model on a new domain after 1 epoch of training with/without partial weights freezing.

In [None]:
test_v = []
test_t = []
test_just = []
test_frozen_0x6 = []
LEARNING_RATE = 1e-3
N_REPEATS = 10

for _ in range(N_REPEATS):
    model.unpickle(f"models/{modelname}.torch")
    test_v.append(model.test(VALIDATION))
    test_t.append(model.test(TEST))

    model.unpickle(f"models/{modelname}.torch")
    model.train(RUSSIA, TEST, 1, LEARNING_RATE, 2)
    test_just.append(model.test(TEST))

    model.unpickle(f"models/{modelname}.torch")
    for i, param in enumerate(model.model.parameters()):
        if i <= 6:
            param.requires_grad = False
    model.train(RUSSIA, TEST, 1, LEARNING_RATE, 2)
    test_frozen_0x6.append(model.test(TEST))

In [None]:
import json

data = {
    'test_v': test_v,
    'test_t': test_t,
    'test_just': test_just,
    'test_frozen_0x6': test_frozen_0x6,
}

with open('experiment_results/results.json', 'w') as f:
    json.dump(data, f, indent=4)
# with open('experiment_results/results.json', 'r') as f:
#     data = json.load(f)

In [None]:
def calculate_stats(results):
    metrics = ['test_accuracy', 'test_recall', 'test_precision', 'test_f1']
    stats = {metric: {'mean': [], 'std': []} for metric in metrics}
    
    for key in results:
        for metric in metrics:
            values = [run[metric] for run in results[key]]
            stats[metric]['mean'].append(np.mean(values))
            stats[metric]['std'].append(np.std(values))
    
    return stats

In [None]:
stats = calculate_stats(data)

labels = ['Validation', 'Baseline', 'Just', 'Frozen 6']
metrics = ['test_accuracy', 'test_recall', 'test_precision', 'test_f1']
colors = ['#41BA9BFF', '#616161FF', '#FB6E52FF', '#4FC0E8FF']

fig, axes = plt.subplots(4, 1, figsize=(10, 6))

for i, metric in enumerate(metrics):
    ax = axes[i]
    values = stats[metric]['mean']
    errs = stats[metric]['std']
    
    ax.barh(labels, values, color=colors, xerr=errs, capsize=5)    
    
    min_value = min(v - e for v, e in zip(values, errs))
    max_value = max(v + e for v, e in zip(values, errs))
    ax.set_xlim(min_value - 0.05, max_value + 0.05)
    
    ax.set_xlabel(metric.replace('test_', '').capitalize())
    
    ax.xaxis.set_major_locator(plt.AutoLocator())
    ax.xaxis.set_minor_locator(AutoMinorLocator())

plt.tight_layout()
# plt.show()
plt.savefig('assets/2.svg')