<a href="https://colab.research.google.com/github/rshwndsz/templates/blob/master/pl-colab/template.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# PL_TEMPLATE

## 0a. Setup

In [None]:
%%shell
pip install pytorch-lightning > /dev/null 2>&1
pip install neptune-client    > /dev/null 2>&1
pip install torchmetrics      > /dev/null 2>&1

In [None]:
# STL
import math
import os
import glob
import logging
import getpass
from pathlib import Path

# Numerical Python
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Image processing
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Deep Learning
import torch
from torch import nn
from torch.nn import functional as F
import torch.utils.data as D
import torchvision as tv
import pytorch_lightning as pl

# Bells & Whistles
from sklearn.model_selection import train_test_split
from pytorch_lightning.loggers.neptune import NeptuneLogger
from pytorch_lightning.callbacks import (ModelCheckpoint,
                                         EarlyStopping)

# Misc
from tqdm.notebook import tqdm
import gdown

In [None]:
# Set up logging to file
# https://stackoverflow.com/a/23681578

logging.basicConfig(
     filename='LOG.log',
     level=logging.INFO, 
     format= '[%(asctime)s] %(levelname)8s - %(funcName)8s() - %(message)s',
     datefmt='%H:%M:%S'
 )

# Set up logging to console
console = logging.StreamHandler()
console.setLevel(logging.DEBUG)

# Set a format which is simpler for console use
formatter = logging.Formatter('%(levelname)8s - %(funcName)14s() - %(message)s')
console.setFormatter(formatter)

# Add the handler to the root logger
logging.getLogger('').addHandler(console)
logger = logging.getLogger(__name__)

# Test drive the logger
logger.info(f"""
            Torch: {torch.__version__}
            Torchvision: {tv.__version__}
            Pytorch Lightning: {pl.__version__} 
            Albumentations: {A.__version__}
            """)

## 0b. Utils

In [None]:
def download_file(url, 
                  destination_dir='./', 
                  desc=None, 
                  force=False):
    """Download a file from any url using requests"""
    # Convert path to pathlib object if not already
    destination_dir = Path(destination_dir)
    # Get filename from url
    fname = url.split('/')[-1]
    # Construct path to file in local machine
    local_filepath = Path(destination_dir) / fname

    if local_filepath.is_file() and not force:
        logger.info("File(s) already downloaded. Use force=True to download again.")
        return local_filepath
    else:
        # Safely create nested directory - https://stackoverflow.com/a/273227
        destination_dir.mkdir(parents=True, exist_ok=True)

    if desc is None:
        desc = f"Downloading {fname}"

    # Download large file with requests - https://stackoverflow.com/a/16696317
    with requests.get(url, stream=True) as r:
        r.raise_for_status()
        total_size_in_bytes = int(r.headers.get('content-length', 0))
        block_size          = 1024
        # Progress bar for downloading file - https://stackoverflow.com/a/37573701
        pbar = tqdm(total=total_size_in_bytes, 
                    unit='iB', 
                    unit_scale=True,
                    desc=desc)
        with open(local_filepath, 'wb') as f:
            for data in r.iter_content(block_size):
                pbar.update(len(data))
                f.write(data)
        pbar.close()
    return local_filepath

