# PL_TEMPLATE

## Setup

### Packages

In [None]:
%%shell
pip install git+https://github.com/PytorchLightning/pytorch-lightning.git@master > /dev/null 2>&1
pip install git+https://github.com/albumentations-team/albumentations > /dev/null 2>&1
pip install neptune-client > /dev/null 2>&1

In [None]:
# STL
import math
import os
import glob
import logging
from getpass import getpass

# Numerical Python
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from PIL import Image

# Deep Learning
import torch
from torch import nn
from torch.nn import functional as F
import torch.utils.data as D
import torchvision as tv
import pytorch_lightning as pl
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import train_test_split

### Config

In [None]:
CONSTANTS = {
    'SEED': 81, 
    'TEST_DRIVE': True,

    'NEPTUNE': {
        'USERNAME': 'rshwndsz',
        'PROJECT': 'template',
        'EXPERIMENT_NAME': '',
        'API_TOKEN': getpass('Enter your private Neptune API token: '),
    }
}

### Logging

In [None]:
import logging
from logging.config import dictConfig

LOGGING_CONFIG = {
    'version': 1,
    'disable_existing_loggers': True,
    'formatters': {
        'standard': {
            'format': '%(asctime)s %(filename)9s: %(levelname)8s %(message)s'
        },
    },
    'handlers': {
        'stdout': {
            'level': 'DEBUG',
            'formatter': 'standard',
            'class': 'logging.StreamHandler',
            'stream': 'ext://sys.stdout',  # Default is stderr
        },
        'file': {
            'class': 'logging.handlers.RotatingFileHandler',
            'level': 'DEBUG',
            'formatter': 'standard',
            'filename': '.logs/LOG.log',
            'mode': 'a',
            'maxBytes': 10485760,
            'backupCount': 5,
        }
    },
    'loggers': {
        'sarcd': {
            'handlers': ['stdout', 'file'],
            'level': 'DEBUG',
            'propagate': True
        },
    }
}


logger = logging.getLogger('sarcd')

# Test drive the logger
if CONSTANTS['TEST_DRIVE']:
    logger.debug(f"Torch: {torch.__version__}, "
                 f"Torchvision: {tv.__version__}, "
                 f"Pytorch Lightning: {pl.__version__}, "
                 f"albumentations: {A.__version__}")

### Utils

In [None]:
def download_file(url, destination_dir='./', desc=None, force=False):
    """
    Download a file from any url using requests
    """
    # Convert path to pathlib object if not already
    destination_dir = Path(destination_dir)
    # Get filename from url
    fname = url.split('/')[-1]
    # Construct path to file in local machine
    local_filepath = Path(destination_dir) / fname

    if local_filepath.is_file() and not force:
        logger.info("File(s) already downloaded. Use force=True to download again.")
        return local_filepath
    else:
        # Safely create nested directory - https://stackoverflow.com/a/273227
        destination_dir.mkdir(parents=True, exist_ok=True)

    if desc is None:
        desc = f"Downloading {fname}"

    # Download large file with requests - https://stackoverflow.com/a/16696317
    with requests.get(url, stream=True) as r:
        r.raise_for_status()
        total_size_in_bytes = int(r.headers.get('content-length', 0))
        block_size          = 1024
        # Progress bar for downloading file - https://stackoverflow.com/a/37573701
        pbar = tqdm(total=total_size_in_bytes, 
                    unit='iB', 
                    unit_scale=True,
                    desc=desc)
        with open(local_filepath, 'wb') as f:
            for data in r.iter_content(block_size):
                pbar.update(len(data))
                f.write(data)
        pbar.close()
    return local_filepath


