# Custom Training Logic with Lightning Integration and Lightning Hooks

In this example, we showcase the ability for the user to define own training logic and easily integrate into Lightning workflow using a variety of Lightning hooks. 
A reference to these hooks is provided here: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#hooks. 

The first part of this notebooks is equivalent to the basics tutorial from before, so we will speed ahead through that first: 



## NeuroMANCER and Dependencies

### Install (Colab only)
Skip this step when running locally.

In [None]:
!pip install neuromancer

### Import

In [1]:
import torch
import torch.nn as nn
import numpy as np
import neuromancer.slim as slim
import matplotlib.pyplot as plt
import lightning.pytorch as pl 

from neuromancer.trainer import LitTrainer
from neuromancer.problem import Problem
from neuromancer.constraint import variable
from neuromancer.dataset import DictDataset
from neuromancer.loss import PenaltyLoss
from neuromancer.modules import blocks
from neuromancer.system import Node


# Problem formulation

In this example we will solve parametric constrained [Rosenbrock problem](https://en.wikipedia.org/wiki/Rosenbrock_function):

$$
\begin{align}
&\text{minimize } &&  (1-x)^2 + a(y-x^2)^2\\
&\text{subject to} && \left(\frac{p}{2}\right)^2 \le x^2 + y^2 \le p^2\\
& && x \ge y
\end{align}
$$

with parameters $p, a$ and decision variables $x, y$.


### Lightning Dataset

We constructy the dataset by sampling the parametric space.

In [2]:
data_seed = 408  # random seed used for simulated data
np.random.seed(data_seed)
torch.manual_seed(data_seed)
nsim = 5000  # number of datapoints: increase sample density for more robust results

# create dictionaries with sampled datapoints with uniform distribution
a_low, a_high, p_low, p_high = 0.2, 1.2, 0.5, 2.0

In [3]:

def data_setup_function(nsim, a_low, a_high, p_low, p_high): 

    
    samples_train = {"a": torch.FloatTensor(nsim, 1).uniform_(a_low, a_high),
                    "p": torch.FloatTensor(nsim, 1).uniform_(p_low, p_high)}
    samples_dev = {"a": torch.FloatTensor(nsim, 1).uniform_(a_low, a_high),
                "p": torch.FloatTensor(nsim, 1).uniform_(p_low, p_high)}
    samples_test = {"a": torch.FloatTensor(nsim, 1).uniform_(a_low, a_high),
                "p": torch.FloatTensor(nsim, 1).uniform_(p_low, p_high)}
    # create named dictionary datasets
    train_data = DictDataset(samples_train, name='train')
    dev_data = DictDataset(samples_dev, name='dev')
    test_data = DictDataset(samples_test, name='test')

    batch_size = 64

    # Return the dict datasets in train, dev, test order, followed by batch_size 
    return train_data, dev_data, test_data, batch_size 



We now define the **Problem()**

In [4]:
# define neural architecture for the trainable solution map
func = blocks.MLP(insize=2, outsize=2,
                bias=True,
                linear_map=slim.maps['linear'],
                nonlin=nn.ReLU,
                hsizes=[80] * 4)
# wrap neural net into symbolic representation of the solution map via the Node class: sol_map(xi) -> x
sol_map = Node(func, ['a', 'p'], ['x'], name='map')

## Objective and Constraints in NeuroMANCER

In [5]:
# define decision variables
x1 = variable("x")[:, [0]]
x2 = variable("x")[:, [1]]
# problem parameters sampled in the dataset
p = variable('p')
a = variable('a')

# objective function
f = (1-x1)**2 + a*(x2-x1**2)**2
obj = f.minimize(weight=1.0, name='obj')

# constraints
Q_con = 100.  # constraint penalty weights
con_1 = Q_con*(x1 >= x2)
con_2 = Q_con*((p/2)**2 <= x1**2+x2**2)
con_3 = Q_con*(x1**2+x2**2 <= p**2)
con_1.name = 'c1'
con_2.name = 'c2'
con_3.name = 'c3'

In [6]:
# constrained optimization problem construction
objectives = [obj]
constraints = [con_1, con_2, con_3]
components = [sol_map]

# create penalty method loss function
loss = PenaltyLoss(objectives, constraints)
# construct constrained optimization problem
problem = Problem(components, loss)

# Lightning Hooks: 

Lightning hooks are modular "lego" blocks that define the training process of the LightningModule. Recall that the Lightning trainer fits a `LightningModule` to a `LightningDataModule`. For user-simplicity, the LightningModule and LightningModule are abstracted away; the user only need to interact with the LitTrainer. However we provide the capability of more fine-grained control over the training process by interacting with these hooks. 

Let's begin with the simplest hook: the `training_step`

## Custom Training Logic
Training within PyTorch Lightning framework is defined by a `training_step` function, which defines the logic going from a data batch to loss. For example, the default training_step used is shown below (other extraneous details removed for simplicity). Here, we get the problem output for the given batch and return the loss associated with that output.

```
def training_step(self, batch):
    output = self.problem(batch)
    loss = output[self.train_metric]
    return loss
```

Notice how easy this is, there is no need to call `optimizer.zero_grad()`, etc.; there is no PyTorch boilerplate. All the user needs to do is define how the loss should be generated during training. 

While rare, there may be instances where the user might want to define their own training logic. Potential cases include test-time data augmentation (e.g. operations on/w.r.t the data rollout), other domain augmentations, or modifications to how the output and/or loss is handled. 

The user can pass in their own "training_step" by supplying an equivalent function handler to the "custom_training_step" keyword of LitTrainer, for example: 

```
def custom_training_step(model, batch): 
    output = model.problem(batch)
    Q_con = 1
    if model.current_epoch > 1: 
        Q_con = 1/10000
    loss = Q_con*(output[model.train_metric])
    return loss
```

The signature of this function should be `custom_training_step(model, batch)` where model is a Neuromancer Problem

In [None]:
def custom_training_step(model, batch): 
    output = model.problem(batch)
    Q_con = 1
    if model.current_epoch > 1: 
        Q_con = 1/10000    
    loss = Q_con*(output[model.train_metric])
    return loss

lit_trainer = LitTrainer(epochs=10, accelerator='cpu', patience=3, custom_training_step=custom_training_step)
lit_trainer.fit(problem=problem, data_setup_function=data_setup_function, nsim=nsim,a_low=0.2, a_high=1.2, p_low=0.5, p_high=2.0)


Below is another example of a dummy custom_training_step. Here we want to add the loss of the previous batch and accumulate into the "current" loss. (Again this is a dummy example and not necessarily propel ML techniques). Any sort of variables, such as "past_loss" can be defined by setting them as attributes of "model"

In [None]:
def custom_training_step(model, batch): 
    with torch.no_grad(): 
        if model.current_epoch == 0: 
            model.past_loss = 0
    
    output = model.problem(batch)
    loss = (output[model.train_metric]) + 0.5*model.past_loss
    model.past_loss = loss.item()
    return loss

lit_trainer = LitTrainer(epochs=100, accelerator='cpu', patience=3, custom_training_step=custom_training_step)
lit_trainer.fit(problem=problem, data_setup_function=data_setup_function, nsim=nsim,a_low=0.2, a_high=1.2, p_low=0.5, p_high=2.0)


## More Custom Hooks Via PyTorch Lightning

The **custom_training_step** discussed above is one example of the many hooks created by PyTorch Lightning. These hooks are special methods in the `LightningModule` class that allow for customization and fine-tuning of the training, validation, and testing processes. These hooks are invoked at specific points during the training lifecycle and enable users to inject custom logic. For instance, hooks like `training_step`, `validation_step`, and `test_step` handle the core logic for processing batches during different stages. Hooks like `on_epoch_start`, `on_epoch_end`, `on_batch_start`, and `on_batch_end` allow for actions at the start and end of epochs and batches. Other hooks such as `on_train_start`, `on_train_end`, `on_validation_start`, and `on_validation_end` provide entry and exit points for the training and validation phases, allowing for setup, teardown, logging, and other custom operations. These hooks provide a structured and clean way to extend and customize the training workflow in PyTorch Lightning.

For a list of all available hooks please refer to https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#hooks

We can visualize this hook interface as follows: 

<img src="../../figs/hooks.png" width="600">  


### Integrating More Hooks into NeuroMANCER

We implement the hook(s) based off their respective signature (please refer to the API at the link above to find that) and pass in a hook dictionary to the LitTrainer.


In [None]:
# Define custom hooks
def custom_train_epoch_end(self): 
    if not hasattr(self, 'train_loss_epoch_history') or self.train_loss_epoch_history is None: #define a list to store training loss
        self.train_loss_epoch_history = []
    
    # get the epoch average train loss. `training_step_outputs` is a list already created to store loss per batch within the training epoch
    epoch_average = torch.stack(self.training_step_outputs).mean()
    self.train_loss_epoch_history.append(epoch_average)

# Do similar thing for validation loss
def custom_validation_epoch_end(self):
    if not hasattr(self, 'val_loss_epoch_history') or self.val_loss_epoch_history is None:
        self.val_loss_epoch_history = []

    epoch_average = torch.stack(self.validation_step_outputs).mean()
    self.val_loss_epoch_history.append(epoch_average)

# Do something when training starts
def on_train_start(self): 
    print("HELLO WORLD")

# optimizer with scheduler
def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    return [optimizer], [scheduler]

# Create a custom hooks dictionary
custom_hooks = {
    'on_train_epoch_end': custom_train_epoch_end,
    'on_validation_epoch_end': custom_validation_epoch_end, 
    'on_train_start': on_train_start, 
    'configure_optimizers': configure_optimizers
}

# Initialize the trainer with custom hooks
trainer = LitTrainer(epochs=1, accelerator='cpu', patience=3, custom_training_step=custom_training_step, custom_hooks=custom_hooks)

# Assuming `problem` and `data_setup_function` are defined
trainer.fit(problem, data_setup_function, nsim=nsim,a_low=0.2, a_high=1.2, p_low=0.5, p_high=2.0)