# New Project

## Setup

In [None]:
# TPU
# !pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl > /dev/null 2>&1

!pip install pytorch-lightning --upgrade       > /dev/null 2>&1
!pip install torchmetrics                      > /dev/null 2>&1
!pip install lightning-bolts                   > /dev/null 2>&1
!pip install thop                              > /dev/null 2>&1
!pip install optuna                            > /dev/null 2>&1

# Uncomment after https://github.com/neptune-ai/neptune-pytorch-lightning/issues/3 is resolved
# !pip install neptune-client[pytorch-lightning] > /dev/null 2>&1
!python -m pip install git+https://github.com/rshwndsz/neptune-pytorch-lightning.git@rsd/NPT-PL-3-use-existing-run > /dev/null 2>&1

# Uncomment after https://github.com/neptune-ai/neptune-optuna/issues/6 is resolved
# !pip install neptune-client[optuna]            > /dev/null 2>&1
!python -m pip install git+https://github.com/rshwndsz/neptune-optuna.git@rsd/NPT-OPT-2-multi-objective-support > /dev/null 2>&1

In [None]:
# STL
import math
import os
import sys
import glob
import logging
import getpass
import shutil
import random
import joblib
import json
from pathlib import Path
from functools import partial
from collections import OrderedDict
from argparse import Namespace

# Numerical Python
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Image processing
from PIL import Image

# Deep Learning
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as D
import torchvision as tv
import torchvision.transforms as tf
import torchvision.transforms.functional as tff
import torchmetrics as M
import pytorch_lightning as pl
import neptune.new as neptune
import optuna
import thop
from sklearn.model_selection import train_test_split

# Bells & Whistles
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from neptune.new.types import File
from neptune.new.integrations.pytorch_lightning import NeptuneLogger
from neptune.new.integrations.optuna import NeptuneCallback
from optuna.integration import PyTorchLightningPruningCallback

# Misc
import gdown
from tqdm.notebook import tqdm

In [None]:
import getpass

C = Namespace(
    NEPTUNE = Namespace(
        USERNAME  = "rshwndsz",
        PROJECT   = "", # TODO  
        API_TOKEN = getpass.getpass(prompt="Neptune API Token: "),
    ),
    SEED = 1337
)

pl.seed_everything(C.SEED)
os.environ["NEPTUNE_API_TOKEN"] = C.NEPTUNE.API_TOKEN

## Data

In [None]:
class DataModule(pl.LightningDataModule):
    def __init__(self, 
                batch_size  = 4, 
                num_workers = 2, 
                pin_memory  = True, 
                shuffle     = True):
        super().__init__()

        self.batch_size  = batch_size
        self.num_workers = num_workers
        self.pin_memory  = pin_memory
        self.shuffle     = shuffle

    def prepare_data(self):
        """
        For operations that might write to disk or 
        that need to be done only from a single process in distributed settings.
        DO NOT use to assign state as it is called from a single process.
        """
        URL    = ""
        outdir = Path("./data/")

        # Safely create nested directory
        outdir.mkdir(parents=True, exist_ok=True)

        # Download dataset
        if not (outdir / "dataset.mat").exists():
            gdown.download(URL, str(outdir / "dataset.mat"), quiet=False)

    def setup(self, stage=None):
        """For data operations on every GPU."""
        pass

    def train_dataloader(self):
        pass
        
    def val_dataloader(self):
        pass

    def test_dataloader(self):
        pass

    def teardown(self, stage=None):
        """Used to clean-up when run is finished"""
        pass
    
    def visualize(self):
        pass

##  Blocks

## Model