In [None]:
def extract_file(fname, 
                 ftype=None, 
                 destination_dir="./", 
                 desc=None, 
                 remove_extract=False):
    # Convert to pathlib objects
    fname = Path(fname)
    destination_dir = Path(destination_dir)

    # Check arguments
    if not fname.is_file():
        raise IOError(f"The file {str(fname)} does not exist.")
    
    # Safely create nested directory - https://stackoverflow.com/a/273227
    destination_dir.mkdir(parents=True, exist_ok=True)

    if desc is None:
        desc = f"Extracting {str(fname.name)}"

    # Get type of extract
    if ftype is None:
        ftype = fname.suffix

    # Extract the dataset into `destination_dir`
    if ftype == '.tar':
        with tarfile.open(fname) as tar:
            pbar = tqdm(
                iterable=tar.getmembers(), 
                total=len(tar.getmembers()), 
                desc=desc
            )
            # Extract files with progress bar - https://stackoverflow.com/a/53405055
            for member in pbar:
                tar.extract(member=member, path=destination_dir)

    elif ftype == '.zip':
        # https://stackoverflow.com/a/56970565
        with ZipFile(fname, 'r') as zip:
            pbar = tqdm(zip.infolist(), desc=desc)
            for member in pbar:
                zip.extract(member, destination_dir)

    else:
        raise IOError(f"The suffix: {ftype} is not supported.")
            
    if remove_extract:
        # Delete the compressed dataset
        os.remove(fname)   

In [None]:
def make_grid(tensors, 
              nrow=2, 
              padding=2, 
              isNormalized=True):
    """Convert a list of tensors into a numpy image grid"""
    grid = tv.utils.make_grid(tensor=tensors.detach().cpu(), 
                              nrow=nrow, 
                              padding=padding, 
                              normalize= (not isNormalized))
    if isNormalized:
        ndgrid = grid.mul(255) \
                     .add_(0.5) \
                     .clamp_(0, 255) \
                     .permute(1, 2, 0) \
                     .numpy() \
                     .astype(np.uint16)
    else:
        ndgrid = grid.clamp_(0, 255) \
                     .permute(1, 2, 0) \
                     .numpy() \
                     .astype(np.uint16)

    return ndgrid

## 1. Data

In [None]:
# Sample datasets
class GenericImageDS(D.Dataset):
    def __init__(self,
                 root,
                 image_glob="*.jpg",
                 train=True,
                 transform=None,
                 min_image_dim=256):
        self.root = root
        self.image_glob = image_glob
        self.train = train
        self.min_image_dim = min_image_dim

        image_regex = os.path.join(self.root, self.image_glob)
        self.image_paths = glob.glob(image_regex)
        if not len(self.image_paths):
            raise ValueError(f"No image found using {image_regex}")

        self.transform = transform
        # Default set of transforms if none are provided
        if self.transform is None:
            self.transform = A.Compose([
                A.Resize(self.min_image_dim, self.min_image_dim, 4, True, 1),
                A.Normalize(mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0), p=1),
                ToTensorV2()
            ])
        logger.info(f"Total samples: {len(self.image_paths)}")

    @staticmethod
    def download(urls, destination_dir, force=False):
        destination_dir = Path(destination_dir)

        # Check validity of arguments
        if not destination_dir.is_dir():
            raise ValueError("Provide destination_dir")
        if urls is None:
            raise ValueError("Provide URL(s)")

        # Download & Extract
        for url in urls:
            fname = download_file(url, destination_dir)
            extract_file(fname, destination_dir)

    def __getitem__(self, index):
        image_path = self.image_paths[index]
        image = np.asarray(Image.open(image_path))
        image = self.transform(image=image)["image"]
        return image

    def __len__(self):
        return len(self.image_paths)


# Test
logger.setLevel(logging.DEBUG)
# Write testing code here
logger.setLevel(logging.INFO)

## 2. Model Blocks

## 3. Lightning Modules

