# POMO Lightning

In [10]:
%load_ext autoreload
%autoreload 2

import sys; sys.path.append('../../')

import math
from typing import List, Tuple, Optional, NamedTuple, Dict, Union, Any
from einops import rearrange, repeat
from hydra.utils import instantiate

import torch
from torch import nn
from torch.utils.checkpoint import checkpoint
from torch.nn import DataParallel
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import lightning as L

from torchrl.envs import EnvBase
from torchrl.envs.utils import step_mdp
from tensordict import TensorDict

from ncobench.data.dataset import TorchDictDataset

from ncobench.envs.tsp import TSPEnv
from ncobench.models.rl.reinforce import *
from ncobench.models.components.am.context import env_context
from ncobench.models.components.am.embeddings import env_init_embedding, env_dynamic_embedding
from ncobench.models.components.am.encoder import GraphAttentionEncoder
from ncobench.models.components.am.decoder import Decoder, decode_probs, PrecomputedCache, LogitAttention
from ncobench.models.components.am.base import get_log_likelihood
from ncobench.models.nn.attention import NativeFlashMHA, flash_attn_wrapper
from ncobench.utils.lightning import get_lightning_device

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Novelty compared to `AttentionModel`

## Pseudo-code of differences for training

### Attention Model
```python
def train(policy network p_θ, training set S, batch size B, significance α):
    for i in 1...B:
        s_i = RandomInstance(S)
        π_i = SampleRollout(s_i, p_θ)
        vb = UpdateBaseline(s, π)
        ∇𝐿 = (1/B) * Σ(L(π_i|s_i) - b_i) * ∇_θ log(p_θ(π_i|s_i))
        θ = GradientDescent(θ, ∇𝐿)
        if OneSidedPairedTest(p_θ, p_θ^BL) < α: # p_θ is better than p_θ^BL
            θ^BL = θ
```

### POMO
```python
def train(policy network p_θ, training set S, batch size B, number of start nodes N):
    for i in 1...B:
        s_i = RandomInstance(S)
        # New: select starting nodes, and rolout with them
        α_i1,...,α_iN = SelectStartNodes(s_i)
        π_i1,...,π_iN = SampleRollout(s_i, p_θ, {α_i,j})
        vb = UpdateBaseline(s, π)
        # New: baseline is simply the average baseline
        ∇𝐿 = (1/NB) * Σ(L(π_ij|s_i) - b_i) * ∇_θ log(p_θ(π_ij|s_i))
        θ = GradientDescent(θ, ∇𝐿)
```

So the novelty is:
1. We select a set of starting nodes for each instance and rollout with them
2. Replace baseline with average baseline

---

## Other novelty
1. Use `multi-greedy` decoding (e.g.) simply take the best out of the starting points

In [11]:

