<a href="https://colab.research.google.com/github/rshwndsz/templates/blob/master/pl-colab/template.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# PL_TEMPLATE

## 0a. Setup

In [None]:
%%shell
pip install pytorch-lightning > /dev/null 2>&1
pip install neptune-client    > /dev/null 2>&1
pip install torchmetrics      > /dev/null 2>&1

In [None]:
# STL
import math
import os
import glob
import logging
import getpass
from pathlib import Path

# Numerical Python
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Image processing
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Deep Learning
import torch
from torch import nn
from torch.nn import functional as F
import torch.utils.data as D
import torchvision as tv
import pytorch_lightning as pl

# Bells & Whistles
from sklearn.model_selection import train_test_split
from pytorch_lightning.loggers.neptune import NeptuneLogger
from pytorch_lightning.callbacks import (ModelCheckpoint,
                                         EarlyStopping)

# Misc
from tqdm.notebook import tqdm
import gdown

In [None]:
# Set up logging to file
# https://stackoverflow.com/a/23681578

logging.basicConfig(
     filename='LOG.log',
     level=logging.INFO, 
     format= '[%(asctime)s] %(levelname)8s - %(funcName)8s() - %(message)s',
     datefmt='%H:%M:%S'
 )

# Set up logging to console
console = logging.StreamHandler()
console.setLevel(logging.DEBUG)

# Set a format which is simpler for console use
formatter = logging.Formatter('%(levelname)8s - %(funcName)14s() - %(message)s')
console.setFormatter(formatter)

# Add the handler to the root logger
logging.getLogger('').addHandler(console)
logger = logging.getLogger(__name__)

# Test drive the logger
logger.info(f"""
            Torch: {torch.__version__}
            Torchvision: {tv.__version__}
            Pytorch Lightning: {pl.__version__} 
            Albumentations: {A.__version__}
            """)

## 0b. Utils

## 1. Data

In [None]:
# Sample dataset API
class GenericImageDS(D.Dataset):
    def __init__(self,
                 root,
                 image_glob="*.jpg",
                 train=True,
                 transform=None,
                 min_image_dim=256):
        self.root = root
        self.image_glob = image_glob
        self.train = train
        self.min_image_dim = min_image_dim

        image_regex = os.path.join(self.root, self.image_glob)
        self.image_paths = glob.glob(image_regex)
        if not len(self.image_paths):
            raise ValueError(f"No image found using {image_regex}")

        self.transform = transform
        # Default set of transforms if none are provided
        if self.transform is None:
            self.transform = A.Compose([
                A.Resize(self.min_image_dim, self.min_image_dim, 4, True, 1),
                A.Normalize(mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0), p=1),
                ToTensorV2()
            ])
        logger.info(f"Total samples: {len(self.image_paths)}")

    @staticmethod
    def download(urls, destination_dir, force=False):
        destination_dir = Path(destination_dir)

        # Check validity of arguments
        if not destination_dir.is_dir():
            raise ValueError("Provide destination_dir")
        if urls is None:
            raise ValueError("Provide URL(s)")

        # Download & Extract
        for url in urls:
            fname = download_file(url, destination_dir)
            extract_file(fname, destination_dir)

    def __getitem__(self, index):
        image_path = self.image_paths[index]
        image = np.asarray(Image.open(image_path))
        image = self.transform(image=image)["image"]
        return image

    def __len__(self):
        return len(self.image_paths)


# Test
logger.setLevel(logging.DEBUG)
# Write testing code here
logger.setLevel(logging.INFO)

## 2. Model Blocks

## 3. Lightning Modules

In [None]:
# Sample data module
class ImageNet(pl.LightningDataModule):
    def __init__(self, batch_size=4, patch_size=11):
        super().__init__()
        pass

    def prepare_data(self):
        """
        For operations that might write to disk or 
        that need to be done only from a single process in distributed settings.
        DO NOT use to assign state as it is called from a single process.
        """
        URL    = ""
        outdir = Path("./data/imagenet/")

        # Safely create nested directory
        outdir.mkdir(parents=True, exist_ok=True)

        # Download dataset
        if not (outdir / "dataset.mat").exists():
            gdown.download(URL, str(outdir / "dataset.mat"), quiet=False)

    def setup(self, stage=None):
        """For data operations on every GPU."""
        pass

    def train_dataloader(self):
        pass
        
    def val_dataloader(self):
        pass

    def test_dataloader(self):
        pass

    def teardown(self, stage=None):
        """Used to clean-up when run is finished"""
        shutil.rmtree("./data/imagenet")
    
    def visualize(self):
        pass


# Test
logger.setLevel(logging.DEBUG)
# Write testing code here
logger.setLevel(logging.INFO)

