##### This code is adapted from the online course 'Geospatial Deep Learning' provided by West Virginia View (http://www.wvview.org/).

In [None]:
from typing import Optional, List
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt 
import os
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data.dataset import Dataset
import albumentations as A
import segmentation_models_pytorch as smp
from segmentation_models_pytorch import utils
import rasterio
from sklearn.metrics import confusion_matrix
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchsummary import summary
import torchmetrics as tm
from kornia import losses
import os
print(os.getcwd())

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda_available = torch.cuda.is_available()
print("CUDA Available:", cuda_available)
if cuda_available:
    print("GPU Name:", torch.cuda.get_device_name(0)) 

In [None]:
train = pd.read_csv("../Data/train_chips.csv")
val = pd.read_csv("../Data/val_chips.csv")

In [None]:
MULTICLASS_MODE: str = "multiclass"
ENCODER = "resnet152"
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['background', 'water', 'builtup', 'bareland', 'thicket', 'agriculture', 'grass']
ACTIVATION = None
DEVICE = 'cuda'

In [None]:
class MultiClassSegDataset(Dataset):
    
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
    
    def __getitem__(self, idx):
        image_name = self.df.iloc[idx, 1]
        mask_name = self.df.iloc[idx, 2]
        image = cv2.imread(image_name)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_name, cv2.IMREAD_UNCHANGED)
        image = image.astype('uint8')
        
        
        if len(mask.shape) == 2:
            pass  
        else:
            mask = mask[:,:,0]  
        
        if self.transform is not None:
            transformed = self.transform(image=image, mask=mask)
            image = transformed["image"]
            mask = transformed["mask"]
        
        
        image = torch.from_numpy(image).permute(2, 0, 1).float() / 255
        mask = torch.from_numpy(mask.astype(np.uint8)).long()  
        
        return image, mask  
    
    def __len__(self):
        return len(self.df)


In [None]:
test_transform = A.Compose(
    [A.PadIfNeeded(min_height=64, min_width=64, border_mode=4), A.Resize(64, 64),]
)

train_transform = A.Compose(
    [
        A.PadIfNeeded(min_height=64, min_width=64, border_mode=4),
        A.Resize(64, 64),
        A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.MedianBlur(blur_limit=3, always_apply=False, p=0.1),
    ]
)

In [None]:
trainDS = MultiClassSegDataset(train, transform=train_transform)
valDS = MultiClassSegDataset(val, transform=test_transform)
print("Number of Training Samples: " + str(len(trainDS)) + " Number of validation Samples: " + str(len(valDS)))

In [None]:
trainDL = DataLoader(trainDS, batch_size=32, shuffle=True, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=True, timeout=0,
           worker_init_fn=None)
valDL =  DataLoader(valDS, batch_size=32, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=True, timeout=0,
           worker_init_fn=None)

In [None]:
batch = next(iter(trainDL))
images, labels = batch
print(images.shape, labels.shape, type(images), type(labels), images.dtype, labels.dtype)

In [None]:
def my_metrics(cm):
  oa = np.sum(np.diagonal(cm))/np.sum(cm)
  r_1 = cm[0][0]/np.sum(cm[:,0])
  r_2 = cm[1][1]/np.sum(cm[:,1])
  r_3 = cm[2][2]/np.sum(cm[:,2])
  r_4 = cm[3][3]/np.sum(cm[:,3])
  r_5 = cm[4][4]/np.sum(cm[:,4])
  r_6 = cm[5][5]/np.sum(cm[:,5])
  p_1 = cm[0][0]/np.sum(cm[0,:])
  p_2 = cm[1][1]/np.sum(cm[1,:])
  p_3 = cm[2][2]/np.sum(cm[2,:])
  p_4 = cm[3][3]/np.sum(cm[3,:])
  p_5 = cm[4][4]/np.sum(cm[4,:])
  p_6 = cm[5][5]/np.sum(cm[5,:])
  f_1 = (2*r_1*p_1)/(r_1+p_1)
  f_2 = (2*r_2*p_2)/(r_2+p_2)
  f_3 = (2*r_3*p_3)/(r_3+p_3)
  f_4 = (2*r_4*p_4)/(r_4+p_4) 
  f_5 = (2*r_5*p_5)/(r_5+p_5)
  f_6 = (2*r_6*p_6)/(r_6+p_6)
  met_out = pd.Series([oa, f_1, f_2, f_3, f_4, f_5, f_6], 
  index=["oa", "f_1", "f_2", "f_3", "f_4", "f_5", "f_6"])
  return met_out

In [None]:
model = smp.UnetPlusPlus(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    decoder_use_batchnorm=True,
    in_channels=3,
    classes=len(CLASSES),
    activation=ACTIVATION,
).to(torch.device("cuda", 0))


In [None]:
criterion = smp.losses.DiceLoss(mode="multiclass", from_logits=True, ignore_index=0)

optimizer = torch.optim.SGD(model.parameters(), lr=0.01)


epochs = 400
total_steps = epochs * len(trainDL)  
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=3e-1, total_steps=total_steps)


In [None]:
all_metsTrain = pd.DataFrame(columns=["oa", "f_1", "f_2", "f_3", "f_4", "f_5", "f_6"])
all_metsVal = pd.DataFrame(columns=["oa", "f_1", "f_2", "f_3", "f_4", "f_5", "f_6"])

In [None]:
size = len(trainDL.dataset)

In [None]:
accum_iter = 30

In [None]:
n_classes = 6 

In [None]:
device="cuda"