def extract_file(fname, ftype=None, destination_dir="./", desc=None, remove_extract=False):
    # Convert to pathlib objects
    fname = Path(fname)
    destination_dir = Path(destination_dir)

    # Check arguments
    if not fname.is_file():
        raise IOError(f"The file {str(fname)} does not exist.")
    
    # Safely create nested directory - https://stackoverflow.com/a/273227
    destination_dir.mkdir(parents=True, exist_ok=True)

    if desc is None:
        desc = f"Extracting {str(fname.name)}"

    # Get type of extract
    if ftype is None:
        ftype = fname.suffix

    # Extract the dataset into `destination_dir`
    if ftype == '.tar':
        with tarfile.open(fname) as tar:
            pbar = tqdm(iterable=tar.getmembers(), total=len(tar.getmembers()), desc=desc)
            # Extract files with progress bar - https://stackoverflow.com/a/53405055
            for member in pbar:
                tar.extract(member=member, path=destination_dir)

    elif ftype == '.zip':
        # https://stackoverflow.com/a/56970565
        with ZipFile(fname, 'r') as zip:
            pbar = tqdm(zip.infolist(), desc=desc)
            for member in pbar:
                zip.extract(member, destination_dir)

    else:
        raise IOError(f"The suffix: {ftype} is not supported.")
            
    if remove_extract:
        # Delete the compressed dataset
        os.remove(fname)   


def make_grid(tensors, nrow=2, padding=2, isNormalized=True):
    """
    Convert a list of tensors into a numpy image grid
    """
    grid = tv.utils.make_grid(tensor=tensors.detach().cpu(), 
                              nrow=nrow, 
                              padding=padding, 
                              normalize= (not isNormalized))
    if isNormalized:
        ndgrid = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).numpy().astype(np.uint16)
    else:
        ndgrid = grid.clamp_(0, 255).permute(1, 2, 0).numpy().astype(np.uint16)
    return ndgrid

## Datasets

In [None]:
class GenericImageDS(D.Dataset):
    def __init__(self,
                 root,
                 image_glob="*.jpg",
                 train=True,
                 transform=None,
                 min_image_dim=256):
        self.root = root
        self.image_glob = image_glob
        self.train = train
        self.min_image_dim = min_image_dim

        image_regex = os.path.join(self.root, self.image_glob)
        self.image_paths = glob.glob(image_regex)
        if not len(self.image_paths):
            raise ValueError(f"No image found using {image_regex}")

        self.transform = transform
        # Default set of transforms if none are provided
        if self.transform is None:
            self.transform = A.Compose([
                A.Resize(self.min_image_dim, self.min_image_dim, 4, True, 1),
                A.Normalize(mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0), p=1),
                ToTensorV2()
            ])
        logger.info(f"Total samples: {len(self.image_paths)}")

    @staticmethod
    def download(urls, destination_dir, force=False):
        destination_dir = Path(destination_dir)

        # Check validity of arguments
        if not destination_dir.is_dir():
            raise ValueError("Provide destination_dir")
        if urls is None:
            raise ValueError("Provide URL(s)")

        # Download & Extract
        for url in urls:
            fname = download_file(url, destination_dir)
            extract_file(fname, destination_dir)

    def __getitem__(self, index):
        image_path = self.image_paths[index]
        image = np.asarray(Image.open(image_path))
        image = self.transform(image=image)["image"]
        return image

    def __len__(self):
        return len(self.image_paths)


In [None]:
# Test drive the dataset API
if CONSTANTS['TEST_DRIVE']:
    train_transform = A.Compose([
        A.VerticalFlip(p=0.1),
        A.HorizontalFlip(p=0.6),
        A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=1),
        A.Resize(256, 256, interpolation=INTER_LANCZOS4, p=1),
        A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.5), p=1),
        A.MotionBlur(p=1),
        A.RandomBrightnessContrast(p=1),
        A.Normalize(mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0), p=1),
        ToTensorV2(),
    ])
    ds = ZeroDceDS("data/train_data/", "*.jpg", train=True, transform=train_transform)
    dl = D.DataLoader(ds, batch_size=8, pin_memory=False, shuffle=False)

    batch = next(iter(dl))
    logger.debug(f"{batch[3].max(), batch[3].min()}")
    
    plt.imshow(batch[3].permute(1, 2, 0))
    plt.show()

## Losses

## Models

## Lightning Module 🪄

