### Recreating the ResNet-110 experiment described in the [original paper](https://www.researchgate.net/publication/335844699_SEN12MS_-_A_Curated_D_of_Georeferenced_Multi-Spectral_Sentinel-12_Imagery_for_Deep_Learning_and_Data_Fusion)


#### Input
- 64 × 64 Sentinel-2 images from the summer subset
- 10 bands: B2, B3, B4, B8, B5, B6, B7, B8a, B11, and B12.


#### Label
- Majority LCCS land use class from each of the 64 × 64 patches
- 8 classes instead of 11: 20 and 25 combined; 30, 35, and 36 combined

#### Training parameters
- Categorical cross-entropy loss
- Adam optimizer with 0.0005 starting learning rate
- ReduceOnPlateau learning rate scheduler
- Batch size: 16


In [1]:
import pathlib
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchmetrics
import torch.nn as nn
import torchvision.models as models

import pytorch_lightning as pl

import sen12ms_dataLoader as sen12ms
import constants

In [2]:
DATASET_PATH = pathlib.Path('/home/dubrovin/Projects/Data/SEN12MS/')
SEASON = sen12ms.Seasons.SUMMER

assert DATASET_PATH.is_dir(), 'Incorect location for the dataset'

In [3]:
class Dataset(torch.utils.data.Dataset):
    """ PyTotch wrapper for the dataloader provided by the dataset authors. """
    
    def __init__(self):
        self._dataset = sen12ms.SEN12MSDataset(base_dir=DATASET_PATH)
        
        # get a dictionary {scene_id: patch_ids} for the whole season
        season_ids = self._dataset.get_season_ids(season=SEASON)
        
        # flatten it into a list of tuples unique for each 64x64 patch
        # (scene_id, patch_id, subpatch_idx)
        self.subpatch_unique_ids = []
        
        for scene_id, patch_ids in season_ids.items():
            for patch_id in patch_ids:
                # there are 16 64x64 patches in 256x256 image
                for i in range(16):
                    self.subpatch_unique_ids.append((scene_id, patch_id, i))
        
        self.lc_bands = sen12ms.LCBands.landuse
        self.s2_bands = [
            sen12ms.S2Bands.B02,   # blue
            sen12ms.S2Bands.B03,   # green
            sen12ms.S2Bands.B04,   # red
            sen12ms.S2Bands.B08,   # near-infrared
            sen12ms.S2Bands.B05,   # red edge 1
            sen12ms.S2Bands.B06,   # red edge 2
            sen12ms.S2Bands.B07,   # red edge 3
            sen12ms.S2Bands.B08A,  # red edge 4
            sen12ms.S2Bands.B11,   # short-wave infrered 1
            sen12ms.S2Bands.B12    # short-wave infrared 2
        ]
        
        self.class_to_target_map = {
            0: 0,  # turns out, some subpatches [i.e. idx=54361] have mode NODATA, 0
            1: 0,
            2: 1,
            3: 2,
            9: 3,
            10: 4,
            20: 5,
            30: 6,
            40: 7
        }
            
    def __len__(self):
        return len(self.subpatch_unique_ids)
    
    def __getitem__(self, idx):
        scene, patch, subpatch = self.subpatch_unique_ids[idx]
        s2, _ = self._dataset.get_patch(SEASON, scene, patch, self.s2_bands)
        lc, _ = self._dataset.get_patch(SEASON, scene, patch, self.lc_bands)
        
        i = subpatch // 4  # row number of the 64x64 subpatch of the 256x256 image
        j = subpatch % 4   # column number of the 64x64 subpatch
        s2_subpatch = s2[:, i * 64:(i + 1) * 64, j * 64:(j + 1) * 64]
        lc_subpatch = lc[:, i * 64:(i + 1) * 64, j * 64:(j + 1) * 64]
        
        image = s2_subpatch - s2_subpatch.mean(axis=(1, 2), keepdims=True)
        image /= s2_subpatch.std(axis=(1, 2), keepdims=True)
        image = image.astype(np.float32)
        
        # combine classes 20 and 25; 30, 35, and 36
        lc_subpatch[lc_subpatch == 25] = 20
        lc_subpatch[lc_subpatch == 35] = 30
        lc_subpatch[lc_subpatch == 36] = 30
        
        # use the most common value as the label
        values, counts = np.unique(lc_subpatch, return_counts=True)
        mode = values[np.argmax(counts)]
        label = self.class_to_target_map[mode]

        return torch.tensor(image), torch.tensor(label, dtype=torch.long)

