In [None]:
# ------ Copyright (C) 2024 University of Strathclyde and Author ------
# ---------------- Author: Robert Cowlishaw ------------------------
# ----------- e-mail: robert.cowlishaw.2017@uni.strath.ac.uk ----------------

# Centralised Crop Type Detection with UNet - Model Tester

These are used to compare the loss ratios, number of crop classes and the softmax thresholding values and the associated F1-scores from the output. This file is designed to test and produce metrics for the model produced by `centralised_model.ipynb`.

In [None]:
#@title {vertical-output: true}

!pip install datasets -q

In [None]:
#@title {vertical-output: true}

from pathlib import Path
my_file = Path("smart_dao")
if not my_file.is_dir():
    !git clone https://github.com/strath-ace/smart-dao smart_dao

!cp smart_dao/agriculture-federated-learning/centralised-unet/UNET.py .
!cp smart_dao/agriculture-federated-learning/centralised-unet/utils.py .

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import json
import random
from functools import reduce
import itertools
from collections import defaultdict
import time
import copy
import json
from tqdm import tqdm
import sklearn.metrics

from datasets import load_dataset

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms, models
from torchvision.transforms import ToTensor
import torch.optim as optim
from torch.optim import lr_scheduler

from utils import *
from UNET import UNet

In [None]:
#@title {vertical-output: true}

dataset = load_dataset("0x365/eo-crop-type-belgium", split="train")

ds = dataset.with_format("torch")

dataset_new = ds.train_test_split(test_size=0.2)

In [None]:
uni, counts = torch.unique(dataset_new["train"].select(np.arange(1000))["label"],return_counts=True)
counts = counts.numpy()
print(counts/np.sum(counts))

In [None]:
#@title {vertical-output: true}

bands = ["B02","B03","B04","B05","B06","B07","B08","B11","B12"]

def dice_loss(pred, target, smooth=1.):
    pred = pred.contiguous()
    target = target.contiguous()

    intersection = (pred * target).sum(dim=2).sum(dim=2)

    loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))

    return loss.mean()

def dice_loss_less(pred, target, smooth=1.):
    pred = pred.contiguous()
    target = target.contiguous()

    intersection = (pred * target).sum(dim=1).sum(dim=1)

    loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=1).sum(dim=1) + target.sum(dim=1).sum(dim=1) + smooth)))

    return loss.mean()

####################

# Data batch loader sizes
train_batch_size = 32
test_batch_size = 32

# Training epochs
training_epochs = 100

scheduler_on = False
NUM_CLASSES = 21

####################

train_dataloader = DataLoader(dataset_new["train"], batch_size=train_batch_size, shuffle=True)#, collate_fn=custom_collate_fn)
test_dataloader = DataLoader(dataset_new["test"], batch_size=test_batch_size, shuffle=True)#, collate_fn=custom_collate_fn)

print("Training data loaded - Batch Size:", train_batch_size, " - Number Batches:", len(train_dataloader))
print("Test data loaded     - Batch Size:", test_batch_size, "  - Number Batches:", len(test_dataloader))

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Import UNET model and send to device
unique_labels = ((ds["label"].unique()))  # Check unique label values
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

my_file = Path("model_save")
if not my_file.is_dir():
    os.mkdir(my_file)

model = UNet(NUM_CLASSES).to(device)


In [None]:
raise Exception("Need to add model loader from wherever it is saved")
# The code should look something like this if coming from google drive on google colab
"""
!cp drive/MyDrive/bla_bla_bla/model_temp model_save/.
!cp drive/MyDrive/bla_bla_bla/learning_temp.json model_save/.
"""
# Or you can just have it in your working directory
model.load_state_dict(torch.load("model_save/model_temp"))

In [None]:
#@title {vertical-output: true}

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

for data in test_dataloader:

    inputs = []
    for k in bands:
        inputs.append(data[k][:,0])
    inputs = torch.stack(inputs, dim=1).clone().detach().float() / 255
    label = data["label"][:,0].numpy()

    break

inputs = inputs.to(device)
pred = model(inputs)
pred = F.softmax(pred, dim=1)

color_scheme = "gist_ncar"

def argmax_with_threshold(tensor, threshold, default_value):
    below_threshold = (tensor < threshold).all(dim=1)
    argmax_indices = torch.argmax(tensor, dim=1)
    argmax_indices[below_threshold] = default_value

    return argmax_indices

pred = torch.argmax(pred, dim=1)

pred = pred.cpu().numpy()

label[label==255] = 0
label[label > 10] = 0

fig, axs = plt.subplots(4,2, figsize=(7,14))

for i, ax in enumerate(axs):
    axs[i,0].imshow(label[i], vmin=0, vmax=10, cmap=color_scheme)
    axs[i,1].imshow(pred[i], vmin=0, vmax=10, cmap=color_scheme)

In [None]:
#@title {vertical-output: true}

learning_rater = load_json("model_save/learning_temp.json")
print(learning_rater)

loss = []
loss_ce = []
loss_d = []
for i in range(1,101):
    try:
        loss.append(learning_rater["epoch"][str(i)]["loss"])
        loss_ce.append(learning_rater["epoch"][str(i)]["loss_ce"])
        loss_d.append(learning_rater["epoch"][str(i)]["loss_d"])
    except:
        pass