In [None]:
class FinalNet(pl.LightningModule):
    def __init__(self, hparams):
        super(FinalNet, self).__init__()

    def forward(self, x):
        # TODO
        raise NotImplementedError

    def configure_optimizers(self):
        # TODO
        raise NotImplementedError

    def loss_function(self, preds, targets):
        # TODO
        raise NotImplementedError

    def prepare_data(self):
        # TODO
        raise NotImplementedError
                  
    def train_dataloader(self):
        # TODO
        raise NotImplementedError
        
    def val_dataloader(self):
        # TODO
        raise NotImplementedError
        
    def training_step(self, batch, batch_idx):
        inputs, targets = batch
        preds = self(inputs)
        loss  = self.loss_function(preds, targets)

        self.logger.experiment.log_metric('step_train_loss', loss)
        return { 'loss': loss }

    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()

        self.logger.experiment.log_metric('epoch_train_loss', avg_loss)
        
    def validation_step(self, batch, batch_idx):
        inputs, targets = batch
        preds = self(inputs)
        loss = self.loss_function(preds, targets)

        self.logger.experiment.log_metric('step_val_loss', loss)
        return { 'val_loss': loss }

    def validation_epoch_end(self, outputs):
        avg_val_loss = torch.stack([output['val_loss'] for output in outputs]).mean()

        self.log('avg_val_loss', avg_val_loss)
        self.logger.experiment.log_metric('epoch_val_loss', avg_val_loss)


# Test
logger.setLevel(logging.DEBUG)
# Write testing code here
logger.setLevel(logging.INFO)

## 4. Training

In [None]:
hparams = {
    'lr': 0.0001, 
    'batch_size': 4,
    'max_epochs': 200,
    'min_epochs': 10,
    'check_val_every_n_epoch': 4,
    'precision': 32,     # https://pytorch-lightning.readthedocs.io/en/latest/amp.html
    'benchmark': True,
    'deterministic': False,
    'use_gpu': torch.cuda.is_available(),
}

In [None]:
# https://pytorch-lightning.readthedocs.io/en/latest/weights_loading.html?highlight=ModelCheckpoint
model_checkpoint = ModelCheckpoint(
    dirpath    = "./checkpoints/",
    filename   = '{epoch:03d}__{avg_val_loss:.5f}',
    save_top_k = 5,
    monitor    = 'avg_val_loss',
    mode       = 'min',
    period     = 5
)

# https://pytorch-lightning.readthedocs.io/en/latest/early_stopping.html
early_stop_callback = EarlyStopping(
   monitor   = 'avg_val_loss',
   min_delta = 1e-7,
   patience  = 3,
   verbose   = True,
   mode      = 'min'
)

# https://docs.neptune.ai/api-reference/neptune/experiments/index.html#neptune.experiments.Experiment
pl_logger = NeptuneLogger(
    api_key         = CONSTANTS['API_TOKEN'],
    project_name    = f"", # TODO
    close_after_fit = True,
    experiment_name = '',  # TODO
    params          = hparams,
    offline_model   = True,  # Comment to log into neptune.ai
)

In [None]:
logger.setLevel(logging.INFO)
pl.seed_everything(CONSTANTS['SEED'])

dataset = Houston18DataModule(batch_size=hparams['batch_size'], 
                              patch_size=hparams['patch_size'])

model   = FinalNet(hparams=hparams)

trainer = pl.Trainer(
    gpus                    = -1 if hparams['use_gpu'] else 0,
    precision               = hparams['precision'],
    gradient_clip_val       = hparams['gradient_clip_val'],
    benchmark               = hparams['benchmark'],
    deterministic           = hparams['deterministic'],
    max_epochs              = hparams['max_epochs'],
    min_epochs              = hparams['min_epochs'],
    check_val_every_n_epoch = hparams['check_val_every_n_epoch'],
    logger                  = pl_logger,
    checkpoint_callback     = model_checkpoint,
    callbacks               = [early_stop_callback],
) 

In [None]:
# 🐉
trainer.fit(model, dataset)

In [None]:
# Log model summary
for chunk in [x for x in str(model).split('\n')]:
    neptune_logger.experiment.log_text('model_summary', str(chunk))

# Which GPUs where used?
gpu_list = [f'{i}:{torch.cuda.get_device_name(i)}' 
            for i in range(torch.cuda.device_count())] 
neptune_logger.experiment.log_text('GPUs used', ', '.join(gpu_list))

# Log best 3 model checkpoints to Neptune
for k in model_checkpoint.best_k_models.keys():
    model_name = 'checkpoints/' + k.split('/')[-1]
    neptune_logger.experiment.log_artifact(k, model_name)

# Save last path
last_model_path = f"checkpoints/last_model--epoch={trainer.current_epoch}.ckpt"
trainer.save_checkpoint(last_model_path)
neptune_logger.experiment.log_artifact(
    last_model_path, 
    'checkpoints/' + last_model_path.split('/')[-1]
)

# Log score of the best model checkpoint
neptune_logger.experiment.set_property(
    'best_model_score', 
    model_checkpoint.best_model_score.tolist()
)

## Inference

### Get weights

In [None]:
# Get Neptune API token
from getpass import getpass
api_token = getpass("Enter Neptune.ai API token: ")

In [None]:
# Initialize Neptune project
import neptune
from neptune import Session

session = Session.with_default_backend(api_token=api_token)
project = session.get_project(f"") # TODO
experiment = project.get_experiments(id='')[0] # TODO
experiment

In [None]:
# Download checkpoint from Neptune
artifact_path   = 'epoch=133-avg_val_loss=1.06.ckpt'
artifact_name   = artifact_path.split('/')[-1]
checkpoint_dir  = os.path.join('checkpoints', 'downloads')
checkpoint_path = os.path.join(checkpoint_dir, artifact_name)

experiment.download_artifact(path=artifact_path, destination_dir=checkpoint_dir)

### Load weights

In [None]:
model = FinalNet.load_from_checkpoint(checkpoint_path=checkpoint_path)
model.eval()

### Test

In [None]:
# TODO