In [1]:
#IMPORTS

#File IO
import os
import glob

#Data manipulation
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

#Pytorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, random_split
import torch.nn.functional as F
from segmentation_models_pytorch import Unet

#Scikit learn
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.metrics import mean_squared_error, r2_score

#Misc
from tqdm import tqdm


In [2]:
#HYPERPARMETERS

train_proportion = .8
val_proportion = .1

batch_size = 256
learning_rate = .00001
num_epochs = 50

In [3]:
#CREATE DATASET

dataset = torch.load("mask_unet_dataset.pt", weights_only=False)

In [4]:
#CREATE DATALOADERS

train_size = int(train_proportion * len(dataset))
val_size = int(val_proportion * len(dataset))
test_size = len(dataset) - train_size - val_size

generator = torch.Generator().manual_seed(1)
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [5]:
#CREATE MODEL

num_classes = 2
model = Unet(
    encoder_name="resnet34",
    in_channels=16,
    classes=num_classes
)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
model = model.to(device)

train_losses, val_losses = [], []
train_accuracies, val_accuracies = [], []

In [6]:
#TRAIN and EVALUATE FUNCTIONS

def train(model, train_loader):
    model.train()
    train_loss = correct = total = 0

    for inputs, labels in tqdm(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        preds = torch.argmax(outputs, dim=1)
        
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        correct += (preds == labels).sum().item()
        total += labels.numel()

    #total is number of train instances
    avg_loss = train_loss/len(train_loader)
    accuracy = correct/total
    return avg_loss, accuracy

def eval(model, val_loader):
    model.eval()
    val_loss = correct = total = 0

    for inputs, labels in val_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        preds = torch.argmax(outputs, dim=1)

        loss = criterion(outputs, labels)

        val_loss += loss.item()
        correct += (preds == labels).sum().item()
        total += labels.numel()

    avg_loss = val_loss/len(val_loader)
    accuracy = correct/total
    return avg_loss, accuracy

In [None]:
#TRAIN MODEL

for e in range(1, num_epochs+1):
    train_loss, train_acc = train(model, train_loader)
    val_loss, val_acc = eval(model, val_loader)
    
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accuracies.append(train_acc)
    val_accuracies.append(val_acc)

    print(f"Epoch: {e} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")

100%|██████████| 47/47 [00:14<00:00,  3.30it/s]


Epoch: 1 | Train Loss: 0.8147 | Train Acc: 0.4325 | Val Loss: 0.7326 | Val Acc: 0.4978


100%|██████████| 47/47 [00:12<00:00,  3.89it/s]


Epoch: 2 | Train Loss: 0.6393 | Train Acc: 0.6216 | Val Loss: 0.5904 | Val Acc: 0.6810


100%|██████████| 47/47 [00:15<00:00,  2.97it/s]


Epoch: 3 | Train Loss: 0.5490 | Train Acc: 0.7562 | Val Loss: 0.5161 | Val Acc: 0.7788


100%|██████████| 47/47 [00:14<00:00,  3.25it/s]


Epoch: 4 | Train Loss: 0.4861 | Train Acc: 0.8321 | Val Loss: 0.4505 | Val Acc: 0.8504


100%|██████████| 47/47 [00:15<00:00,  3.11it/s]


Epoch: 5 | Train Loss: 0.4414 | Train Acc: 0.8711 | Val Loss: 0.4071 | Val Acc: 0.8892


100%|██████████| 47/47 [00:15<00:00,  3.00it/s]


Epoch: 6 | Train Loss: 0.4096 | Train Acc: 0.8933 | Val Loss: 0.3927 | Val Acc: 0.8963


100%|██████████| 47/47 [00:15<00:00,  3.05it/s]


Epoch: 7 | Train Loss: 0.3858 | Train Acc: 0.9070 | Val Loss: 0.4014 | Val Acc: 0.8949


100%|██████████| 47/47 [00:15<00:00,  3.10it/s]


Epoch: 8 | Train Loss: 0.3663 | Train Acc: 0.9153 | Val Loss: 0.3877 | Val Acc: 0.9056


100%|██████████| 47/47 [00:15<00:00,  3.12it/s]


Epoch: 9 | Train Loss: 0.3494 | Train Acc: 0.9215 | Val Loss: 0.3494 | Val Acc: 0.9134


100%|██████████| 47/47 [00:15<00:00,  3.13it/s]


Epoch: 10 | Train Loss: 0.3348 | Train Acc: 0.9254 | Val Loss: 0.3300 | Val Acc: 0.9228


100%|██████████| 47/47 [00:15<00:00,  2.98it/s]


Epoch: 11 | Train Loss: 0.3218 | Train Acc: 0.9286 | Val Loss: 0.3307 | Val Acc: 0.9206


100%|██████████| 47/47 [00:12<00:00,  3.80it/s]


Epoch: 12 | Train Loss: 0.3100 | Train Acc: 0.9315 | Val Loss: 0.3593 | Val Acc: 0.9032


100%|██████████| 47/47 [00:49<00:00,  1.05s/it]


Epoch: 13 | Train Loss: 0.2999 | Train Acc: 0.9331 | Val Loss: 0.3099 | Val Acc: 0.9230


100%|██████████| 47/47 [00:15<00:00,  2.95it/s]


In [None]:
#PLOT LOSS and ACCURACY

plt.figure(figsize=(12,5))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.title('Cloud Mask Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label='Train Acc')
plt.plot(val_accuracies, label='Val Acc')
plt.title('Cloud Mask Accuracy')
plt.legend()

plt.savefig("./graphs/unet_cloud_mask.png")  

plt.show()

In [None]:
#MODEL EVALUATION

all_preds = []
all_labels = []

model.eval()
with torch.no_grad():
    for images, labels in tqdm(test_loader):
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        preds = torch.argmax(outputs, dim=1)

        all_preds.append(preds.cpu().numpy())
        all_labels.append(labels.cpu().numpy())
        
all_preds = np.concatenate([p.flatten() for p in all_preds])
all_labels = np.concatenate([l.flatten() for l in all_labels])


report = classification_report(all_labels, all_preds, digits=3, output_dict=True)
f1_scores = np.array([report[str(i)]['f1-score'] for i in range(num_classes)])
supports = np.array([report[str(i)]['support'] for i in range(num_classes)])
iou = f1_scores / (2 - f1_scores)

print("REPORT:\n", classification_report(all_labels, all_preds, digits=3))
print("CONFUSION MATRIX:\n", confusion_matrix(all_labels, all_preds))
print("\nIOU:", iou)
print("Unweighted:", np.mean(iou))
print("Weighted:", np.average(iou, weights=supports))