In [1]:
import sys
import torch
import os
import json
import time
import random

import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import normalize

import matplotlib.pyplot as plt
import matplotlib.patheffects as pe
import seaborn as sn

from tqdm import tqdm

import cv2

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import albumentations as album
import gc
import segmentation_models_pytorch as smp
from segmentation_models_pytorch import utils

gc.collect()
torch.cuda.empty_cache()

In [3]:
# Set seed for reproducibility
def set_seed(seed, use_gpu=True):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if use_gpu:
        print("CUDA is available")
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

SEED = 12345
USE_SEED = True

if USE_SEED:
    set_seed(SEED, torch.cuda.is_available())

CUDA is available


In [4]:
DATA_PATH = "./data"
train_metadata_df = pd.read_csv(os.path.join(DATA_PATH, "train_metadata_patches256.csv"))
valid_metadata_df = pd.read_csv(os.path.join(DATA_PATH, "valid_metadata_patches256.csv"))
test_metadata_df = pd.read_csv(os.path.join(DATA_PATH, "test_metadata_patches256.csv"))


In [5]:
train_metadata_df["sat_image_path"] = train_metadata_df["sat_image_path"].apply(lambda img_pth: os.path.join(DATA_PATH, img_pth))
train_metadata_df["mask_path"] = train_metadata_df["mask_path"].apply(lambda mask_pth: os.path.join(DATA_PATH, mask_pth))

valid_metadata_df["sat_image_path"] = valid_metadata_df["sat_image_path"].apply(lambda img_pth: os.path.join(DATA_PATH, img_pth))
valid_metadata_df["mask_path"] = valid_metadata_df["mask_path"].apply(lambda mask_pth: os.path.join(DATA_PATH, mask_pth))

test_metadata_df["sat_image_path"] = test_metadata_df["sat_image_path"].apply(lambda img_pth: os.path.join(DATA_PATH, img_pth))
test_metadata_df["mask_path"] = test_metadata_df["mask_path"].apply(lambda mask_pth: os.path.join(DATA_PATH, mask_pth))


In [6]:
class_dict = pd.read_csv(os.path.join(DATA_PATH, "class_dict.csv"))
class_names = class_dict["name"].tolist()
class_rgb_values = class_dict[["r", "g", "b"]].values.tolist()

format_spec = "{:<20} {:<20}"
print(format_spec.format("class name:", "class RGB values:"), "\n")

for i in range(len(class_names)):
    print(format_spec.format(str(class_names[i]), str(class_rgb_values[i])))


class name:          class RGB values:    

built-up             [255, 0, 0]         
farmland             [0, 255, 0]         
forest               [0, 255, 255]       
meadow               [255, 255, 0]       
water                [0, 0, 255]         
unknown              [0, 0, 0]           


In [7]:
class DeepGlobeLandCover(Dataset):
    def __init__(self, df, class_rgb_values=None, transform=None):
        self.image_paths = df["sat_image_path"].tolist()
        self.mask_paths = df["mask_path"].tolist()
        self.class_rgb_values = class_rgb_values
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, i):
        image = cv2.imread(self.image_paths[i])
        mask = cv2.imread(self.mask_paths[i])

        if self.transform:
            sample = self.transform(image=image, mask=mask)
            image, mask = sample["image"], sample["mask"]

        mask = rgb_to_dense_encoding(mask, self.class_rgb_values)

        image = torch.from_numpy(image).permute(2, 0, 1)
        mask = torch.from_numpy(mask)

        return image, mask

In [8]:
def rgb_to_dense_encoding(mask_rgb, class_rgb_values):
    if mask_rgb.shape[2] != 3:
        raise ValueError("The number of channels must be 3")

    mask_dense = np.zeros((mask_rgb.shape[0], mask_rgb.shape[1]), dtype="uint8")
    for i, rgb in enumerate(class_rgb_values):
        mask_dense[np.all(np.equal(mask_rgb, rgb), axis=-1)] = i
    return mask_dense