plt.plot(loss, label="Total Loss")
plt.plot(loss_ce, label="Cross Entropy Loss")
plt.plot(loss_d, label="Dice Loss")
plt.yscale("log")
plt.legend()

In [None]:
# @title {vertical-output: true}

def test_model(model, test_dataloader, num_classes, image_path="output_view.png"):

    summer = 0
    divisor = 0
    not_run = True
    summer2 = np.zeros(num_classes,dtype=float)
    divisor2 = np.zeros(num_classes,dtype=float)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    big_metrics = {}
    flat_pred_ex = np.array([])
    flat_label_ex = np.array([])

    confusion = np.zeros((num_classes, 2, 2))
    big_confusion = np.zeros((num_classes, num_classes))
    number_in_class = np.zeros(num_classes)
    number_in_pred = np.zeros(num_classes)
    total_number = 0

    description = "Testing on test subset "
    for j, data in enumerate(tqdm(test_dataloader,desc=description)):
        inputs = []
        for k in bands:
            inputs.append(data[k][:,0])
        inputs = torch.stack(inputs, dim=1).clone().detach().float() / 255
        labels = data["label"][:,0].double()

        inputs = inputs.to(device)
        labels = labels.to(device)

        pred = model(inputs)

        pred = F.softmax(pred, dim=1)

        def argmax_with_threshold(tensor, threshold, default_value):
            # Check if all values are below the threshold
            below_threshold = (tensor < threshold).all(dim=1)

            # Compute the argmax normally
            argmax_indices = torch.argmax(tensor, dim=1)

            # If all values are below the threshold, set to the default_value
            argmax_indices[below_threshold] = default_value

            return argmax_indices

        pred = torch.argmax(pred, dim=1)

        pred_out = pred.data

        flat_pred = pred_out.cpu().numpy().flatten().astype(int)
        flat_label = labels.cpu().numpy().flatten().astype(int)

        flat_label[flat_label > 10] = 0

        big_confusion += sklearn.metrics.confusion_matrix(flat_label, flat_pred, labels=np.arange(num_classes))#, normalize="true")

        confusion += sklearn.metrics.multilabel_confusion_matrix(flat_label, flat_pred, labels=np.arange(num_classes))

        total_number += len(flat_label)

        binners = np.bincount(flat_label)
        binners_pred = np.bincount(flat_pred)

        number_in_class[range(len(binners))] += binners
        number_in_pred[range(len(binners_pred))] += binners_pred

    precision_li = []
    recall_li = []
    F1_li = []
    accuracy_count = 0
    for cla in confusion:
        TP = cla[1,1]
        FP = cla[1,0]
        FN = cla[0,1]
        TN = cla[0,0]
        try:
            precision = TP/(TP+FP)
        except:
            precision = 0
        try:
            recall = TP/(TP+FN)
        except:
            recall = 0
        try:
            F1 = TP/(TP+(0.5*(FP+FN)))
        except:
            F1 = 0
        precision_li.append(precision)
        recall_li.append(recall)
        F1_li.append(F1)
        accuracy_count += TP

    precision_li = np.array(precision_li)
    recall_li = np.array(recall_li)
    F1_li = np.array(F1_li)

    F1_macro = np.nanmean(F1_li)
    F1_weighted = (np.nansum(F1_li*number_in_class))/total_number
    F1_weighted_no_background = (np.nansum(F1_li[1:]*number_in_class[1:]))/np.nansum(number_in_class[1:])
    accuracy = accuracy_count/total_number

    precision_li[precision_li == np.nan] = 0
    recall_li[recall_li == np.nan] = 0
    F1_li[F1_li == np.nan] = 0

    big_metrics = {
        "per_class": {
            "precision": precision_li.tolist(),
            "recall": recall_li.tolist(),
            "F1-score": F1_li.tolist(),
            "support_true": number_in_class.tolist(),
            "support_pred": number_in_pred.tolist()
        },
        "f1-macro": F1_macro,
        "f1-weighted": F1_weighted,
        "f1-weighted-no-background": F1_weighted_no_background,
        "accuracy": accuracy
    }

    for i in range(num_classes):
        # print(number_in_class[i])
        if number_in_class[i] > 0:
            big_confusion[i] = big_confusion[i]/number_in_class[i]

    return big_metrics, big_confusion

model.load_state_dict(torch.load("model_save/model_temp"))
big_metrics, big_confusion = test_model(model, test_dataloader, NUM_CLASSES)

save_json("model_stats.json", big_metrics)

print()
print("#############")
print(big_metrics["f1-weighted"])
print(big_metrics["f1-weighted-no-background"])
print(big_metrics["per_class"]["F1-score"])
print()

In [None]:
# @title {vertical-output: true}

print(np.shape(big_confusion))
fig = plt.figure(figsize=(10,10))

plt.imshow(big_confusion, vmin=0, vmax=1)
big_confusion = np.around(big_confusion,2)
for i in range(big_confusion.shape[0]):
    for j in range(big_confusion.shape[1]):
        plt.text(j, i, big_confusion[i, j], ha='center', va='center', color='white')
plt.colorbar()
plt.title("Confusion Matrix (What is predicted)")
plt.xlabel("Prediction")
plt.ylabel("Truth")