### Recreating the DenseNet experiment described in the [original paper](https://www.researchgate.net/publication/335844699_SEN12MS_-_A_Curated_D_of_Georeferenced_Multi-Spectral_Sentinel-12_Imagery_for_Deep_Learning_and_Data_Fusion)


#### Input
- Full-sized 256x256 Sentinel-2 images from the summer subset
- 10 bands: B2, B3, B4, B8, B5, B6, B7, B8a, B11, and B12.


#### Label
- LCCS land use images
- 8 classes instead of 11: 20 and 25 combined; 30, 35, and 36 combined

#### Training parameters
- Categorical cross-entropy loss
- Adam optimizer with 0.0001 starting learning rate
- ReduceOnPlateau learning rate scheduler
- Batch size: 4



In [1]:
import pathlib
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchmetrics
import torch.nn as nn
import torchvision.models as models

import pytorch_lightning as pl

import models
import sen12ms_dataLoader as sen12ms

In [2]:
DATASET_PATH = pathlib.Path('/home/dubrovin/Projects/Data/SEN12MS/')
SEASON = sen12ms.Seasons.SUMMER

assert DATASET_PATH.is_dir(), 'Incorect location for the dataset'

In [3]:
class Dataset(torch.utils.data.Dataset):
    """ PyTotch wrapper for the dataloader provided by the dataset authors. """
    
    def __init__(self):
        self._dataset = sen12ms.SEN12MSDataset(base_dir=DATASET_PATH)
        
        # get a dictionary {scene_id: patch_ids} for the whole season
        season_ids = self._dataset.get_season_ids(season=SEASON)
        
        # flatten it into a list of tuples unique for 256x256 patch
        # (scene_id, patch_id)
        self.patch_unique_ids = []
        
        for scene_id, patch_ids in season_ids.items():
            for patch_id in patch_ids:
                self.patch_unique_ids.append((scene_id, patch_id))
        
        self.lc_bands = sen12ms.LCBands.landuse
        self.s2_bands = [
            sen12ms.S2Bands.B02,   # blue
            sen12ms.S2Bands.B03,   # green
            sen12ms.S2Bands.B04,   # red
            sen12ms.S2Bands.B08,   # near-infrared
            sen12ms.S2Bands.B05,   # red edge 1
            sen12ms.S2Bands.B06,   # red edge 2
            sen12ms.S2Bands.B07,   # red edge 3
            sen12ms.S2Bands.B08A,  # red edge 4
            sen12ms.S2Bands.B11,   # short-wave infrered 1
            sen12ms.S2Bands.B12    # short-wave infrared 2
        ]
        
        self.class_to_target_map = {
            1: 0,
            2: 1,
            3: 2,
            9: 3,
            10: 4,
            20: 5,
            30: 6,
            40: 7
        }
            
    def __len__(self):
        return len(self.patch_unique_ids)
    
    def __getitem__(self, idx):
        scene, patch = self.patch_unique_ids[idx]
        s2, _ = self._dataset.get_patch(SEASON, scene, patch, self.s2_bands)
        lc, _ = self._dataset.get_patch(SEASON, scene, patch, self.lc_bands)
        
        image = s2 - s2.mean(axis=(1, 2), keepdims=True)
        image /= s2.std(axis=(1, 2), keepdims=True)
        image = image.astype(np.float32)
        
        label = torch.zeros(256, 256, dtype=int)
        
        # combine classes 20 and 25; 30, 35, and 36
        # what's a better way to do this step?
        label[lc == 1] = 0
        label[lc == 2] = 1
        label[lc == 3] = 2
        label[lc == 9] = 3
        label[lc == 10] = 4
        label[lc == 20] = 5
        label[lc == 25] = 5
        label[lc == 30] = 6
        label[lc == 35] = 6
        label[lc == 36] = 6
        label[lc == 40] = 7

        return torch.tensor(image), label

In [4]:
class FCDenseNet103(pl.LightningModule):
    
    def __init__(self):
        super(FCDenseNet103, self).__init__()
        
        self.densenet = models.FCDenseNet103(in_channels=10, n_classes=8)
            
        self.criterion = nn.CrossEntropyLoss()
        self.accuracy = torchmetrics.Accuracy()
    
    def forward(self, x):
        x = self.densenet(x)
        return x
    
    def setup(self, stage):
        dataset = Dataset()
        n_val_examples = int(len(dataset) * 0.1)
        splits = [len(dataset) - n_val_examples, n_val_examples]
        self.train_data, self.val_data = torch.utils.data.random_split(dataset, splits)
    
    def train_dataloader(self):
        dataloader = torch.utils.data.DataLoader(self.train_data, batch_size=4, shuffle=True, num_workers=12, pin_memory=True)
        return dataloader
    
    def val_dataloader(self):
        dataloader = torch.utils.data.DataLoader(self.val_data, batch_size=4, num_workers=12, pin_memory=True)
        return dataloader
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        pred = self(x)
        loss = self.criterion(pred, y)
        accuracy = self.accuracy(pred.softmax(dim=1), y)
        
        batch_dict = {
            'loss': loss,
            'accuracy': accuracy,
        }
        
        return batch_dict
    
    def training_epoch_end(self, train_step_outputs):
        average_loss = torch.tensor([x['loss'] for x in train_step_outputs]).mean()
        average_accuracy = torch.tensor([x['accuracy'] for x in train_step_outputs]).mean()
        
        # log to TebsorBoard
        self.logger.experiment.add_scalar('Loss/train', average_loss, self.current_epoch)
        self.logger.experiment.add_scalar('Accuracy/train', average_accuracy, self.current_epoch)
    
    def validation_step(self, batch, batch_idx):
        batch_dict = self.training_step(batch, batch_idx)
        return batch_dict
    
    def validation_epoch_end(self, val_step_outputs):
        average_loss = torch.tensor([x['loss'] for x in val_step_outputs]).mean()
        average_accuracy = torch.tensor([x['accuracy'] for x in val_step_outputs]).mean()
        
        # log to TebsorBoard
        self.logger.experiment.add_scalar('Loss/validation', average_loss, self.current_epoch)
        self.logger.experiment.add_scalar('Accuracy/validation', average_accuracy, self.current_epoch)
        
        # log to the system for ReduceLROnPlateau and EarlyStopping / ModelCheckpoint
        self.log('system/val_loss', average_loss)
        self.log('system/val_acc', average_accuracy)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.0001)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
        return {'optimizer': optimizer, 'scheduler': scheduler, 'monitor': 'system/val_loss'}

In [5]:
stop_early = pl.callbacks.EarlyStopping(
    monitor='system/val_loss',
    patience=4,
    mode='min',
)

checkpoint_acc = pl.callbacks.ModelCheckpoint(
    monitor='system/val_acc',
    mode='max',
    every_n_val_epochs=1,
    dirpath='./best_models/',
    filename=r'fcdensenet103_v0_val_acc={system/val_acc:.2f}',
    auto_insert_metric_name=False,
    save_weights_only=False,
)

In [6]:
model = FCDenseNet103()
logger = pl.loggers.TensorBoardLogger('runs', 'fcdensenet', default_hp_metric=False)

trainer = pl.Trainer(
    logger=logger,
    gpus='1', 
    callbacks=[stop_early, checkpoint_acc],
    profiler='simple',
    num_sanity_val_steps=0,
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [None]:
trainer.fit(model)