def dense_to_rgb_encoding(mask_dense, class_rgb_values):
    if len(mask_dense.shape) != 2:
        raise ValueError("The number of channels must be 1")

    mask_rgb = np.zeros((mask_dense.shape[0], mask_dense.shape[1], 3), dtype="uint8")
    for i, rgb in enumerate(class_rgb_values):
        mask_rgb[mask_dense == i, :] = rgb
    return mask_rgb

In [9]:
def task_pixels(mask_path):
    mask = cv2.imread(mask_path)
    mask = rgb_to_dense_encoding(mask, class_rgb_values)
    unique, counts = np.unique(mask, return_counts=True)
    return (unique, counts)

FREQ = True
if FREQ:
    sparse_pixels = [task_pixels(path) for path in tqdm(train_metadata_df["mask_path"].tolist())]
    pixels = np.zeros((len(sparse_pixels), 6))
    for i in range(len(sparse_pixels)):
        pixels[i, sparse_pixels[i][0]] = sparse_pixels[i][1]
total_pixels_per_class = pixels.sum(axis=0)
total_pixels = total_pixels_per_class.sum()
distr = total_pixels_per_class / total_pixels
assert np.isclose(distr.sum(), 1), "Sum of distribution is not equal to 1"
print(distr)

freq = np.array([0.16138614, 0.34260766, 0.06750223, 0.00764113, 0.08391151, 0.33695133])
class_weight = np.median(freq)/freq

100%|██████████| 55566/55566 [10:40<00:00, 86.79it/s] 


[0.15084652 0.34913148 0.06891692 0.00762519 0.08419963 0.33928026]


In [10]:
def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype("float32")

def get_training_augmentation():
    return album.Compose([
        album.RandomRotate90(p=1),
        album.HorizontalFlip(p=0.5),
        album.VerticalFlip(p=0.5),
        album.Transpose(p=0.5),
        album.ColorJitter(p=0.5),
        album.Normalize(mean=[79.75/255, 112.14/255, 110.33/255], std=[79.75/255, 112.14/255, 110.33/255])
    ])

def get_validation_augmentation():
    return album.Compose([
        album.Resize(height=128, width=128),
        album.Normalize(mean=[79.75/255, 112.14/255, 110.33/255], std=[79.75/255, 112.14/255, 110.33/255])
    ])

def get_preprocessing(preprocessing_fn):
    return album.Compose([
        album.Lambda(image=preprocessing_fn),
        album.Lambda(image=to_tensor, mask=to_tensor)
    ])


In [11]:
train_transform = get_training_augmentation()
valid_transform = get_validation_augmentation()

ENCODER = "resnet50"
ENCODER_WEIGHTS = "imagenet"
ACTIVATION = "softmax2d"

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

train_dataset = DeepGlobeLandCover(train_metadata_df, class_rgb_values, transform=album.Compose([train_transform, get_preprocessing(preprocessing_fn)]))
valid_dataset = DeepGlobeLandCover(valid_metadata_df, class_rgb_values, transform=album.Compose([valid_transform, get_preprocessing(preprocessing_fn)]))
test_dataset = DeepGlobeLandCover(test_metadata_df, class_rgb_values, transform=album.Compose([valid_transform, get_preprocessing(preprocessing_fn)]))


In [12]:
BATCH_SIZE = 8
train_iterator = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE, drop_last=True)
valid_iterator = DataLoader(valid_dataset, batch_size=BATCH_SIZE, drop_last=True)
test_iterator = DataLoader(test_dataset, batch_size=BATCH_SIZE, drop_last=True)


In [13]:
model = smp.DeepLabV3Plus(encoder_name=ENCODER, encoder_weights=ENCODER_WEIGHTS, classes=len(class_names), activation=ACTIVATION)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [14]:
criterion = nn.CrossEntropyLoss(weight=torch.from_numpy(class_weight).float()).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)


