### Recreating the ResNet-110 experiment described in the [original paper](https://www.researchgate.net/publication/335844699_SEN12MS_-_A_Curated_D_of_Georeferenced_Multi-Spectral_Sentinel-12_Imagery_for_Deep_Learning_and_Data_Fusion)

### *In that experiment we add two denosed bands from Sentinel-1

#### Input
- 64 × 64 Sentinel-2 and Sentinel-1 images from the summer subset
- 12 bands: B2, B3, B4, B8, DVV, DVH, B5, B6, B7, B8a, B11, and B12.


#### Label
- Majority LCCS land use class from each of the 64 × 64 patches
- 8 classes instead of 11: 20 and 25 combined; 30, 35, and 36 combined

#### Training parameters
- Categorical cross-entropy loss
- Adam optimizer with 0.0005 starting learning rate
- ReduceOnPlateau learning rate scheduler
- Batch size: 16


In [1]:
import pathlib
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchmetrics
import torch.nn as nn
import torchvision.models as models

import pytorch_lightning as pl

from utils import sen12ms_dataLoader as sen12ms
from utils import SEN12MSDataset_64x64subpatches

In [2]:
DATASET_PATH = pathlib.Path('/home/dubrovin/Projects/Data/SEN12MS/')
SEASON = sen12ms.Seasons.SUMMER

assert DATASET_PATH.is_dir(), 'Incorect location for the dataset'

In [3]:
class ResNet110(pl.LightningModule):
    
    def __init__(self):
        super(ResNet110, self).__init__()
        
        self.resnet = models.resnet101()
        self.resnet.conv1 = nn.Conv2d(12, 64, 7, 2, 3, bias=False)
        self.resnet.fc = nn.Linear(2048, 8)
        
        # to transform the ResNet-101 to ResNet-110, add 3 extra bottleneck blocks
        # each bottleneck block adds 3 layers, 101 + 3 * 3 = 110
        for i in range(3):
            self.resnet.layer3.add_module(f'extra_{i}', models.resnet.Bottleneck(1024, 256))
        
        self.criterion = nn.CrossEntropyLoss()
        self.accuracy = torchmetrics.Accuracy()
    
    def forward(self, x):
        x = self.resnet(x)
        return x
    
    def setup(self, stage):
        dataset = SEN12MSDataset_64x64subpatches(DATASET_PATH, SEASON)
        n_val_examples = int(len(dataset) * 0.1)
        splits = [len(dataset) - n_val_examples, n_val_examples]
        self.train_data, self.val_data = torch.utils.data.random_split(dataset, splits)
    
    def train_dataloader(self):
        dataloader = torch.utils.data.DataLoader(self.train_data, batch_size=16, shuffle=True, num_workers=12, pin_memory=True)
        return dataloader
    
    def val_dataloader(self):
        dataloader = torch.utils.data.DataLoader(self.val_data, batch_size=16, num_workers=12, pin_memory=True)
        return dataloader
    
    def training_step(self, batch, batch_idx):
        x, y = batch
#         x = x[:, 2:]
        pred = self(x)
        loss = self.criterion(pred, y)
        accuracy = self.accuracy(pred.softmax(dim=1), y)
        
        batch_dict = {
            'loss': loss,
            'accuracy': accuracy,
        }
        
        return batch_dict
    
    def training_epoch_end(self, train_step_outputs):
        average_loss = torch.tensor([x['loss'] for x in train_step_outputs]).mean()
        average_accuracy = torch.tensor([x['accuracy'] for x in train_step_outputs]).mean()
        
        # log to TebsorBoard
        self.logger.experiment.add_scalar('Loss/train', average_loss, self.current_epoch)
        self.logger.experiment.add_scalar('Accuracy/train', average_accuracy, self.current_epoch)
    
    def validation_step(self, batch, batch_idx):
        batch_dict = self.training_step(batch, batch_idx)
        return batch_dict
    
    def validation_epoch_end(self, val_step_outputs):
        average_loss = torch.tensor([x['loss'] for x in val_step_outputs]).mean()
        average_accuracy = torch.tensor([x['accuracy'] for x in val_step_outputs]).mean()
        
        # log to TebsorBoard
        self.logger.experiment.add_scalar('Loss/validation', average_loss, self.current_epoch)
        self.logger.experiment.add_scalar('Accuracy/validation', average_accuracy, self.current_epoch)
        
        # log to the system for ReduceLROnPlateau and EarlyStopping / ModelCheckpoint
        self.log('system/val_loss', average_loss)
        self.log('system/val_acc', average_accuracy)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.0005)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
        return {'optimizer': optimizer, 'scheduler': scheduler, 'monitor': 'system/val_loss'}

In [4]:
stop_early = pl.callbacks.EarlyStopping(
    monitor='system/val_loss',
    patience=4,
    mode='min',
)

checkpoint_acc = pl.callbacks.ModelCheckpoint(
    monitor='system/val_acc',
    mode='max',
    every_n_val_epochs=1,
    dirpath='./best_models/',
    filename=r'resnet110_den_v0_val_acc={system/val_acc:.2f}',
    auto_insert_metric_name=False,
    save_weights_only=False,
)

In [5]:
model = ResNet110()
logger = pl.loggers.TensorBoardLogger('runs', 'resnet110dn', default_hp_metric=False)

trainer = pl.Trainer(
    logger=logger,
    gpus=1, 
    callbacks=[stop_early, checkpoint_acc],
    num_sanity_val_steps=0,
    profiler='simple',
    fast_dev_run=True,
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Running in fast_dev_run mode: will run a full train, val and test loop using 1 batch(es).


In [6]:
trainer.fit(model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name      | Type             | Params
-----------------------------------------------
0 | resnet    | ResNet           | 45.9 M
1 | criterion | CrossEntropyLoss | 0     
2 | accuracy  | Accuracy         | 0     
-----------------------------------------------
45.9 M    Trainable params
0         Non-trainable params
45.9 M    Total params
183.560   Total estimated model params size (MB)


Epoch 0:  50%|█████     | 1/2 [00:01<00:01,  1.45s/it, loss=2.34, v_num=]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/1 [00:00<?, ?it/s][A
Epoch 0: 100%|██████████| 2/2 [00:02<00:00,  1.40s/it, loss=2.34, v_num=]
Epoch 0: 100%|██████████| 2/2 [00:02<00:00,  1.45s/it, loss=2.34, v_num=]

FIT Profiler Report

Action                             	|  Mean duration (s)	|Num calls      	|  Total time (s) 	|  Percentage %   	|
--------------------------------------------------------------------------------------------------------------------------------------
Total                              	|  -              	|_              	|  4.7329         	|  100 %          	|
--------------------------------------------------------------------------------------------------------------------------------------
run_training_epoch                 	|  2.9059         	|1              	|  2.9059         	|  61.399         	|
get_train_batch                    	|  0.84596        	|1              	|  0.84596        	|  17.874         	|
run_training_batch                 	|  0.60089        	|1              	|  0.60089        	|  12.696         	|
optimizer_step_and_closure_0       	|  0.58728        	|1              	|  0.58728        	|  12.409         	|
training_step_and_backward         