In [None]:
class Net(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()

        # Hyperparameters
        self.save_hyperparameters(hparams)

        # Metrics
        _metrics  = M.MetricCollection({
            "AverageAccuracy": M.Accuracy(num_classes=self.hparams.num_classes, average="macro"),
        })
        self.train_metrics = _metrics.clone(prefix="train/")
        self.val_metrics   = _metrics.clone(prefix="val/")
        self.test_metrics  = _metrics.clone(prefix="test/")

    def forward(self, x):
        pass

    def configure_optimizers(self):
        pass

    def loss_function(self, preds, targets):
        pass

    def prepare_data(self):
        pass
                  
    def train_dataloader(self):
        pass
        
    def val_dataloader(self):
        pass
        
    def training_step(self, batch, batch_idx):
        input, target = batch
        pred          = self(input)
        loss          = self.loss_function(pred, target)

        self.log("train/loss", loss)
        return { 'loss': loss , "pred": pred, "target": target }

    def training_step_end(self, outputs):
        preds   = outputs['pred']
        targets = outputs['target'] 

        m = self.train_metrics(preds, targets)
        self.log_dict(m)
        return outputs

    def validation_step(self, batch, batch_idx):
        input, target = batch
        pred          = self(input)
        loss          = self.loss_function(pred, target)

        self.log("val/loss", loss)
        return { "loss": loss, "pred": pred, "target": target }

    def validation_epoch_end(self, outputs):
        preds   = torch.cat([x['pred'] for x in outputs], dim=0)
        targets = torch.cat([x['target'] for x in outputs], dim=0)

        m = self.val_metrics(preds, targets)
        self.log_dict(m)
        return outputs

    def test_step(self, batch, batch_idx):
        input, target = batch
        pred          = self(input)
        loss          = self.loss_function(pred, target)

        self.log("test/loss", loss)
        return { "loss": loss, "pred": pred, "target": target }

    def test_epoch_end(self, outputs):
        preds   = torch.cat([x['pred'] for x in outputs], dim=0)
        targets = torch.cat([x['target'] for x in outputs], dim=0)

        m = self.test_metrics(preds, targets)
        self.log_dict(m)
        return outputs

## Sweep

In [None]:
class Objective:
    def __init__(self, run, monitor):
        self.run     = run        # Neptune Instance
        self.monitor = monitor    # Metric to be monitored 

    def __call__(self, trial):
        which_trial = trial.number 
        use_gpu     = torch.cuda.is_available() 

        tunable = Namespace(

        )

        nlogger = NeptuneLogger(
            run             = self.run,
            base_namespace  = f"sweep/trial_{which_trial}",
            close_after_fit = False
        )

        # dataset = 
        # run["dataset"].log(dataset.__class__.__name__)

        # model = 
        # run["model"].log(model.__class__.__name__)

        trainer = pl.Trainer(
            gpus                      = -1 if use_gpu else 0,
            precision                 = tunable.precision if use_gpu else 32,
            deterministic             = True,
            benchmark                 = True,

            max_epochs                = 33,
            num_sanity_val_steps      = 2,
            check_val_every_n_epoch   = 2,

            weights_summary           = "full",
            progress_bar_refresh_rate = 20,
            gradient_clip_val         = tunable.gradient_clip_val,
            stochastic_weight_avg     = tunable.stochastic_weight_avg,
            logger                    = nlogger,
            checkpoint_callback       = False,
            # callbacks                 = [PyTorchLightningPruningCallback(trial, monitor=self.monitor)],
        )

        trainer.fit(model, dataset)

        accuracy = trainer.callback_metrics[self.monitor].item()

        # _inputs  = 
        flops, num_parameters = thop.profile(model, inputs=_inputs, verbose=False)

        return accuracy, num_parameters


In [None]:
run = neptune.init(
    project   = f"{C.NEPTUNE.USERNAME}/{C.NEPTUNE.PROJECT}",
    name      = "sweep",
    api_token = C.NEPTUNE.API_TOKEN,
    mode      = "debug",
)

study = optuna.study.create_study(
    directions     = ["maximize", "minimize"],
    # pruner         = optuna.pruners.MedianPruner(),
)

# Resume from here
study.optimize(
    Objective(run, "val/AverageAccuracy"), 
    n_trials  = 100,
    callbacks = [NeptuneCallback(run)]
)
run.stop()

## Train

In [None]:
run = neptune.init(
    project   = f"{C.NEPTUNE.USERNAME}/{C.NEPTUNE.PROJECT}",
    name      = "sweep",
    api_token = C.NEPTUNE.API_TOKEN,
    mode      = "debug",
)

# Logger
nlogger = NeptuneLogger(
    run             = run,
    base_namespace  = "training",
    close_after_fit = False # Keep open for testing
)

# Get parameters from best study
best = Namespace(**run["best/params"].fetch())

# Save best model every 5 val_epochs
model_checkpoint = ModelCheckpoint(
    monitor    = "val/AverageAccuracy",
    mode       = "max",
    verbose    = True
    period     = 5,
    save_top_k = 5,
    save_last  = True,
    dirpath    = "./checkpoints/",
    filename   = "epoch-{epoch:03d}__val_AverageAccuracy-{val/AverageAccuracy:.4f}",
    auto_insert_metric_name = False,
)

# Stop training if model stops improving
early_stopping = EarlyStopping(
    monitor      = "val/AverageAccuracy", 
    mode         = "maxx", 
    patience     = 10, 
    min_delta    = 1e-6, 
    strict       = True, 
    check_finite = True,
    verbose      = True, 
)

# Dataset
# TODO

# Model
# TODO

# Trainer
trainer = pl.Trainer(
    gpus                      = -1 if torch.cuda.is_available() else 0,
    precision                 = best.precision if torch.cuda.is_available() else 32,
    deterministic             = True,
    benchmark                 = True,

    min_epochs                = 50
    max_epochs                = 1000,
    num_sanity_val_steps      = 4,
    check_val_every_n_epoch   = 4,

    gradient_clip_val         = best.gradient_clip_val,
    stochastic_weight_avg     = best.stochastic_weight_avg,
    weights_summary           = "full",
    progress_bar_refresh_rate = 20,

    logger                    = nlogger,
    callbacks                 = [model_checkpoint, early_stopping],
)

In [None]:
# 🐉
trainer.fit(model, dataset)