In [15]:
from torch.cuda.amp import GradScaler, autocast
scaler = GradScaler()

In [16]:
def compute_CM_elements(y_logits, y):
    y_prob = F.softmax(y_logits, dim=1)
    y_pred = y_prob.argmax(dim=1)
    y_pred_bool = F.one_hot(y_pred, num_classes=6).bool()
    y_bool = F.one_hot(y, num_classes=6).bool()
    TP = torch.sum(torch.logical_and(y_pred_bool, y_bool), dim=(0,1,2))
    FP = torch.sum(y_pred_bool, dim=(0,1,2)) - TP
    TN = torch.sum(torch.logical_and(~y_pred_bool, ~y_bool), dim=(0,1,2))
    FN = torch.sum(~y_pred_bool, dim=(0,1,2)) - TN
    return TP, FP, TN, FN

In [17]:
def compute_metrics(loss, TP, FP, TN, FN):
    IoU = TP / (TP + FP + FN)
    MA_IoU = torch.sum(IoU) / 6
    accuracy = (TP + TN) / (TP + FP + TN + FN)
    MA_accuracy = torch.sum(accuracy) / 6
    precision = TP / (TP + FP)
    MA_precision = torch.sum(precision) / 6
    recall = TP / (TP + FN)
    MA_recall = torch.sum(recall) / 6
    F1_score = 2 * (precision * recall) / (precision + recall)
    MA_F1_score = 2 * (MA_precision * MA_recall) / (MA_precision + MA_recall)

    metrics = {
        "loss": loss.item(),
        "IoU": IoU.tolist(),
        "MA_IoU": MA_IoU.item(),
        "accuracy": accuracy.tolist(),
        "MA_accuracy": MA_accuracy.item(),
        "precision": precision.tolist(),
        "MA_precision": MA_precision.item(),
        "recall": recall.tolist(),
        "MA_recall": MA_recall.item(),
        "F1_score": F1_score.tolist(),
        "MA_F1_score": MA_F1_score.item()
    }
    return metrics

In [18]:
def train(model, iterator, criterion, optimizer, device, desc="Train"):
    epoch_loss = 0
    epoch_TP = 0
    epoch_FP = 0
    epoch_TN = 0
    epoch_FN = 0

    model.train()
    for x, y in tqdm(iterator, desc=desc):
        optimizer.zero_grad()
        x = x.to(device)
        y = y.to(device)
        with autocast():
            y_logits = model(x)
            loss = criterion(y_logits, y.long())
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        TP, FP, TN, FN = compute_CM_elements(y_logits, y.long())
        epoch_loss += loss / len(iterator)
        epoch_TP += TP
        epoch_FP += FP
        epoch_TN += TN
        epoch_FN += FN
    epoch_metrics = compute_metrics(epoch_loss, epoch_TP, epoch_FP, epoch_TN, epoch_FN)
    return epoch_metrics

In [19]:
def evaluate(model, iterator, criterion, device, desc="Valid"):
    epoch_loss = 0
    epoch_TP = 0
    epoch_FP = 0
    epoch_TN = 0
    epoch_FN = 0
    model.eval()
    with torch.no_grad():
        for x, y in tqdm(iterator, desc=desc):
            x = x.to(device)
            y = y.to(device)
            y_logits = model(x)
            loss = criterion(y_logits, y.long())
            TP, FP, TN, FN = compute_CM_elements(y_logits, y.long())
            epoch_loss += loss / len(iterator)
            epoch_TP += TP
            epoch_FP += FP
            epoch_TN += TN
            epoch_FN += FN
    epoch_metrics = compute_metrics(epoch_loss, epoch_TP, epoch_FP, epoch_TN, epoch_FN)
    return epoch_met