In [None]:
for t in range(epochs):
    cmTrain = np.zeros([6, 6], dtype=int)
    cmVal = np.zeros([6, 6], dtype=int)
    for batch_idx, (x_batch, y_batch) in enumerate(trainDL):
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        with torch.set_grad_enabled(True):
            pred = model(x_batch)
            pred2 = pred[:, 1:n_classes+1, :, :]
            pred3 = torch.argmax(pred2, dim=1)
            predNP = pred3.detach().cpu().numpy().flatten()
            predNP = predNP + 1
            refNP = y_batch.detach().cpu().numpy().flatten()
            try:
                cmTB = confusion_matrix(refNP, predNP, labels=[1,2,3,4,5,6])
            except ValueError as e:
                print("Skipping a batch due to missing labels: ", e)
                continue  
            
            lossT = criterion(pred, y_batch)
            lossT = lossT / accum_iter
            lossT.backward()
            scheduler.step()
            if ((batch_idx + 1) % accum_iter == 0) or (batch_idx + 1 == size):
                optimizer.step()
                optimizer.zero_grad()
            cmTrain += cmTB
    
    for batch_idx, (x_batch, y_batch) in enumerate(valDL):
        with torch.no_grad():
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)
            pred = model(x_batch)
            pred2 = pred[:, 1:n_classes+1, :, :]
            pred3= torch.argmax(pred2, dim=1)
            predNP = pred3.detach().cpu().numpy().flatten()
            predNP = predNP + 1
            refNP = y_batch.detach().cpu().numpy().flatten()
            try:
                cmVB = confusion_matrix(refNP, predNP, labels=[1,2,3,4,5,6])
            except ValueError as e:
                print("Skipping a batch due to missing labels: ", e)
                continue  
            lossV = criterion(pred, y_batch)
            cmVal += cmVB
    metsTrain = my_metrics(cmTrain)
    metsTrain['loss'] = lossT.detach().cpu().item()  

    metsTrain = metsTrain.fillna(0)
    all_metsTrain = pd.concat([all_metsTrain, metsTrain.to_frame().T], ignore_index=True)

    metsVal = my_metrics(cmVal)
    metsVal['loss'] = lossV.detach().cpu().item()
    metsVal = metsVal.fillna(0)
    all_metsVal = pd.concat([all_metsVal, metsVal.to_frame().T], ignore_index=True)
    all_metsTrain.to_csv("../Data/Deep_Learning/train_epoch_metrics.csv")
    all_metsVal.to_csv("../Data/Deep_Learning/val_epoch_metrics.csv")
    model_name = "../Data/Deep_Learning/model_out_" + str(t) + ".pth"
    torch.save(model.state_dict(), model_name)
    print(f"Epoch {t+1}\nTrain Loss: {lossT}\nVal Loss: {lossV}\nTraining Metrics: {metsTrain}\nVal Metrics: {metsVal}")

In [None]:
val_data = pd.read_csv("../Data/Deep_Learning/val_epoch_metrics.csv")
High_accuracy = max(val_data["oa"].values)
print("Highest accuracy:", High_accuracy)

In [None]:
epoch_num = val_data[val_data["oa"] == High_accuracy]
x= epoch_num.iloc[0, 0]
print("The epoch with highest oa is: ",x)

In [None]:
model = smp.UnetPlusPlus(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    decoder_use_batchnorm=True,
    in_channels=3,
    classes=len(CLASSES),
    activation=ACTIVATION,
).to(torch.device("cuda", 0))

In [None]:
saveFolder = "../Data/Deep_Learning/"
best_weights = torch.load(saveFolder+'model_out_280.pth')
model.load_state_dict(best_weights)

In [None]:
testDF = pd.read_csv("../Data/test_chips.csv")

In [None]:
testDF.head()

In [None]:
testDS = MultiClassSegDataset(testDF, transform=test_transform)

In [None]:
testDL =  DataLoader(testDS, batch_size=32, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=True, timeout=0,
           worker_init_fn=None)

In [None]:
model.eval()

cmTest = np.zeros((6, 6), dtype=int)  

with torch.no_grad():
    for x_batch, y_batch in testDL:
        x_batch = x_batch.to(device)
        pred = model(x_batch)
        pred2 = pred[:, 1:n_classes+1, :, :]  
        pred_classes = torch.argmax(pred2, dim=1)
        predNP = pred_classes.detach().cpu().numpy().flatten()
        predNP = predNP + 1  
        refNP = y_batch.detach().cpu().numpy().flatten()
        cmTB = confusion_matrix(refNP, predNP, labels=[1, 2, 3, 4, 5, 6])
        cmTest += cmTB

metsTest = my_metrics(cmTest)

print(f"Test Metrics:\n{metsTest}")


In [None]:
train_data = pd.read_csv("../Data/Deep_Learning/train_epoch_metrics.csv")
val_data = pd.read_csv("../Data/Deep_Learning/val_epoch_metrics.csv")


epochs = train_data.iloc[:, 0].values
train_accuracy = train_data["oa"].values
val_accuracy = val_data["oa"].values

# Create the plot
plt.figure(figsize=(10, 6))
plt.plot(epochs, train_accuracy, label='Training Accuracy', color='blue')
plt.plot(epochs, val_accuracy, label='Validation Accuracy', color='#FF8C00')


plt.ylim(0.1, 1.0)
plt.yticks([0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],fontsize=12)
plt.xticks(fontsize=12)
plt.xlabel('Epoch', fontsize=14, labelpad=10)
plt.ylabel('Accuracy', fontsize=14, labelpad=10)

plt.legend(loc='lower right', fontsize=14)

plt.grid(axis='y')
plt.show() 