In [None]:
# ------ Copyright (C) 2024 University of Strathclyde and Author ------
# ---------------- Author: Robert Cowlishaw ------------------------
# ----------- e-mail: robert.cowlishaw.2017@uni.strath.ac.uk ----------------

# Centralised Crop Type Detection with UNet - Model Trainer

These are used to compare the loss ratios, number of crop classes and the softmax thresholding values and the associated F1-scores from the output. This file is designed to train the model and then `centralised_model_testset.ipynb` can be used to test the output model and produce metrics.

In [None]:
#@title {vertical-output: true}

!pip install datasets -q

In [None]:
#@title {vertical-output: true}

from pathlib import Path
my_file = Path("smart_dao")
if not my_file.is_dir():
    !git clone https://github.com/strath-ace/smart-dao smart_dao

!cp smart_dao/agriculture-federated-learning/centralised-unet/UNET.py .
!cp smart_dao/agriculture-federated-learning/centralised-unet/utils.py .

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import json
import random
from functools import reduce
import itertools
from collections import defaultdict
import time
import copy
import json
from tqdm import tqdm
import sklearn.metrics

from datasets import load_dataset

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms, models
from torchvision.transforms import ToTensor
import torch.optim as optim
from torch.optim import lr_scheduler

from utils import *
from UNET import UNet

In [None]:
#@title {vertical-output: true}

dataset = load_dataset("0x365/eo-crop-type-belgium", split="train")

ds = dataset.with_format("torch")

ds_split = ds.train_test_split(test_size=0.2)

from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data, num_classes=21):
        self.data = data
        self.bands = ["B02","B03","B04","B05","B06","B07","B08","B11","B12"]
        self.num_classes = num_classes
        self.processed_data = []

        for sample in tqdm(data, desc="Building dataset"):
            inputs = torch.stack([sample[k][0] for k in self.bands], dim=0).float() / 255
            labels = sample["label"][0].long()
            labels[labels > self.num_classes-1] = 0
            labels_one_hot = F.one_hot(labels, num_classes=self.num_classes).permute(2, 0, 1).float()
            self.processed_data.append((inputs, labels, labels_one_hot))

    def __getitem__(self, idx):
        sample = self.processed_data[idx]
        return sample

    def __len__(self):
        return len(self.data)

    def unqiue(self):
        return torch.unqiue(self.data["label"])

ds_train = CustomDataset(ds_split["train"])#, num_classes=11)
ds_test = CustomDataset(ds_split["test"])

print(ds_train)

In [None]:
fig, axs = plt.subplots(5, 10, figsize=(30,15))

for j in range(5):
    tester = ds_train.__getitem__(j)
    inputs, labels, labels_one_hot = tester

    for i in range(len(inputs)):
        axs[j,i].imshow(inputs[i])
    axs[j,-1].imshow(labels)

In [None]:
#@title {vertical-output: true}

bands = ["B02","B03","B04","B05","B06","B07","B08","B11","B12"]

####################

# Data batch loader sizes
train_batch_size = 48
test_batch_size = 48

# Training epochs
training_epochs = 50

num_classes = 21

####################

train_dataloader = DataLoader(ds_train, batch_size=train_batch_size, shuffle=True)#, collate_fn=custom_collate_fn)
test_dataloader = DataLoader(ds_test, batch_size=test_batch_size, shuffle=True)#, collate_fn=custom_collate_fn)

print("Training data loaded - Batch Size:", train_batch_size, " - Number Batches:", len(train_dataloader))
print("Test data loaded     - Batch Size:", test_batch_size, "  - Number Batches:", len(test_dataloader))

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(ds_train)
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

my_file = Path("model_save")
if not my_file.is_dir():
    os.mkdir(my_file)

model = UNet(21).to(device)

In [None]:
optimizer = optim.Adam(model.parameters(), lr=0.001)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.1)
scheduler_on = False

def dice_loss(pred, target, smooth=1.):
    pred = pred.contiguous()
    target = target.contiguous()

    intersection = (pred * target).sum(dim=2).sum(dim=2)

    loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))

    return loss.mean()