class Decoder(nn.Module):
    def __init__(self, env, embedding_dim, num_heads, num_pomo=20, **logit_attn_kwargs):
        super(Decoder, self).__init__()

        self.env = env
        self.embedding_dim = embedding_dim
        self.n_heads = num_heads

        assert embedding_dim % num_heads == 0

        step_context_dim = 2 * embedding_dim  # Embedding of first and last node
        self.context = env_context(self.env.name, {"context_dim": step_context_dim})
        self.dynamic_embedding = env_dynamic_embedding(
            self.env.name, {"embedding_dim": embedding_dim}
        )

        # For each node we compute (glimpse key, glimpse value, logit key) so 3 * embedding_dim
        self.project_node_embeddings = nn.Linear(
            embedding_dim, 3 * embedding_dim, bias=False
        )
        self.project_fixed_context = nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.project_step_context = nn.Linear(
            step_context_dim, embedding_dim, bias=False
        )

        # MHA
        self.logit_attention = LogitAttention(
            embedding_dim, num_heads, **logit_attn_kwargs
        )

        # POMO
        self.num_pomo = num_pomo

    def forward(self, td, embeddings, decode_type="sampling"):
        outputs = []
        actions = []

        # Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step
        cached_embeds = self._precompute(embeddings)

        # Here we suppose all the batch is done at the same time
        while not td["done"].any():  

            # POMO: first action is decided via select_start_nodes
            if next(td["i"]) == 0:
                action = self.select_start_nodes(td)
                td.set("action", action[:, None])
                td = self.env.step(td)["next"]
                log_p = torch.zeros_like(action) # first log_p is 0

            else:   
                # Compute the logits for the next node
                log_p, mask = self._get_log_p(cached_embeds, td)

                # Select the indices of the next nodes in the sequences, result (batch_size) long
                action = decode_probs(
                    log_p.exp().squeeze(1), mask.squeeze(1), decode_type=decode_type
                )

            td.set("action", action[:, None])
            td = self.env.step(td)["next"]


            # Collect output of step
            outputs.append(log_p.squeeze(1))
            actions.append(action)

        outputs, actions = torch.stack(outputs, 1), torch.stack(actions, 1)
        td.set("reward", self.env.get_reward(td["observation"], actions))
        return outputs, actions, td

    def select_start_nodes(self, td):
        """Select POMO"""
        batch_size = td["observation"].shape[0]
        num_pomo = self.num_pomo
        # selected = torch.arange(pomo_size)[None, :].expand(batch_size, pomo_size)
        selected = torch.arange(num_pomo)[None, :].expand(batch_size, num_pomo)
        return selected

    
    def _precompute(self, embeddings):
        # The fixed context projection of the graph embedding is calculated only once for efficiency
        graph_embed = embeddings.mean(1)

        # The projection of the node embeddings for the attention is calculated once up front
        (
            glimpse_key_fixed,
            glimpse_val_fixed,
            logit_key_fixed,
        ) = self.project_node_embeddings(embeddings[:, None, :, :]).chunk(3, dim=-1)

        # Organize in a TensorDict for easy access
        cached_embeds = PrecomputedCache(
            node_embeddings=embeddings,
            graph_context=self.project_fixed_context(graph_embed)[:, None, :],
            glimpse_key=self.logit_attention._make_heads(glimpse_key_fixed),
            glimpse_val=self.logit_attention._make_heads(glimpse_val_fixed),
            logit_key=logit_key_fixed,
        )

        return cached_embeds

    def _get_log_p(self, cached, td):
        context = self.context(cached.node_embeddings, td)
        step_context = self.project_step_context(context)  # [batch, 1, embed_dim]

        query = cached.graph_context + step_context  # [batch, 1, embed_dim]

        # Compute keys and values for the nodes
        # glimpse_K, glimpse_V, logit_K = self._get_attention_node_data(cached, td['observation'])
        (
            glimpse_key_dynamic,
            glimpse_val_dynamic,
            logit_key_dynamic,
        ) = self.dynamic_embedding(td["observation"])
        glimpse_key = cached.glimpse_key + glimpse_key_dynamic
        glimpse_key = cached.glimpse_val + glimpse_val_dynamic
        logit_key = cached.logit_key + logit_key_dynamic

        # Get the mask
        mask = ~td["action_mask"]

        # Compute logits
        log_p = self.logit_attention(query, glimpse_key, glimpse_key, logit_key, mask)

        return log_p, mask

## Test `POMOBase`

In [5]:
env = TSPEnv(num_loc=10).transform()

# data = env.gen_params(batch_size=[10000]) # NOTE: need to put batch_size in a list!!
# init_td = env.reset(data)
# env.batch_size = [10000]
init_td = env.reset(batch_size=[10000])
dataset = TorchDictDataset(init_td)


dataloader = DataLoader(
                dataset,
                batch_size=128,
                shuffle=False, # no need to shuffle, we're resampling every epoch
                num_workers=0,
                collate_fn=torch.stack, # we need this to stack the batches in the dataset
            )

model = AttentionModelBase(
    env,
    embedding_dim=128,
    hidden_dim=128,
    n_encode_layers=3,
    # force_flash_attn=True,
).to("cuda")