In [20]:
def model_training(n_epochs, model, train_iterator, valid_iterator, criterion, optimizer, device, checkpoint_name="checkpoint_deeplabv3plus.pt", results_name="results_deeplabv3plus.csv"):
    best_valid_loss = float('inf')
    train_metrics_log = {}
    valid_metrics_log = {}
    print("----------------------------------------------------------")
    for epoch in range(n_epochs):
        print(f"\nEpoch: {epoch + 1}/{n_epochs}\n")
        start_time = time.time()
        train_metrics = train(model, train_iterator, criterion, optimizer, device, desc="Train")
        valid_metrics = evaluate(model, valid_iterator, criterion, device, desc="Valid")
        print(f"\nTrain:", f"Loss = {train_metrics['loss']:.4f},", f"MA_IoU = {train_metrics['MA_IoU']*100:.2f} %,", f"MA_acc = {train_metrics['MA_accuracy']*100:.2f} %,", f"MA_prec = {train_metrics['MA_precision']*100:.2f} %,", f"MA_rec = {train_metrics['MA_recall']*100:.2f} %,", f"MA_F1_score = {train_metrics['MA_F1_score']*100:.2f} %")
        print(f"Valid:", f"Loss = {valid_metrics['loss']:.4f},", f"MA_IoU = {valid_metrics['MA_IoU']*100:.2f} %,", f"MA_acc = {valid_metrics['MA_accuracy']*100:.2f} %,", f"MA_prec = {valid_metrics['MA_precision']*100:.2f} %,", f"MA_rec = {valid_metrics['MA_recall']*100:.2f} %,", f"MA_F1_score = {valid_metrics['MA_F1_score']*100:.2f} %")
        end_time = time.time()
        print(f"\nEpoch Time: {end_time-start_time:.2f} s\n")
        print("----------------------------------------------------------")
        if valid_metrics["loss"] < best_valid_loss:
            best_valid_loss = valid_metrics["loss"]
            torch.save({"epoch": epoch + 1, "model_state_dict": model.state_dict(), "criterion_state_dict": criterion.state_dict(), "optimizer_state_dict": optimizer.state_dict()}, os.path.join('./checkpoints', checkpoint_name))
        for key in train_metrics.keys():
            if key not in train_metrics_log:
                train_metrics_log[key] = []
            train_metrics_log[key].append(train_metrics[key])
        pd.DataFrame.from_dict(train_metrics_log).to_csv("./results/train_" + results_name, index=False)
        for key in valid_metrics.keys():
            if key not in valid_metrics_log:
                valid_metrics_log[key] = []
            valid_metrics_log[key].append(valid_metrics[key])
        pd.DataFrame.from_dict(valid_metrics_log).to_csv("./results/valid_" + results_name, index=False)
    return train_metrics_log, valid_metrics_log


In [21]:
def test(model, iterator, criterion, device, desc="Test"):
    loss = 0
    CM = 0

    model.eval()
    with torch.no_grad():
        for x, y in tqdm(iterator, desc=desc):
            x = x.to(device)
            y = y.to(device)
            y_logits = model(x)
            loss = criterion(y_logits, y.long())
            y_prob = F.softmax(y_logits, dim=1)
            y_pred = y_prob.argmax(dim=1)
            loss += loss / len(iterator)
            CM += confusion_matrix(y.view(-1).tolist(), y_pred.view(-1).tolist(), labels=range(6))

    loss = loss.cpu()
    TP = CM.diagonal().copy()
    CM_no_diag = CM - np.diag(TP)
    FP = CM_no_diag.sum(axis=0)
    FN = CM_no_diag.sum(axis=1)
    TN = CM.sum() - (FP + FN + TP)
    
    TP = torch.from_numpy(TP)
    FP = torch.from_numpy(FP)
    TN = torch.from_numpy(TN)
    FN = torch.from_numpy(FN)

    metrics = compute_metrics(loss, TP, FP, TN, FN)

    return metrics, CM