criterion = torch.nn.CrossEntropyLoss()

In [None]:
#@title {vertical-output: true}

ratio = "09"
classes = "20"

def train_model(model, optimizer, scheduler, dataloaders, use_scheduler=False, num_epochs=25):
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    best_loss = 1e10
    stored_data = {"epoch": {}}
    for _, epoch in enumerate(range(num_epochs)):
        description = 'Epoch {}/{}'.format(epoch+1, num_epochs)

        since = time.time()

        model.train()

        sum_loss = 0
        sum_loss_ce = 0
        sum_loss_d = 0
        totaler = 0

        for _, data in enumerate(tqdm(iter(dataloaders),desc=description)):

            inputs, labels, labels_one_hot = data

            inputs = inputs.to(device)
            labels = labels.to(device)
            labels_one_hot = labels_one_hot.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward
            outputs = model(inputs)

            multi_scale = 0.9

            loss_ce = criterion(outputs, labels)

            outputs = F.softmax(outputs, dim=1)

            loss_d = dice_loss(outputs, labels_one_hot)

            loss = loss_ce * multi_scale + loss_d * (1-multi_scale)

            loss.backward()
            optimizer.step()

            sum_loss += loss.data.cpu().numpy() * inputs.size(0)
            sum_loss_ce += loss_ce.data.cpu().numpy() * inputs.size(0)
            sum_loss_d += loss_d.data.cpu().numpy() * inputs.size(0)
            totaler += inputs.size(0)

        if use_scheduler:
            scheduler.step()

        time_elapsed = time.time() - since

        updater = {str(epoch+1): {"loss": sum_loss/totaler, "loss_ce": sum_loss_ce/totaler, "loss_d": sum_loss_d/totaler, "time": round(time_elapsed, 1)}}
        stored_data["epoch"].update(updater)
        print("Loss:", updater[str(epoch+1)]["loss"], "-----", "Loss CE:", updater[str(epoch+1)]["loss_ce"], "-----", "Loss D:", updater[str(epoch+1)]["loss_d"])

        save_json("model_save/learning_temp.json", stored_data)
        if updater[str(epoch+1)]["loss"] < best_loss:
            best_loss = updater[str(epoch+1)]["loss"]
            temp_model_path = "model_save/model_temp"
            torch.save(model.state_dict(), temp_model_path)
            
            # Can add a save line in for google colab here:
            # !cp model_save/* drive/MyDrive/bla_bla_bla/.

    print('Best val loss: {:4f}'.format(best_loss))
    save_json("model_save/learning_classes_"+classes+"_ratio_"+ratio+".json", stored_data)
    model.load_state_dict(torch.load(temp_model_path))
    perm_model_path = "model_save/model___classes_"+classes+"_ratio_"+ratio
    torch.save(model.state_dict(), perm_model_path)

    # Can add a save line in for google colab here:
    # !cp model_save/* drive/MyDrive/bla_bla_bla/.

    return model, stored_data


model_trained, learning_metrics = train_model(model, optimizer, exp_lr_scheduler, train_dataloader, use_scheduler=scheduler_on, num_epochs=training_epochs)

In [None]:
# @title {vertical-output: true}