# model = torch.compile(model)

x = next(iter(dataloader)).to("cuda")

out = model(x, decode_type="sampling")

In [6]:
def get_lightning_device(lit_module: L.LightningModule) -> torch.device:
    """Get the device of the lightning module
    See device setting issue in setup https://github.com/Lightning-AI/lightning/issues/2638
    """
    if lit_module.trainer.strategy.root_device != lit_module.device:
        return lit_module.trainer.strategy.root_device
    return lit_module.device


class AttentionModel(nn.Module):
    def __init__(self, env, policy):
        super().__init__()
        self.env = env
        self.policy = policy

        # TODO: hydra instantiation
        # self.policy = instantiate(cfg.policy)
        # self.baseline = instantiate(cfg.baseline) TODO

    def forward(self, td: TensorDict, phase: str="train", decode_type: str=None) -> TensorDict:
        """Evaluate model, get costs and log probabilities and compare with baseline"""

        # Evaluate model, get costs and log probabilities
        out_policy = self.policy(td)
        bl_val, bl_loss = self.baseline.eval(td, -out_policy['reward'])

        # print(bl_val, bl_loss)
        # Calculate loss
        advantage = -out_policy['reward'] - bl_val
        reinforce_loss = (advantage * out_policy['log_likelihood']).mean()
        loss = reinforce_loss + bl_loss

        return {'loss': loss, 'reinforce_loss': reinforce_loss, 'bl_loss': bl_loss, 'bl_val': bl_val, **out_policy}
    
    def setup(self, lit_module):
        # Make baseline taking model itself and train_dataloader from model as input
        # TODO make this as taken from config
        self.baseline = instantiate({"_target_": "__main__.WarmupBaseline",
                                    "baseline": {"_target_": "__main__.RolloutBaseline",                                             }
                                    })  

        self.baseline.setup(self.policy, lit_module.val_dataloader(), self.env, device=get_lightning_device(lit_module))         
        # self.baseline = NoBaseline()

    def on_train_epoch_end(self, lit_module):
        # self.baseline.epoch_callback(self.policy, self.env, pl_module)
        self.baseline.epoch_callback(self.policy, lit_module.val_dataloader(), lit_module.current_epoch, self.env, device=get_lightning_device(lit_module))

## Lightning Module

In [7]:
class NCOLightningModule(L.LightningModule):
    def __init__(self, env, model, lr=1e-4, batch_size=128, train_size=1000, val_size=10000):
        super().__init__()

        # TODO: hydra instantiation
        self.env = env
        self.model = model
        self.lr = lr
        self.batch_size = batch_size
        self.train_size = train_size
        self.val_size = val_size

    def setup(self, stage="fit"):
        self.train_dataset = self.get_observation_dataset(self.train_size)
        self.val_dataset = self.get_observation_dataset(self.val_size)
        if hasattr(self.model, "setup"):
            self.model.setup(self)

    def shared_step(self, batch: Any, batch_idx: int, phase: str):
        td = self.env.reset(init_observation=batch)
        output = self.model(td, phase)
        
        # output = self.model(batch, phase)
        self.log(f"{phase}/cost", -output["reward"].mean(), prog_bar=True)
        return {"loss": output['loss']}

    def training_step(self, batch: Any, batch_idx: int):   
        return self.shared_step(batch, batch_idx, phase='train')

    def validation_step(self, batch: Any, batch_idx: int):
        return self.shared_step(batch, batch_idx, phase='val')

    def test_step(self, batch: Any, batch_idx: int):
        return self.shared_step(batch, batch_idx, phase='test')

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=1e-5)
        # optim = Lion(model.parameters(), lr=1e-4, weight_decay=1e-2)
        # TODO: scheduler
        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, total_steps)
        return [optim] #, [scheduler]
    
    def train_dataloader(self):
        return self._dataloader(self.train_dataset)
    
    def val_dataloader(self):
        return self._dataloader(self.val_dataset)
    
    def on_train_epoch_end(self):
        if hasattr(self.model, "on_train_epoch_end"):
            self.model.on_train_epoch_end(self)
        self.train_dataset = self.get_observation_dataset(self.train_size) 

    # def get_observation_dataset(self, size):
    #     # online data generation: we generate a new batch online
    #     data = self.env.gen_params(batch_size=size)
    #     return TorchDictDataset(self.env.reset(data))

    def get_observation_dataset(self, size):
        # online data generation: we generate a new batch online
        # data = self.env.gen_params(batch_size=size)
        return TorchDictDataset(self.env.reset(batch_size=[size])['observation'])
       
    def _dataloader(self, dataset):
        return DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=False, # no need to shuffle, we're resampling every epoch
            num_workers=0,
            collate_fn=torch.stack, # we need this to stack the batches in the dataset
            pin_memory=self.on_gpu,
        )

