## Model Training
This notebook trains a UNet model using the Pytorch and segmentation-models-pytorch package. In a first step multiple models are trained based on different settings of the alpha value in the focal loss function. During this step data from both study areas is used together. In a second step the alpha parameter with the best performance on the validation data of both study areas is selected. This model is then fine-tuned independently for each study area using only the training data of the specific study area. This results in two final models, one for each study area. 

#### Inputs:
* *data_path*: Matching image pairs (256 $\times$ 256 pixels) of KH-9 image and crater label split into training, validation and test sets (hdf5 file)
* *data_path_sa*: Same as *data_path* but split into the separate study areas (hdf5 files)

#### Parameters:
* *study_areas*: Names of the study_areas
* *n_classes*: Number of classes in the labelled image tiles = len(crater_ids) + 2 (boundary and background classes)
* *backbone*: Backbone used for the UNet
* *n_epochs*: Maximum number of epochs for model training
* *batch_size*: Batch size used during model training
* *early_stopping_patience*: Model training is stopped if there has not been any improvement for this many epochs
* *alpha_cv*: Settings for the alpha parameter in the focal loss function to use during cross validation
* *crater_ids*: Integers that represent craters in the labelled image tiles
* *crater_classes*: Names of the crater classes

#### Outputs:
* *model_path_cv*: Models trained during cross validation 
* *model_path_sa*: Best models for each study area after fine-tuning

#### Config parameters that were added to the config file based on notebook results
* *alpha_best*: Setting for the alpha parameter in the focal loss function with the highest F1-score on the validation data
* *model_path_cv_best*: Path where the model trained using the *alpha_best* parameter is stored


#### Created paper content:
* **Alpha best**: Setting for the alpha parameter in the focal loss function with the highest F1-score on the validation data

In [1]:
import torch
import pandas as pd
import numpy as np
import segmentation_models_pytorch as smp
import torchvision.transforms.v2 as transforms
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader, Dataset
from torchvision.datapoints import Mask

from evaluation import evaluate_crater_accuracy, evaluate_pixel_accuracy
from utils import create_dir, load_config, load_data, apply_min_max_scaling

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class CustomDataset(Dataset):
    def __init__(self, images, masks, transform=None):
        self.images = images
        self.masks = masks
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        mask = self.masks[idx]

        if self.transform:
            # set to type Mask to ensure the transform functions know
            # to treat it as a label
            image, mask = self.transform(image, Mask(mask))

        return image, mask


def create_dataset(x, y, transform=None):
    x_tensor = torch.FloatTensor(x).permute(0, 3, 1, 2)
    y_tensor = torch.LongTensor(y.argmax(axis=-1))
    dataset = CustomDataset(x_tensor, y_tensor, transform=transform)
    return dataset


def create_dataset_loader(x, y, batch_size, transform, shuffle=True):
    # create training and validation loaders
    dataset = create_dataset(x, y, transform=transform)
    dataset_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle
    )

    return dataset_loader


def load_data_for_training(data_path, batch_size):
    # load the training data
    x_train, y_train, x_val, y_val = load_data(
        data_path, "x_train", "y_train", "x_val", "y_val")
    x_train = apply_min_max_scaling(x_train)
    x_val = apply_min_max_scaling(x_val)

    # define augmentations
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ColorJitter(brightness=0.8, contrast=0.8),
        transforms.RandomRotation(degrees=30),
    ])

    train_loader = create_dataset_loader(
        x_train, y_train,
        batch_size=batch_size,
        transform=transform,
        shuffle=True
    )
    val_loader = create_dataset_loader(
        x_val, y_val,
        batch_size=batch_size,
        transform=None,
        shuffle=False
    )
    return train_loader, val_loader

In [3]:
# Code from https://saturncloud.io/blog/how-to-use-class-weights-with-focal-loss-in-pytorch-for-imbalanced-multiclass-classification/#:~:text=Focal%20loss%20works%20by%20down,performance%20on%20the%20minority%20class.
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction="none")
        pt = torch.exp(-ce_loss)
        focal_part = (1 - pt) ** self.gamma
        loss = (self.alpha[targets] * focal_part * ce_loss).mean()
        return loss