In [None]:
# Sample data module
class Houston18DataModule(pl.LightningDataModule):
    """
    https://hyperspectral.ee.uh.edu/?page_id=1075
    """
    def __init__(self, batch_size=4, patch_size=11):
        super(Houston18DataModule, self).__init__()

        self.patch_size   = patch_size
        self.batch_size   = batch_size
        self.rgb_bands    = (47, 31, 15)
        self.label_values = [
            "Unclassified",
            "Healthy grass",
            "Stressed grass",
            "Artificial turf",
            "Evergreen trees",
            "Deciduous trees",
            "Bare earth",
            "Water",
            "Residential buildings",
            "Non-residential buildings",
            "Roads",
            "Sidewalks",
            "Crosswalks",
            "Major thoroughfares",
            "Highways",
            "Railways",
            "Paved parking lots",
            "Unpaved parking lots",
            "Cars",
            "Trains",
            "Stadium seats",
        ]
        self.ignored_labels = [0]
        self.palette = dict([
            (0,  (0, 0, 0)),
            (1,  (52, 209, 0)),
            (2,  (143, 255, 0)),
            (3,  (55, 153, 86)),
            (4,  (34, 140, 0)),
            (5,  (18, 70, 0)),
            (6,  (155, 70, 32)),
            (7,  (51, 254, 254)),
            (8,  (255, 255, 255)),
            (9,  (209, 185, 212)),
            (10, (244, 0, 0)),
            (11, (160, 147, 138)),
            (12, (112, 110, 111)),
            (13, (173, 0, 0)),
            (14, (67, 0, 0)),
            (15, (234, 158, 0)),
            (16, (255, 255, 0)),
            (17, (255, 216, 0)),
            (18, (209, 0, 227)),
            (19, (0, 0, 227)),
            (20, (176, 194, 220)),
        ])

    def prepare_data(self):
        """
        For operations that might write to disk or 
        that need to be done only from a single process in distributed settings.
        DO NOT use to assign state as it is called from a single process.
        """
        URL    = "https://drive.google.com/u/0/uc?id=1Mf1nVX1SzaJUOwedJi9w5v2GRYGpxRiF"
        outdir = Path("./data/houston18/")

        # Safely create nested directory
        outdir.mkdir(parents=True, exist_ok=True)

        # Download dataset
        if not (outdir / "H18data2_mat.mat").exists():
            gdown.download(URL, str(outdir / "H18data2_mat.mat"), quiet=False)

    def setup(self, stage=None):
        """For data operations on every GPU."""
        root  = Path("./data/houston18")
        data  = open_file(root / "H18data2_mat.mat")

        if stage == "fit" or stage == "test" or None:
            self.ds = GenericHSI(
                img            = data['hsi_sub_zoom'],
                gt             = data['gt'],
                rgb_bands      = self.rgb_bands,
                label_values   = self.label_values,
                ignored_labels = self.ignored_labels,
                palette        = self.palette,
                patch_size     = self.patch_size
            )
            num_classes = len(self.label_values) - len(self.ignored_labels)

            trainval_idx, _ = customClasswiseSplit(
                self.ds.labels, 
                num_classes=num_classes,
                num_samples=100,
                ignore_indices=None
            )

            train_idx, valid_idx = train_test_split(
                sorted(trainval_idx),
                test_size=0.1,
                stratify=np.array(self.ds.labels)[trainval_idx],
                random_state=CONSTANTS['SEED']
            )

            test_idx, _  = customClasswiseSplit(
                self.ds.labels,
                num_classes=num_classes,
                num_samples=2000,
                ignore_indices=trainval_idx
            )
        
            self.train_sampler = D.SubsetRandomSampler(train_idx)
            self.valid_sampler = D.SubsetRandomSampler(valid_idx)
            self.test_sampler  = D.SubsetRandomSampler(test_idx)
        
        if stage == "inference":
            # If you need patches from the whole image
            self.inference_ds = GenericHSIInference(
                img            = data["hsi_sub_zoom"],
                patch_size     = self.patch_size,
            )

    def train_dataloader(self):
        return D.DataLoader(
            self.ds, 
            batch_size=self.batch_size, 
            num_workers=2, 
            pin_memory=True, 
            sampler=self.train_sampler
        )
        
    def val_dataloader(self):
        return D.DataLoader(
            self.ds, 
            batch_size=self.batch_size,
            num_workers=2, 
            pin_memory=True, 
            sampler=self.valid_sampler
        )

    def test_dataloader(self):
        return D.DataLoader(
            self.ds, 
            batch_size=self.batch_size,
            num_workers=2, 
            pin_memory=True, 
            sampler=self.test_sampler
        )

    def inference_dataloader(self, batch_size=4):
        return D.DataLoader(
            self.inference_ds,
            batch_size=batch_size,
            num_workers=2,
            pin_memory=True,
            shuffle=False
        )

    def teardown(self, stage=None):
        """Used to clean-up when run is finished"""
        shutil.rmtree("./data/houston18")
    
    def visualize(self):
        ds = self.ds
        data, label = ds[random.randint(0, len(ds))]
        logger.debug(f"""data:  {data.min():.3f} to {data.max():.3f} 
                         with shape={data.shape}, dtype={data.dtype}""")
        logger.debug(f"""label: {label.min()} to {label.max()} 
                         with shape={label.shape}, dtype={label.dtype}""")
 
        # A single patch
        img = display_dataset(data.squeeze().permute(1, 2, 0).numpy(), ds.rgb_bands)
        plt.imshow(img)
        plt.show()

        # Full Image & Train GT
        img = display_dataset(ds.data, ds.rgb_bands)
        gt  = convert_to_color_(ds.label, ds.palette)
        fig, ax = plt.subplots(1, 2, figsize=(20, 20))
        for a in ax:
            a.axis('off')
        ax[0].imshow(img)
        ax[1].imshow(gt)
        fig.tight_layout()
        plt.show()
        plot_colortable(ds.palette, "Color table")


