In [None]:
import sys
import os
import numpy as np
import time
import torch
import torch.nn as nn
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from geotorch.models.raster import DeepSatV2
from geotorch.datasets.raster import EuroSAT

In [None]:
## Define parameters
epoch_nums = 100
learning_rate = 0.0002
batch_size = 16
validation_split = 0.2
shuffle_dataset = True
random_seed = int(time.time())
params = {'batch_size': batch_size, 'shuffle': False}

## make sure that PATH_TO_DATASET exists in the running directory
PATH_TO_DATASET = "data/eurosat"
MODEL_SAVE_PATH = "model-deepsatv2"
os.makedirs(MODEL_SAVE_PATH, exist_ok=True)

In [None]:
## load data and calculate mean and std to perform normalization transform
## Set download=True if dataset is not available in the given path
fullData = EuroSAT(root = PATH_TO_DATASET, download=False)

full_loader = DataLoader(fullData, batch_size= batch_size)
channels_sum, channels_squared_sum, num_batches = 0, 0, 0
for i, sample in enumerate(full_loader):
    data_temp, _ = sample
    channels_sum += torch.mean(data_temp, dim=[0, 2, 3])
    channels_squared_sum += torch.mean(data_temp**2, dim=[0, 2, 3])
    num_batches += 1

mean = channels_sum / num_batches
std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5

In [None]:
## Define the transform operation
sat_transform = transforms.Normalize(mean, std)
## Load data with desired transformation and additional handcrafted features enabled
fullData = EuroSAT(root = PATH_TO_DATASET, include_additional_features = True, transform = sat_transform)

In [None]:
## Initialize training and validation indices to split the dataset
dataset_size = len(fullData)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset:
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

In [None]:
## Define training and validation data sampler
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

## Define training and validation data loader
train_loader = DataLoader(fullData, **params, sampler=train_sampler)
val_loader = DataLoader(fullData, **params, sampler=valid_sampler)

In [None]:
## set device to CPU or GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
## Define Model
model = DeepSatV2(13, 64, 64, 10, len(fullData.ADDITIONAL_FEATURES))
## Define hyper-parameters
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
model.to(device)
loss_fn.to(device)

In [None]:
## Before starting training, define a method to calculate validation accuracy
def get_validation_accuracy(model, data_loader, device):
    model.eval()
    total_sample = 0
    correct = 0
    for i, sample in enumerate(data_loader):
        inputs, labels, features = sample
        inputs = inputs.to(device)
        features = features.type(torch.FloatTensor).to(device)
        labels = labels.to(device)

        outputs = model(inputs, features)
        total_sample += len(labels)

        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
    
    accuracy = 100 * correct / total_sample
    return accuracy

In [None]:
## Perform training and validation
max_val_accuracy = None
for e in range(epoch_nums):
    for i, sample in enumerate(train_loader):
        inputs, labels, features = sample
        inputs = inputs.to(device)
        features = features.type(torch.FloatTensor).to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(inputs, features)
        loss = loss_fn(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print('Epoch [{}/{}], Training Loss: {:.4f}'.format(e + 1, epoch_nums, loss.item()))

    ## Perform model validation after finishing each epoch training
    val_accuracy = get_validation_accuracy(model, val_loader, device)
    print("Validation Accuracy: ", val_accuracy, "%")

    if max_val_accuracy == None or val_accuracy > max_val_accuracy:
        max_val_accuracy = val_accuracy
        torch.save(model.state_dict(), MODEL_SAVE_PATH)
        print('Best model saved!')