# code for checkpoints and early stopping https://machinelearningmastery.com/managing-a-pytorch-training-process-with-checkpoints-and-early-stopping/
def checkpoint(model, filename):
    torch.save(model.state_dict(), filename)


def resume(model, filename):
    model.load_state_dict(torch.load(filename))


def train_model(
    train_loader,
    val_loader,
    n_classes,
    backbone,
    alpha,
    n_epochs,
    early_stopping_patience,
    out_path,
    fine_tune_model=False,
    input_model=None,
    seed=1234
):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    create_dir(out_path, is_file=True)

    # create or load model
    if fine_tune_model:
        assert input_model is not None, "Need to provide an input model when fine_tune_model is True"
        model = torch.load(input_model, map_location=device)
        # set a lower learning rate for fine-tuning
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
    else:
        # Create new model
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        model = smp.Unet(
            encoder_name=backbone,
            encoder_weights="imagenet",
            in_channels=1,
            classes=n_classes,
            activation=None
        )
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Choose a loss function and optimizer
    alpha = torch.tensor(alpha).to(device)
    loss = FocalLoss(alpha=alpha)

    # Train the model
    model.to(device)
    lowest_val_loss = 999
    best_epoch = -1
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    for epoch in range(n_epochs):
        model.train()
        epoch_loss = 0

        # train the model
        for images, masks in train_loader:
            images = images.to(device)
            masks = masks.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            loss_value = loss(outputs, masks)
            epoch_loss += loss_value.item()

            loss_value.backward()
            optimizer.step()

        epoch_loss /= len(train_loader)

        # calculate validation loss
        model.eval()
        epoch_val_loss = 0

        for images, masks in val_loader:
            images = images.to(device)
            masks = masks.to(device)
            outputs = model(images)
            loss_value = loss(outputs, masks)
            epoch_val_loss += loss_value.item()

        epoch_val_loss /= len(train_loader)
        print(
            f"Epoch [{epoch + 1}/{n_epochs}] Loss: {epoch_loss:.4f} Validation Loss: {epoch_val_loss:.4f}")

        if epoch_val_loss < lowest_val_loss:
            lowest_val_loss = epoch_val_loss
            best_epoch = epoch
            checkpoint(model, out_path)
        elif epoch - best_epoch > early_stopping_patience:
            print("Early stopped training at epoch %d" % epoch)
            break  # terminate the training loop

    # load best model and save the full model not just the model state
    resume(model, out_path)
    torch.save(model, out_path)

    return model


