# This Example shows the Segmentation of Cloud38 satellite images using the deep learning model UNet.

Find the details of the DeepSAT-V2 model in the <a href="https://link.springer.com/chapter/10.1007/978-3-319-24574-4_28">corresponding paper</a>

Find the details of the dataset <a href="https://www.kaggle.com/datasets/sorour/38cloud-cloud-segmentation-in-satellite-images">here</a>

### Import Modules and Define Parameters

In [None]:
import sys
import os
import numpy as np
import time
import torch
import torch.nn as nn
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from geotorchai.models.raster import UNet
from geotorchai.datasets.raster import Cloud38

In [None]:
## Define parameters
epoch_nums = 100
learning_rate = 0.0002
batch_size = 8
validation_split = 0.2
shuffle_dataset = True
random_seed = int(time.time())
params = {'batch_size': batch_size, 'shuffle': False}

## make sure that PATH_TO_DATASET exists in the running directory
PATH_TO_DATASET = "data/cloud38"
MODEL_SAVE_PATH = "model-unet"
os.makedirs(MODEL_SAVE_PATH, exist_ok=True)

### Load Data and Add Normalization Transformation

In [None]:
## load data and calculate mean and std to perform normalization transform
## Set download=True if dataset is not available in the given path
fullData = Cloud38(root = PATH_TO_DATASET)

full_loader = DataLoader(fullData, batch_size= batch_size)
channels_sum, channels_squared_sum, num_batches = 0, 0, 0
for i, sample in enumerate(full_loader):
    data_temp, _ = sample
    channels_sum += torch.mean(data_temp, dim=[0, 2, 3])
    channels_squared_sum += torch.mean(data_temp**2, dim=[0, 2, 3])
    num_batches += 1

mean = channels_sum / num_batches
std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5

In [None]:
## Define the transform operation
sat_transform = transforms.Normalize(mean, std)
## Load data with desired transformation and additional handcrafted features enabled
fullData = Cloud38(root = PATH_TO_DATASET, transform = sat_transform)

### Split Dataset into Train and Validation

In [None]:
## Initialize training and validation indices to split the dataset
dataset_size = len(fullData)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset:
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

In [None]:
## Define training and validation data sampler
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

## Define training and validation data loader
train_loader = DataLoader(fullData, **params, sampler=train_sampler)
val_loader = DataLoader(fullData, **params, sampler=valid_sampler)

### Initialize Model and Hyperparameters

In [None]:
## set device to CPU or GPU
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
## Define Model
model = UNet(4, 2)
## Define hyper-parameters
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
model.to(device)
loss_fn.to(device)

### Method for Returning Validation Accuracy

In [None]:
## Before starting training, define a method to calculate validation accuracy
def get_validation_accuracy(model, data_loader, device):
    model.eval()
    total_sample = 0
    running_acc = 0.0
    for i, sample in enumerate(data_loader):
        inputs, labels = sample
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        predicted = outputs.argmax(dim=1)
        running_acc += (predicted == labels).float().mean().item()*len(labels)
        total_sample += len(labels)

    accuracy = 100 * running_acc / total_sample
    return accuracy

### Train and Evaluate Model

In [None]:
## Perform training and validation
max_val_accuracy = None
for e in range(epoch_nums):
    for i, sample in enumerate(train_loader):
        inputs, labels = sample
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print('Epoch [{}/{}], Training Loss: {:.4f}'.format(e + 1, epoch_nums, loss.item()))

    ## Perform model validation after finishing each epoch training
    val_accuracy = get_validation_accuracy(model, val_loader, device)
    print("Validation Accuracy: ", val_accuracy, "%")

    if max_val_accuracy == None or val_accuracy > max_val_accuracy:
        max_val_accuracy = val_accuracy
        torch.save(model.state_dict(), MODEL_SAVE_PATH)
        print('Best model saved!')