def test_model(model, test_dataloader, num_classes, threshold, image_path="output_view.png"):

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    big_metrics = {}

    confusion = np.zeros((num_classes, 2, 2))
    big_confusion = np.zeros((num_classes, num_classes))
    number_in_class = np.zeros(num_classes)
    number_in_pred = np.zeros(num_classes)
    total_number = 0

    description = "Testing on test subset "
    for j, data in enumerate(tqdm(test_dataloader,desc=description)):

        inputs, labels, _ = data

        inputs = inputs.to(device)
        labels = labels.to(device)

        pred = model(inputs)

        pred = F.softmax(pred, dim=1)

        def argmax_with_threshold(tensor, threshold, default_value):
            below_threshold = (tensor < threshold).all(dim=1)
            argmax_indices = torch.argmax(tensor, dim=1)
            argmax_indices[below_threshold] = default_value
            return argmax_indices

        if threshold == 0:
            pred = torch.argmax(pred, dim=1)
        else:
            pred = argmax_with_threshold(pred, threshold, 0)

        pred_out = pred.data

        flat_pred = pred_out.cpu().numpy().flatten().astype(int)
        flat_label = labels.cpu().numpy().flatten().astype(int)

        flat_label[flat_label > num_classes-1] = 0

        big_confusion += sklearn.metrics.confusion_matrix(flat_label, flat_pred, labels=np.arange(num_classes))#, normalize="true")

        confusion += sklearn.metrics.multilabel_confusion_matrix(flat_label, flat_pred, labels=np.arange(num_classes))

        total_number += len(flat_label)

        binners = np.bincount(flat_label)
        binners_pred = np.bincount(flat_pred)

        number_in_class[range(len(binners))] += binners
        number_in_pred[range(len(binners_pred))] += binners_pred

    precision_li = []
    recall_li = []
    F1_li = []
    accuracy_count = 0
    for cla in confusion:
        TP = cla[1,1]
        FP = cla[1,0]
        FN = cla[0,1]
        TN = cla[0,0]
        try:
            precision = TP/(TP+FP)
        except:
            precision = 0
        try:
            recall = TP/(TP+FN)
        except:
            recall = 0
        try:
            F1 = TP/(TP+(0.5*(FP+FN)))
        except:
            F1 = 0
        precision_li.append(precision)
        recall_li.append(recall)
        F1_li.append(F1)
        accuracy_count += TP

    precision_li = np.array(precision_li)
    recall_li = np.array(recall_li)
    F1_li = np.array(F1_li)

    F1_macro = np.nanmean(F1_li)
    F1_weighted = (np.nansum(F1_li*number_in_class))/total_number
    F1_weighted_no_background = (np.nansum(F1_li[1:]*number_in_class[1:]))/np.nansum(number_in_class[1:])
    accuracy = accuracy_count/total_number

    precision_li[precision_li == np.nan] = 0
    recall_li[recall_li == np.nan] = 0
    F1_li[F1_li == np.nan] = 0

    big_metrics = {
        "per_class": {
            "precision": precision_li.tolist(),
            "recall": recall_li.tolist(),
            "F1-score": F1_li.tolist(),
            "support_true": number_in_class.tolist(),
            "support_pred": number_in_pred.tolist()
        },
        "f1-macro": F1_macro,
        "f1-weighted": F1_weighted,
        "f1-weighted-no-background": F1_weighted_no_background,
        "accuracy": accuracy
    }

    for i in range(num_classes):
        if number_in_class[i] > 0:
            big_confusion[i] = big_confusion[i]/number_in_class[i]

    return big_metrics, big_confusion


perm_model_path = "model_save/model___classes_"+classes+"_ratio_"+ratio
model.load_state_dict(torch.load(perm_model_path))

big_metrics, big_confusion = test_model(model, test_dataloader, 21, 0)
save_json("model_save/model_stats_classes_"+classes+"_threshold_0_ratio_"+ratio+".json", big_metrics)
np.save("model_save/confusion_classes_"+classes+"_threshold_0_ratio_"+ratio, big_confusion)

big_metrics, big_confusion = test_model(model, test_dataloader, 21, 0.5)
save_json("model_save/model_stats_classes_"+classes+"_threshold_05_ratio_"+ratio+".json", big_metrics)
np.save("model_save/confusion_classes_"+classes+"_threshold_05_ratio_"+ratio, big_confusion)

big_metrics, big_confusion = test_model(model, test_dataloader, 21, 0.75)
save_json("model_save/model_stats_classes_"+classes+"_threshold_075_ratio_"+ratio+".json", big_metrics)
np.save("model_save/confusion_classes_"+classes+"_threshold_075_ratio_"+ratio, big_confusion)

# Can add a save line in for google colab here:
# !cp model_save/* drive/MyDrive/bla_bla_bla/.