def pred_val_data(model, data_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()

    pred_list = []
    with torch.no_grad():
        for images, _ in data_loader:
            images = images.to(device)
            outputs = model(images)
            pred_list.append(outputs.cpu().numpy())

    # Concatenate the predictions and masks along the batch dimension to get the final result
    pred = np.concatenate(pred_list, axis=0)
    pred = pred.transpose((0, 2, 3, 1))

    return pred

In [4]:
config = load_config("../config.yaml")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.deterministic = True

In [5]:
train_loader, val_loader = load_data_for_training(
    config.get("data_path"),
    batch_size=config.get("batch_size")
)

for setting_name, alpha_cv in config.get("alpha_cv").items():

    model = train_model(
        train_loader=train_loader,
        val_loader=val_loader,
        n_classes=config.get("n_classes"),
        backbone=config.get("backbone"),
        n_epochs=config.get("n_epochs"),
        alpha=alpha_cv,
        early_stopping_patience=config.get("early_stopping_patience"),
        out_path=config.get("model_path_cv").format(cv_setting=setting_name)
    )

    del (model)
    torch.cuda.empty_cache()
del (train_loader, val_loader)
torch.cuda.empty_cache()

Directory already exists: ../outputs/models/cv/equal
Epoch [1/1000] Loss: 0.0944 Validation Loss: 0.0121
Epoch [2/1000] Loss: 0.0304 Validation Loss: 0.0118
Epoch [3/1000] Loss: 0.0283 Validation Loss: 0.0097
Epoch [4/1000] Loss: 0.0271 Validation Loss: 0.0099
Epoch [5/1000] Loss: 0.0252 Validation Loss: 0.0092
Epoch [6/1000] Loss: 0.0251 Validation Loss: 0.0095
Epoch [7/1000] Loss: 0.0235 Validation Loss: 0.0093
Epoch [8/1000] Loss: 0.0230 Validation Loss: 0.0085
Epoch [9/1000] Loss: 0.0232 Validation Loss: 0.0079
Epoch [10/1000] Loss: 0.0221 Validation Loss: 0.0082
Epoch [11/1000] Loss: 0.0223 Validation Loss: 0.0088
Epoch [12/1000] Loss: 0.0222 Validation Loss: 0.0078
Epoch [13/1000] Loss: 0.0222 Validation Loss: 0.0077
Epoch [14/1000] Loss: 0.0216 Validation Loss: 0.0089
Epoch [15/1000] Loss: 0.0217 Validation Loss: 0.0078
Epoch [16/1000] Loss: 0.0214 Validation Loss: 0.0077
Epoch [17/1000] Loss: 0.0203 Validation Loss: 0.0070
Epoch [18/1000] Loss: 0.0200 Validation Loss: 0.0069
Ep

### Model Selection

In [7]:
results_pixel = list()
results_crater = list()

# load the validation data
x_val, y_val = load_data(config.get("data_path"), "x_val", "y_val")
x_val = apply_min_max_scaling(x_val)
val_loader = create_dataset_loader(
    x_val, y_val,
    batch_size=config.get("batch_size"),
    transform=None,
    shuffle=False
)

for setting_name, alpha_cv in config.get("alpha_cv").items():
    print(setting_name)

    # load the trained model
    model_path = config.get("model_path_cv").format(cv_setting=setting_name)
    model = torch.load(
        config.get("model_path_cv").format(cv_setting=setting_name),
        map_location=device
    )

    # predict on the validation data
    pred = pred_val_data(model, val_loader)

    # evaluate pixel accuracy
    res_pixel = evaluate_pixel_accuracy(
        pred, y_val,
        crater_ids=config.get("crater_ids")
    )

    # evaluate crater accuracy
    res_crater, cm = evaluate_crater_accuracy(
        pred, y_val,
        crater_classes=config.get("crater_classes"),
        crater_ids=config.get("crater_ids"),
        min_crater_area=config.get("min_crater_area"),
        threshold=0.5,
        plot_cm=False
    )

    # append to list of results for the cross-validation
    results_pixel.append(res_pixel)
    results_crater.append(res_crater)

    del (model)
    torch.cuda.empty_cache()
del (val_loader)
torch.cuda.empty_cache()

equal
factor_2
factor_3
factor_4
factor_5
factor_6


In [8]:
results_cv = pd.DataFrame(
    [df["craters_combined"] for df in results_crater],
    index=config.get("alpha_cv").keys(),
)
results_cv

Unnamed: 0,Precision,Recall,F1-Score,N
equal,0.734,0.557,0.634,1880.0
factor_2,0.684,0.569,0.621,1880.0
factor_3,0.68,0.614,0.645,1880.0
factor_4,0.577,0.619,0.597,1880.0
factor_5,0.59,0.638,0.613,1880.0
factor_6,0.553,0.643,0.595,1880.0


### Add parameters of best cv run to the config file

`alpha_best: [1, 3, 3, 3, 3, 3, 3]`

`model_path_cv_best: ../outputs/models/cv/factor_3/model.pth`


### Model Fine-tuning
Select best model from cross-validation, fine-tune the model parameters independently for each study area

In [12]:
study_areas = config.get("study_areas").keys()

for study_area in study_areas:
    print(study_area)
    train_loader, val_loader = load_data_for_training(
        config.get("data_path_sa").format(study_area=study_area),
        batch_size=config.get("batch_size")
    )

    model = train_model(
        train_loader=train_loader,
        val_loader=val_loader,
        n_classes=config.get("n_classes"),
        backbone=config.get("backbone"),
        n_epochs=config.get("n_epochs"),
        alpha=config.get("alpha_best"),
        early_stopping_patience=config.get("early_stopping_patience"),
        out_path=config.get("model_path_sa").format(study_area=study_area),
        fine_tune_model=True,
        input_model=config.get("model_path_cv_best")
    )

    del (model, train_loader, val_loader)
    torch.cuda.empty_cache()

quang_tri
Directory already exists: ../outputs/models/quang_tri
Epoch [1/1000] Loss: 0.0472 Validation Loss: 0.0199
Epoch [2/1000] Loss: 0.0454 Validation Loss: 0.0199
Epoch [3/1000] Loss: 0.0468 Validation Loss: 0.0204
Epoch [4/1000] Loss: 0.0443 Validation Loss: 0.0205
Epoch [5/1000] Loss: 0.0435 Validation Loss: 0.0204
Epoch [6/1000] Loss: 0.0430 Validation Loss: 0.0199
Epoch [7/1000] Loss: 0.0435 Validation Loss: 0.0203
Epoch [8/1000] Loss: 0.0426 Validation Loss: 0.0202
Epoch [9/1000] Loss: 0.0458 Validation Loss: 0.0202
Epoch [10/1000] Loss: 0.0446 Validation Loss: 0.0199
Epoch [11/1000] Loss: 0.0426 Validation Loss: 0.0199
Epoch [12/1000] Loss: 0.0435 Validation Loss: 0.0206
Epoch [13/1000] Loss: 0.0424 Validation Loss: 0.0200
Epoch [14/1000] Loss: 0.0431 Validation Loss: 0.0204
Epoch [15/1000] Loss: 0.0415 Validation Loss: 0.0199
Epoch [16/1000] Loss: 0.0414 Validation Loss: 0.0203
Epoch [17/1000] Loss: 0.0421 Validation Loss: 0.0201
Epoch [18/1000] Loss: 0.0432 Validation Loss

In [14]:
study_areas = config.get("study_areas").keys()
results_pixel = list()
results_crater = list()

for study_area in study_areas:
    # load the validation data
    x_val, y_val = load_data(
        config.get("data_path_sa").format(study_area=study_area),
        "x_val", "y_val"
    )
    x_val = apply_min_max_scaling(x_val)
    val_loader = create_dataset_loader(
        x_val, y_val,
        batch_size=config.get("batch_size"),
        transform=None,
        shuffle=False
    )
    # load the trained model
    model = torch.load(
        config.get("model_path_sa").format(study_area=study_area),
        map_location=device
    )

    # predict on the validation data
    pred = pred_val_data(model, val_loader)

    # evaluate pixel accuracy
    res_pixel = evaluate_pixel_accuracy(
        pred, y_val,
        crater_ids=config.get("crater_ids")
    )

    # evaluate crater accuracy
    res_crater, cm = evaluate_crater_accuracy(
        pred, y_val,
        crater_classes=config.get("crater_classes"),
        crater_ids=config.get("crater_ids"),
        min_crater_area=config.get("min_crater_area"),
        threshold=0.5,
        plot_cm=False
    )

    # append to list of results
    results_pixel.append(res_pixel)
    results_crater.append(res_crater)

    del (model, val_loader)
    torch.cuda.empty_cache()

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [15]:
results_sa = pd.DataFrame(
    [df["craters_combined"] for df in results_crater], index=study_areas
)
results_sa

Unnamed: 0,Precision,Recall,F1-Score,N
quang_tri,0.664,0.656,0.66,1738.0
tri_border_area,0.648,0.324,0.432,142.0


In [16]:
results_crater[0]

Unnamed: 0,0,1,2,3,4,5,craters_combined
Precision,0.0,0.674,0.515,0.286,0.289,0.385,0.664
Recall,0.0,0.783,0.444,0.049,0.468,0.271,0.656
F1-Score,0.0,0.725,0.477,0.084,0.357,0.318,0.66
N,577.0,844.0,426.0,163.0,139.0,166.0,1738.0


In [17]:
results_crater[1]

Unnamed: 0,0,1,2,3,4,5,craters_combined
Precision,0.0,0.7,0.667,0.0,0.357,0.0,0.648
Recall,0.0,0.432,0.4,0.0,0.185,0.0,0.324
F1-Score,0.0,0.534,0.5,0.0,0.244,0.0,0.432
N,25.0,81.0,10.0,13.0,27.0,11.0,142.0