# Test
logger.setLevel(logging.DEBUG)

d = Houston18DataModule()
d.prepare_data()
d.setup(stage="fit")
d.visualize()
d.teardown()

del d
logger.setLevel(logging.INFO)

In [None]:
class FinalNet(pl.LightningModule):
    def __init__(self, hparams):
        super(FinalNet, self).__init__()

    def forward(self, x):
        # TODO
        raise NotImplementedError

    def configure_optimizers(self):
        # TODO
        raise NotImplementedError

    def loss_function(self, preds, targets):
        # TODO
        raise NotImplementedError

    def prepare_data(self):
        # TODO
        raise NotImplementedError
                  
    def train_dataloader(self):
        # TODO
        raise NotImplementedError
        
    def val_dataloader(self):
        # TODO
        raise NotImplementedError
        
    def training_step(self, batch, batch_idx):
        inputs, targets = batch
        preds = self(inputs)
        loss  = self.loss_function(preds, targets)

        self.logger.experiment.log_metric('step_train_loss', loss)
        return { 'loss': loss }

    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()

        self.logger.experiment.log_metric('epoch_train_loss', avg_loss)
        
    def validation_step(self, batch, batch_idx):
        inputs, targets = batch
        preds = self(inputs)
        loss = self.loss_function(preds, targets)

        self.logger.experiment.log_metric('step_val_loss', loss)
        return { 'val_loss': loss }

    def validation_epoch_end(self, outputs):
        avg_val_loss = torch.stack([output['val_loss'] for output in outputs]).mean()

        self.log('avg_val_loss', avg_val_loss)
        self.logger.experiment.log_metric('epoch_val_loss', avg_val_loss)


# Test
logger.setLevel(logging.DEBUG)
# Write testing code here
logger.setLevel(logging.INFO)

## 4. Training

In [None]:
hparams = {
    'lr': 0.0001, 
    'batch_size': 4,
    'max_epochs': 200,
    'min_epochs': 10,
    'check_val_every_n_epoch': 4,
    'precision': 32,     # https://pytorch-lightning.readthedocs.io/en/latest/amp.html
    'benchmark': True,
    'deterministic': False,
    'use_gpu': torch.cuda.is_available(),
}

In [None]:
# https://pytorch-lightning.readthedocs.io/en/latest/weights_loading.html?highlight=ModelCheckpoint
model_checkpoint = ModelCheckpoint(
    dirpath    = "./checkpoints/",
    filename   = '{epoch:03d}__{avg_val_loss:.5f}',
    save_top_k = 5,
    monitor    = 'avg_val_loss',
    mode       = 'min',
    period     = 5
)