In [None]:
class FinalNet(pl.LightningModule):
    def __init__(self, hparams):
        super(FinalNet, self).__init__()

    def forward(self, x):
        # TODO
        return x

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), 
                                lr=self.hparams['lr'], 
                                weight_decay=self.hparams['weight_decay'])

    def loss_function(self, inputs, outputs):
        # TODO
        return outputs

    def prepare_data(self):
        train_transform = A.Compose([
            A.VerticalFlip(p=0.1),
            A.HorizontalFlip(p=0.6),
            A.Resize(self.hparams['image_size'], self.hparams['image_size'], interpolation=4, p=1),
            A.Normalize(mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0), p=1),
            ToTensorV2(),
        ])
        test_transform = A.Compose([
            A.Resize(self.hparams['image_size'], self.hparams['image_size'], interpolation=4, p=1),
            A.Normalize(mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0), p=1),
            ToTensorV2(),
        ])

        self.train_ds = ZeroDceDS("data/train_data/",    "*.jpg", train=True,  transform=train_transform)
        self.val_ds   = ZeroDceDS("data/test_data/DICM", "*.jpg", train=False, transform=test_transform)
        self.test_ds  = ZeroDceDS("data/test_data/LIME", "*.bmp", train=False, transform=test_transform)
                  
    def train_dataloader(self):
        return D.DataLoader(self.train_ds, 
                            batch_size=self.hparams['batch_size']['train'], 
                            num_workers=4, 
                            pin_memory=True, 
                            shuffle=True)
        
    def val_dataloader(self):
        return D.DataLoader(self.val_ds, 
                            batch_size=self.hparams['batch_size']['val'],
                            num_workers=4, 
                            pin_memory=True, 
                            shuffle=False)
        
    def test_dataloader(self):
        return D.DataLoader(self.test_ds,
                            batch_size=self.hparams['batch_size']['test'],
                            num_workers=4, 
                            pin_memory=True, 
                            shuffle=False)
        
    def training_step(self, batch, batch_idx):
        images = batch
        outputs = self(images)
        loss = self.loss_function(images, outputs)

        self.logger.experiment.log_metric('step_train_loss', loss)
        return { 'loss': loss }

    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        self.logger.experiment.log_metric('epoch_train_loss', avg_loss)
        
    def validation_step(self, batch, batch_idx):
        images = batch
        outputs = self(images)
        loss = self.loss_function(images, enhanced, A)

        self.logger.experiment.log_metric('step_val_loss', loss)
        return {'val_loss': loss, 'outputs': outputs}

    def validation_epoch_end(self, outputs):
        avg_val_loss = torch.stack([output['val_loss'] for output in outputs]).mean()
        self.log('avg_val_loss', avg_val_loss)
        self.logger.experiment.log_metric('epoch_val_loss', avg_val_loss)

    def test_step(self, batch, batch_idx):
        inputs = batch
        outputs = self(inputs)


## Training

### Parameters

In [None]:
hparams = {
    'lr': 0.0001, 
    'weight_decay': 0.0001,
    'batch_size': {
        'train': 8,
        'val': 4,
        'test': 4,
    },
    'image_size': 256,
    'gradient_clip_val': 0.1, 
    'max_epochs': 200,
    'min_epochs': 10,
    'check_val_every_n_epoch': 4,
    'precision': 32,     # https://pytorch-lightning.readthedocs.io/en/latest/amp.html
    'benchmark': True,
    'deterministic': False,
    'use_gpu': torch.cuda.is_available(),
}

### Bells & whistles

In [None]:
# https://pytorch-lightning.readthedocs.io/en/latest/weights_loading.html?highlight=ModelCheckpoint
from pytorch_lightning.callbacks import ModelCheckpoint
model_checkpoint = ModelCheckpoint(filepath   = 'checkpoints/{epoch:04d}-{epoch_val_loss:.3f}.ckpt',
                                   save_top_k = 5,
                                   monitor    = 'val_loss',
                                   mode       = 'min',
                                   period     = 5)

# https://pytorch-lightning.readthedocs.io/en/latest/early_stopping.html
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
early_stop_callback = EarlyStopping(
   monitor   = 'epoch_val_loss',
   min_delta = 0.00,
   patience  = 3,
   verbose   = True,
   mode      = 'min'
)

if not CONSTANT['TEST_DRIVE']:
    # https://docs.neptune.ai/api-reference/neptune/experiments/index.html#neptune.experiments.Experiment
    from pytorch_lightning.loggers.neptune import NeptuneLogger
    pl_logger = NeptuneLogger(
        api_key         = CONSTANTS['NEPTUNE']['API_TOKEN'],
        project_name    = f"{CONSTANTS['NEPTUNE']['USERNAME']}/{CONSTANTS['NEPTUNE']['PROJECT']}",
        close_after_fit = False,
        experiment_name = CONSTANTS['NEPTUNE']['EXPERIMENT_NAME'],
        params          = hparams,
    )