In [4]:
class ResNet110(pl.LightningModule):
    
    def __init__(self):
        super(ResNet110, self).__init__()
        
        self.resnet = models.resnet101()
        self.resnet.conv1 = nn.Conv2d(10, 64, 7, 2, 3, bias=False)
        self.resnet.fc = nn.Linear(2048, 8)
        
        # to transform the ResNet-101 to ResNet-110, add 3 extra bottleneck blocks
        # each bottleneck block adds 3 layers, 101 + 3 * 3 = 110
        for i in range(3):
            self.resnet.layer3.add_module(f'extra_{i}', models.resnet.Bottleneck(1024, 256))
        
        self.criterion = nn.CrossEntropyLoss()
        self.accuracy = torchmetrics.Accuracy()
    
    def forward(self, x):
        x = self.resnet(x)
        return x
    
    def setup(self, stage):
        dataset = Dataset()
        n_val_examples = int(len(dataset) * 0.1)
        splits = [len(dataset) - n_val_examples, n_val_examples]
        self.train_data, self.val_data = torch.utils.data.random_split(dataset, splits)
    
    def train_dataloader(self):
        dataloader = torch.utils.data.DataLoader(self.train_data, batch_size=16, shuffle=True, num_workers=12, pin_memory=True)
        return dataloader
    
    def val_dataloader(self):
        dataloader = torch.utils.data.DataLoader(self.val_data, batch_size=16, num_workers=12, pin_memory=True)
        return dataloader
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        pred = self(x)
        loss = self.criterion(pred, y)
        accuracy = self.accuracy(pred.softmax(dim=1), y)
        
        batch_dict = {
            'loss': loss,
            'accuracy': accuracy,
        }
        
        return batch_dict
    
    def training_epoch_end(self, train_step_outputs):
        average_loss = torch.tensor([x['loss'] for x in train_step_outputs]).mean()
        average_accuracy = torch.tensor([x['accuracy'] for x in train_step_outputs]).mean()
        
        # log to TebsorBoard
        self.logger.experiment.add_scalar('Loss/train', average_loss, self.current_epoch)
        self.logger.experiment.add_scalar('Accuracy/train', average_accuracy, self.current_epoch)
    
    def validation_step(self, batch, batch_idx):
        batch_dict = self.training_step(batch, batch_idx)
        return batch_dict
    
    def validation_epoch_end(self, val_step_outputs):
        average_loss = torch.tensor([x['loss'] for x in val_step_outputs]).mean()
        average_accuracy = torch.tensor([x['accuracy'] for x in val_step_outputs]).mean()
        
        # log to TebsorBoard
        self.logger.experiment.add_scalar('Loss/validation', average_loss, self.current_epoch)
        self.logger.experiment.add_scalar('Accuracy/validation', average_accuracy, self.current_epoch)
        
        # log to the system for ReduceLROnPlateau and EarlyStopping / ModelCheckpoint
        self.log('system/val_loss', average_loss)
        self.log('system/val_acc', average_accuracy)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.0005)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
        return {'optimizer': optimizer, 'scheduler': scheduler, 'monitor': 'system/val_loss'}

In [5]:
stop_early = pl.callbacks.EarlyStopping(
    monitor='system/val_loss',
    patience=4,
    mode='min',
)

checkpoint_acc = pl.callbacks.ModelCheckpoint(
    monitor='system/val_acc',
    mode='max',
    every_n_val_epochs=1,
    dirpath='./best_models/',
    filename=r'resnet110_v0_val_acc={system/val_acc:.2f}',
    auto_insert_metric_name=False,
    save_weights_only=False,
)

In [6]:
model = ResNet110()
logger = pl.loggers.TensorBoardLogger('runs', 'resnet110', default_hp_metric=False)

trainer = pl.Trainer(
    logger=logger,
    gpus=1, 
    callbacks=[stop_early, checkpoint_acc],
    profiler='simple',
    num_sanity_val_steps=0,
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [None]:
trainer.fit(model)