# https://pytorch-lightning.readthedocs.io/en/latest/early_stopping.html
early_stop_callback = EarlyStopping(
   monitor   = 'avg_val_loss',
   min_delta = 1e-7,
   patience  = 3,
   verbose   = True,
   mode      = 'min'
)

# https://docs.neptune.ai/api-reference/neptune/experiments/index.html#neptune.experiments.Experiment
pl_logger = NeptuneLogger(
    api_key         = CONSTANTS['API_TOKEN'],
    project_name    = f"", # TODO
    close_after_fit = True,
    experiment_name = '',  # TODO
    params          = hparams,
    offline_model   = True,  # Comment to log into neptune.ai
)

In [None]:
logger.setLevel(logging.INFO)
pl.seed_everything(CONSTANTS['SEED'])

dataset = Houston18DataModule(batch_size=hparams['batch_size'], 
                              patch_size=hparams['patch_size'])

model   = FinalNet(hparams=hparams)

trainer = pl.Trainer(
    gpus                    = -1 if hparams['use_gpu'] else 0,
    precision               = hparams['precision'],
    gradient_clip_val       = hparams['gradient_clip_val'],
    benchmark               = hparams['benchmark'],
    deterministic           = hparams['deterministic'],
    max_epochs              = hparams['max_epochs'],
    min_epochs              = hparams['min_epochs'],
    check_val_every_n_epoch = hparams['check_val_every_n_epoch'],
    logger                  = pl_logger,
    checkpoint_callback     = model_checkpoint,
    callbacks               = [early_stop_callback],
) 

In [None]:
# 🐉
trainer.fit(model, dataset)

In [None]:
# Log model summary
for chunk in [x for x in str(model).split('\n')]:
    neptune_logger.experiment.log_text('model_summary', str(chunk))

# Which GPUs where used?
gpu_list = [f'{i}:{torch.cuda.get_device_name(i)}' 
            for i in range(torch.cuda.device_count())] 
neptune_logger.experiment.log_text('GPUs used', ', '.join(gpu_list))

# Log best 3 model checkpoints to Neptune
for k in model_checkpoint.best_k_models.keys():
    model_name = 'checkpoints/' + k.split('/')[-1]
    neptune_logger.experiment.log_artifact(k, model_name)

# Save last path
last_model_path = f"checkpoints/last_model--epoch={trainer.current_epoch}.ckpt"
trainer.save_checkpoint(last_model_path)
neptune_logger.experiment.log_artifact(
    last_model_path, 
    'checkpoints/' + last_model_path.split('/')[-1]
)

# Log score of the best model checkpoint
neptune_logger.experiment.set_property(
    'best_model_score', 
    model_checkpoint.best_model_score.tolist()
)

## Inference

### Get weights

In [None]:
# Get Neptune API token
from getpass import getpass
api_token = getpass("Enter Neptune.ai API token: ")

In [None]:
# Initialize Neptune project
import neptune
from neptune import Session

session = Session.with_default_backend(api_token=api_token)
project = session.get_project(f"") # TODO
experiment = project.get_experiments(id='')[0] # TODO
experiment

In [None]:
# Download checkpoint from Neptune
artifact_path   = 'epoch=133-avg_val_loss=1.06.ckpt'
artifact_name   = artifact_path.split('/')[-1]
checkpoint_dir  = os.path.join('checkpoints', 'downloads')
checkpoint_path = os.path.join(checkpoint_dir, artifact_name)

experiment.download_artifact(path=artifact_path, destination_dir=checkpoint_dir)

### Load weights

In [None]:
model = FinalNet.load_from_checkpoint(checkpoint_path=checkpoint_path)
model.eval()

### Test

In [None]:
# TODO