else:
    # https://pytorch-lightning.readthedocs.io/en/1.0.8/logging.html
    from pytorch_lightning.loggers import TensorBoardLogger
    pl_logger = TensorBoardLogger('tensorboard-logs/')

### Instantiations

In [None]:
logger.setLevel(logging.INFO)
pl.seed_everything(CONSTANTS['SEED'])

model   = FinalNet(hparams=hparams)

trainer = pl.Trainer(
    gpus                    = -1 if hparams['use_gpu'] else 0,
    precision               = hparams['precision'],
    gradient_clip_val       = hparams['gradient_clip_val'],
    benchmark               = hparams['benchmark'],
    deterministic           = hparams['deterministic'],
    max_epochs              = hparams['max_epochs'],
    min_epochs              = hparams['min_epochs'],
    check_val_every_n_epoch = hparams['check_val_every_n_epoch'],
    logger                  = pl_logger,
    checkpoint_callback     = model_checkpoint,
    callbacks               = [early_stop_callback],
) 

### Train 🐉🐉

In [None]:
trainer.fit(model)

In [None]:
if not CONSTANTS['TEST_DRIVE']:
    # Log model summary
    for chunk in [x for x in str(model).split('\n')]:
        neptune_logger.experiment.log_text('model_summary', str(chunk))

    # Which GPUs where used?
    gpu_list = [f'{i}:{torch.cuda.get_device_name(i)}' for i in range(torch.cuda.device_count())] 
    neptune_logger.experiment.log_text('GPUs used', ', '.join(gpu_list))

    # Log best 3 model checkpoints to Neptune
    for k in model_checkpoint.best_k_models.keys():
        model_name = 'checkpoints/' + k.split('/')[-1]
        neptune_logger.experiment.log_artifact(k, model_name)

    # Save last path
    last_model_path = f"checkpoints/last_model--epoch={trainer.current_epoch}.ckpt"
    trainer.save_checkpoint(last_model_path)
    neptune_logger.experiment.log_artifact(last_model_path, 'checkpoints/' + last_model_path.split('/')[-1])

    # Log score of the best model checkpoint
    neptune_logger.experiment.set_property('best_model_score', model_checkpoint.best_model_score.tolist())

### Test

In [None]:
trainer.test()

if not CONSTANTS['TEST_DRIVE']:
    # Stop Neptune Logger
    neptune_logger.experiment.stop()

## Inference

### Get weights

In [None]:
# Get Neptune API token
from getpass import getpass
api_token = getpass("Enter Neptune.ai API token: ")

In [None]:
# Initialize Neptune project
import neptune
from neptune import Session

session = Session.with_default_backend(api_token=api_token)
project = session.get_project('rshwndsz/nightsight')
experiment = project.get_experiments(id='NIG-11')[0]
experiment

In [None]:
# Download checkpoint from Neptune
artifact_path   = 'epoch=133-avg_val_loss=1.06.ckpt'
artifact_name   = artifact_path.split('/')[-1]
checkpoint_dir  = os.path.join('checkpoints', 'downloads')
checkpoint_path = os.path.join(checkpoint_dir, artifact_name)

experiment.download_artifact(path=artifact_path, destination_dir=checkpoint_dir)

### Load weights

In [None]:
testing_model = FinalNet.load_from_checkpoint(checkpoint_path=checkpoint_path)
testing_model.eval()

### Test

In [None]:
test_transform = A.Compose(
    [
        A.Resize(hparams['image_size'], hparams['image_size'], interpolation=INTER_LANCZOS4, p=1),
        A.Normalize(mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0), p=1),
        ToTensorV2(),
    ]
)
ds = ZeroDceDS("data/test_data/Adobe-5k", "*.jpg", train=False, transform=test_transform)
dl = D.DataLoader(ds, batch_size=5, pin_memory=False, shuffle=True)
batch = next(iter(dl))
plt.imshow(batch[0].permute(1, 2, 0))