In [22]:
def model_testing(model, test_iterator, criterion, device, checkpoint_name="checkpoint_deeplabv3plus.pt", results_name="results_deeplabv3plus.csv", CM_name="confusion_matrix_deeplabv3plus.csv"):
    checkpoint = torch.load(os.path.join("./checkpoints", checkpoint_name))
    model.load_state_dict(checkpoint["model_state_dict"])
    criterion.load_state_dict(checkpoint["criterion_state_dict"])
    metrics_log = {}
    metrics, CM = test(model, test_iterator, criterion, device, desc="Test")
    for key in metrics.keys():
        if key not in metrics_log:
            metrics_log[key] = []
        metrics_log[key].append(metrics[key])
    pd.DataFrame.from_dict(metrics_log).to_csv("./results/test_" + results_name, index=False)
    pd.DataFrame(CM).to_csv("./results/" + CM_name, index=False)

In [23]:
def plot_CM(class_names, CM_name="confusion_matrix.csv"):
    CM = pd.read_csv("./results/" + CM_name).values
    CM = normalize(CM, axis=1, norm="l1")
    
    fig, ax = plt.subplots(figsize=(9, 7))
    ax = sn.heatmap(CM, xticklabels=[s.split("_")[0] for s in class_names], yticklabels=[s.split("_")[0] for s in class_names], vmin=0, vmax=1, cmap="Blues", fmt=".2f", annot=True)
    ax.collections[0].colorbar.ax.tick_params(labelsize=16)
    plt.xticks(fontsize=18, rotation=90)
    plt.yticks(fontsize=18, rotation=0)
    plt.tight_layout()
    plt.show()

In [24]:
N_EPOCHS = 5
TRAIN = True

if TRAIN:
    train_metrics_log, valid_metrics_log = model_training(N_EPOCHS, model, train_iterator, valid_iterator, criterion, optimizer, device, "checkpoint_deeplabv3plus.pt", "results_deeplabv3plus.csv")


----------------------------------------------------------

Epoch: 1/5



Train:   0%|          | 0/6945 [00:00<?, ?it/s]


ValueError: The number of channels must be 3

In [None]:
def plot_results(checkpoint_name="checkpoint.pt", results_name="results.csv"):
    best_epoch = torch.load(os.path.join("./checkpoints", checkpoint_name))["epoch"]
    train_results_df = pd.read_csv("./results/train_" + results_name).iloc[:best_epoch]
    valid_results_df = pd.read_csv("./results/valid_" + results_name).iloc[:best_epoch]
    assert len(train_results_df) == len(valid_results_df)
    n_epochs = len(train_results_df)
    
    plt.figure(figsize=(21, 14))
    
    n_rows = 3
    n_cols = 3
    idx = 1
    
    y_ranges = [(0.4, 1.1), (0.25, 0.65), (0.4, 0.8)]
    
    for i, metric in enumerate(["loss", "MA_IoU", "MA_F1_score", "MA_precision", "MA_recall", "MA_accuracy"]):
        plt.subplot(n_rows, n_cols, idx)
        idx += 1
        plt.plot(np.arange(n_epochs) + 1, train_results_df[metric], linewidth=3, color="tab:blue", label="Train")
        plt.plot(np.arange(n_epochs) + 1, valid_results_df[metric], linewidth=3, color="tab:orange", label="Valid")
        if i < 3:
            plt.ylim(bottom=y_ranges[i][0], top=y_ranges[i][1])
        plt.xlabel("epoch", fontsize=14), plt.ylabel(metric, fontsize=14)
        plt.xticks(fontsize=12), plt.yticks(fontsize=12)
        plt.legend(fontsize=12)
        plt.grid()
        
    plt.show()

plot_results("checkpoint_deeplabv3plus.pt", "results_deeplabv3plus.csv")

In [None]:
TEST = True
if TEST:
    model_testing(model, test_iterator, criterion, device, "checkpoint_deeplabv3plus.pt", "results_deeplabv3plus.csv", "confusion_matrix_deeplabv3plus.csv")


In [None]:
plot_CM(class_names, "confusion_matrix_deeplabv3plus.csv")