In [8]:
# Disable profiling executor. This reduces memory and increases speed.
try:
    torch._C._jit_set_profiling_executor(False)
    torch._C._jit_set_profiling_mode(False)
except AttributeError:
    pass


In [13]:
env = TSPEnv(num_loc=20).transform()

# env = env.transform()
policy = AttentionModelBase(
    env,
    embedding_dim=128,
    hidden_dim=128,
    n_encode_layers=3,
    force_flash_attn=True,
)


model = instantiate({"_target_": "__main__.AttentionModel", "env": env}, policy=policy)



: 

In [6]:
cfg.strategy

ConfigAttributeError: Missing key strategy
    full_key: strategy
    object_type=dict

In [7]:
import lightning.strategies

ModuleNotFoundError: No module named 'lightning.strategies'

In [3]:
# Create omegaconf config
from omegaconf import OmegaConf

cfg = OmegaConf.create({"model": "ciao"})

# Add key to config
cfg.est


NameError: name 'model' is not defined

In [9]:
env = TSPEnv(num_loc=20).transform()

# env = env.transform()
policy = AttentionModelBase(
    env,
    embedding_dim=128,
    hidden_dim=128,
    n_encode_layers=3,
    force_flash_attn=True,
)

model_final = AttentionModel(env, policy)

# # TODO CHANGE THIS
batch_size = 512 #1024 #512

model = NCOLightningModule(env, model_final, batch_size=batch_size, train_size=1280000, lr=1e-4)

# Trick to make calculations faster
torch.set_float32_matmul_precision("medium")

# Wandb Logger - we can use others as well as simply `None`
# logger = pl.loggers.WandbLogger(project="torchrl", name="pendulum")
# logger = L.loggers.CSVLogger("logs", name="tsp")

epochs = 1


# from lightning.pytorch.callbacks import DeviceStatsMonitor
# callbacks = [DeviceStatsMonitor()]

from lightning.pytorch.profilers import AdvancedProfiler

profiler = AdvancedProfiler(dirpath=".", filename="perf_logsv2")

# Trainer
trainer = L.Trainer(
    max_epochs=epochs,
    accelerator="gpu",
    devices=[1],
    # callbacks=callbacks,
    # profiler=profiler,
    # strategy="deepspeed_stage_3_offload",
    precision=16,
    log_every_n_steps=100,   
    gradient_clip_val=1.0, # clip gradients to avoid exploding gradients
)

# Fit the model
trainer.fit(model)

  rank_zero_warn(
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Evaluating baseline model on evaluation dataset


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type           | Params
-----------------------------------------
0 | env   | TSPEnv         | 0     
1 | model | AttentionModel | 1.4 M 
-----------------------------------------
1.4 M     Trainable params
0         Non-trainable params
1.4 M     Total params
5.681     Total estimated model params size (MB)


                                                                           

  rank_zero_warn(
  rank_zero_warn(


Epoch 0:  92%|█████████▏| 2306/2500 [02:23<00:12, 16.06it/s, v_num=177, train/cost=4.020]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
