# POMO Lightning

In [1]:
%load_ext autoreload
%autoreload 2

import sys; sys.path.append('../../')

import math
from typing import List, Tuple, Optional, NamedTuple, Dict, Union, Any
from einops import rearrange, repeat
from hydra.utils import instantiate

import torch
from torch import nn
from torch.utils.checkpoint import checkpoint
from torch.nn import DataParallel
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import lightning as L

from torchrl.envs import EnvBase
from torchrl.envs.utils import step_mdp
from tensordict import TensorDict

from ncobench.envs.tsp import TSPEnv
from ncobench.models.rl.reinforce import *
from ncobench.models.co.am.context import env_context
from ncobench.models.co.am.embeddings import env_init_embedding, env_dynamic_embedding
from ncobench.models.co.am.encoder import GraphAttentionEncoder
from ncobench.models.co.am.decoder import Decoder, decode_probs, PrecomputedCache, LogitAttention
from ncobench.models.co.am.policy import get_log_likelihood
from ncobench.models.nn.attention import NativeFlashMHA, flash_attn_wrapper
from ncobench.utils.lightning import get_lightning_device

  warn(


## Novelty compared to `AttentionModel`

## Pseudo-code of differences for training

### Attention Model
```python
def train(policy network p_θ, training set S, batch size B, significance α):
    for i in 1...B:
        s_i = RandomInstance(S)
        π_i = SampleRollout(s_i, p_θ)
        vb = UpdateBaseline(s, π)
        ∇𝐿 = (1/B) * Σ(L(π_i|s_i) - b_i) * ∇_θ log(p_θ(π_i|s_i))
        θ = GradientDescent(θ, ∇𝐿)
        if OneSidedPairedTest(p_θ, p_θ^BL) < α: # p_θ is better than p_θ^BL
            θ^BL = θ
```

### POMO
```python
def train(policy network p_θ, training set S, batch size B, number of start nodes N):
    for i in 1...B:
        s_i = RandomInstance(S)
        # New: select starting nodes, and rolout with them
        α_i1,...,α_iN = SelectStartNodes(s_i)
        π_i1,...,π_iN = SampleRollout(s_i, p_θ, {α_i,j})
        vb = UpdateBaseline(s, π)
        # New: baseline is simply the average baseline
        ∇𝐿 = (1/NB) * Σ(L(π_ij|s_i) - b_i) * ∇_θ log(p_θ(π_ij|s_i))
        θ = GradientDescent(θ, ∇𝐿)
```

So the novelty is:
1. We select a set of starting nodes for each instance and rollout with them
2. Replace baseline with average baseline

---

## Other novelty
1. Use `multi-greedy` decoding (e.g.) simply take the best out of the starting points
2. Use Instance Augmentation (e.g. just augment the dataset with)

In [2]:
# For easier debugging

from rich.traceback import install
install()

<bound method InteractiveShell.excepthook of <ipykernel.zmqshell.ZMQInteractiveShell object at 0x7f3a9a0e3310>>

## Utilities: action selection, batching


In [86]:
def select_start_nodes(batch_size, num_nodes, device="cpu"):
    """Node selection strategy for POMO
    Selects different start nodes for each batch element
    """
    # selected = torch.arange(num_nodes, device=device).repeat_interleave(batch_size, dim=0) # TODO: check
    selected = torch.arange(num_nodes, device=device).repeat(batch_size)
    # requires grad
    # selected.requires_grad_ = True # TODO check
    return selected


def repeat_batch(x, repeats):
    """Same as repeat on dim=0 for tensordicts as well
    Same as einops.repeat(x, 'b ... -> (b r) ...', r=repeats) but 50% faster
    """
    expand_dims = [1] * len(x.shape)
    expand_dims[0] = repeats
    return x.repeat(*expand_dims)


def unrepeat_batch(x, repeats):
    """Undoes repeat_batch
    Same as einops.rearrange(x, '(b r) ... -> b r ...', r=repeats) but 3x faster
    """
    s = x.shape
    return x.view(s[0] // repeats, repeats, *s[1:])

In [79]:
from einops import rearrange, repeat

# test_a
a = torch.rand(64, 20)

In [83]:
x = torch.rand(64, 20, 2)
x = rearrange(x, '(b r) ... -> b r ...', r=4)

In [80]:
b = repeat(a, 'b n -> (b r) n', r=7)
c = repeat(b, 'b n -> (b r) n', r=6)
c_orig = c.clone()

# undo
print(c.shape)
b_ = rearrange(c, '(b r) ... -> b r ...', r=6)
print(b_.shape)
a_ = rearrange(b_, '(b r) ... -> b r ...', r=7)
print(a_.shape)
print(torch.allclose(a_[:, 0, 0, :], a))


torch.Size([2688, 20])
torch.Size([448, 6, 20])
torch.Size([64, 7, 6, 20])
True


In [81]:
b = repeat_batch(a, 7)
c = repeat_batch(b, 6)
c_ours = c.clone()

# undo
b_ = undo_repeat_batch(c, 6)
print(b_.shape)
a_ = undo_repeat_batch(b_, 7)
print(a_.shape)
print(torch.allclose(a_[:, 0, 0, :], a))

print(torch.allclose(c, c_ours))

torch.Size([448, 6, 20])
torch.Size([64, 7, 6, 20])
False
True


In [4]:
# x.view(repeats, s[0] // repeats, *s[1:])
# same but with s[i] and [s[k] for k in len(s) if k != i]]]

In [24]:
from dataclasses import dataclass


@dataclass
class PrecomputedCache:
    node_embeddings: torch.Tensor
    glimpse_key: torch.Tensor
    glimpse_val: torch.Tensor
    logit_key: torch.Tensor


class Decoder(nn.Module):
    def __init__(self, env, embedding_dim, num_heads, num_pomo=20, **logit_attn_kwargs):
        super(Decoder, self).__init__()

        self.env = env
        self.embedding_dim = embedding_dim
        self.n_heads = num_heads

        assert embedding_dim % num_heads == 0

        self.context = env_context(self.env.name, {"embedding_dim": embedding_dim})
        self.dynamic_embedding = env_dynamic_embedding(
            self.env.name, {"embedding_dim": embedding_dim}
        )

        # For each node we compute (glimpse key, glimpse value, logit key) so 3 * embedding_dim
        self.project_node_embeddings = nn.Linear(
            embedding_dim, 3 * embedding_dim, bias=False
        )
        self.project_fixed_context = nn.Linear(embedding_dim, embedding_dim, bias=False)

        # MHA
        self.logit_attention = LogitAttention(
            embedding_dim, num_heads, **logit_attn_kwargs
        )

        # POMO
        self.num_pomo = max(num_pomo, 1) # POMO = 1 is just normal REINFORCE

    def forward(self, td, embeddings, decode_type="sampling"):
        # Collect outputs
        outputs = []
        actions = []

        if self.num_pomo > 1:
            # POMO: first action is decided via select_start_nodes
            action = select_start_nodes(batch_size=td.shape[0], num_nodes=self.num_pomo, device=td.device)

            # # Expand td to batch_size * num_pomo
            td = repeat_batch(td, self.num_pomo)

            td.set("action", action[:, None])
            td = self.env.step(td)["next"]
            log_p = torch.zeros_like(td['action_mask'], device=td.device) # first log_p is 0, so p = log_p.exp() = 1

            outputs.append(log_p.squeeze(1))
            actions.append(action)
        
        # Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step
        cached_embeds = self._precompute(embeddings)        

        # Here we suppose all the batch is done at the same time
        while not td["done"].any():  
            # Compute the logits for the next node
            log_p, mask = self._get_log_p(cached_embeds, td)

            # Select the indices of the next nodes in the sequences, result (batch_size) long
            action = decode_probs(
                log_p.exp().squeeze(1), mask.squeeze(1), decode_type=decode_type
            )

            # Step the environment
            td.set("action", action[:, None])
            td = self.env.step(td)["next"]

            # Collect output of step
            outputs.append(log_p.squeeze(1))
            actions.append(action)

        outputs, actions = torch.stack(outputs, 1), torch.stack(actions, 1)
        td.set("reward", self.env.get_reward(td, actions))
        return outputs, actions, td
    
    def _precompute(self, embeddings):       
        # The projection of the node embeddings for the attention is calculated once up front
        (
            glimpse_key_fixed,
            glimpse_val_fixed,
            logit_key_fixed,
        ) = self.project_node_embeddings(embeddings[:, None, :, :]).chunk(3, dim=-1)

        # Organize in a dataclass for easy access
        cached_embeds = PrecomputedCache(
            node_embeddings=repeat_batch(embeddings, self.num_pomo),
            glimpse_key=repeat_batch(self.logit_attention._make_heads(glimpse_key_fixed), self.num_pomo),
            glimpse_val=repeat_batch(self.logit_attention._make_heads(glimpse_val_fixed), self.num_pomo),
            logit_key=repeat_batch(logit_key_fixed, self.num_pomo)
        )

        return cached_embeds

    def _get_log_p(self, cached, td):
        # Compute the query based on the context (computes automatically the first and last node context)
        step_context = self.context(cached.node_embeddings, td)
        query = step_context # in POMO, no graph context (trick for overfit) # [batch, 1, embed_dim]

        # Compute keys and values for the nodes
        glimpse_key_dynamic, glimpse_val_dynamic, logit_key_dynamic = self.dynamic_embedding(td)
        glimpse_key = cached.glimpse_key + glimpse_key_dynamic
        glimpse_key = cached.glimpse_val + glimpse_val_dynamic
        logit_key = cached.logit_key + logit_key_dynamic

        # Get the mask
        mask = ~td["action_mask"]
        mask = mask.unsqueeze(1) if mask.dim() == 2 else mask

        # Compute logits
        log_p = self.logit_attention(query, glimpse_key, glimpse_key, logit_key, mask)

        return log_p, mask

In [14]:
class POMOPolicy(nn.Module):

    def __init__(self,
                 env: EnvBase,
                 embedding_dim: int,
                 hidden_dim: int,
                 encoder: nn.Module = None,
                 decoder: nn.Module = None,
                 num_pomo: int = 10,
                 n_encode_layers: int = 3,
                 normalization: str = 'batch',
                 n_heads: int = 8,
                 checkpoint_encoder: bool = False,
                 mask_inner: bool = True,
                 force_flash_attn: bool = False,
                 **kwargs
                 ):
        super(POMOPolicy, self).__init__()

        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.n_encode_layers = n_encode_layers
        self.env = env

        self.n_heads = n_heads
        self.checkpoint_encoder = checkpoint_encoder

        self.init_embedding = env_init_embedding(self.env.name, {"embedding_dim": embedding_dim})

        self.encoder = GraphAttentionEncoder(
            num_heads=n_heads,
            embed_dim=embedding_dim,
            num_layers=self.n_encode_layers,
            normalization=normalization,
            force_flash_attn=force_flash_attn,
        ) if encoder is None else encoder
        
        self.decoder = Decoder(env, embedding_dim, n_heads, num_pomo=num_pomo, mask_inner=mask_inner, force_flash_attn=force_flash_attn) if decoder is None else decoder
        self.num_pomo = num_pomo

    def forward(self, td: TensorDict, phase: str = "train", decode_type: str = "sampling", return_actions: bool = False) -> TensorDict:
        """Given observation, precompute embeddings and rollout"""

        # Set decoding type for policy, can be also greedy
        embedding = self.init_embedding(td)
        encoded_inputs, _ = self.encoder(embedding)

        # Main rollout
        _log_p, actions, td = self.decoder(td, encoded_inputs, decode_type)

        # Log likelyhood is calculated within the model since returning it per action does not work well with
        ll = get_log_likelihood(_log_p, actions, td.get('mask', None))
        out = {"reward": td["reward"], "log_likelihood": ll, "actions": actions if return_actions else None}

        return out

## Test the Policy only

In [15]:
num_loc = 15
env = TSPEnv(num_loc=num_loc).transform()

dataset = env.dataset(batch_size=[10000])

dataloader = DataLoader(
                dataset,
                batch_size=32,
                shuffle=False, # no need to shuffle, we're resampling every epoch
                num_workers=0,
                collate_fn=torch.stack, # we need this to stack the batches in the dataset
            )

model = POMOPolicy(
    env,
    embedding_dim=128,
    hidden_dim=128,
    n_encode_layers=3,
    num_pomo=num_loc,
    # force_flash_attn=True,
).to("cuda")

# model = torch.compile(model)

x = next(iter(dataloader)).to("cuda")
x = env.reset(init_obs=x)

out = model(x, decode_type="sampling")

## Create full model: `env` + `policy` + `baseline`

In [16]:
def augment_xy_data_by_8_fold(xy):
    # [batch, graph, 2]
    x, y = xy.split(1, dim=2)
    # augmnetations [batch, graph, 2]
    z0 = torch.cat((x, y), dim=2)
    z1 = torch.cat((1 - x, y), dim=2)
    z2 = torch.cat((x, 1 - y), dim=2)
    z3 = torch.cat((1 - x, 1 - y), dim=2)
    z4 = torch.cat((y, x), dim=2)
    z5 = torch.cat((1 - y, x), dim=2)
    z6 = torch.cat((y, 1 - x), dim=2)
    z7 = torch.cat((1 - y, 1 - x), dim=2)
    # [batch*8, graph, 2]
    aug_xy = torch.cat((z0, z1, z2, z3, z4, z5, z6, z7), dim=0)
    return aug_xy


def env_aug_feats(env_name: str) -> Tuple[str, ...]:
    return ('observation', 'depot') if env_name == "op" else ('observation',)


class StateAugmentation(nn.Module):
    def __init__(self, env_name, num_augment: int = 8):
        """Augment state by 8 fold for POMO"""
        super(StateAugmentation, self).__init__()
        self.num_augment = num_augment
        assert num_augment == 8, "Only 8 fold augmentation is supported for POMO"
        self.augmentation = augment_xy_data_by_8_fold
        self.feats = env_aug_feats(env_name)

    def forward(self, td: TensorDict) -> TensorDict:
        td_aug = repeat_batch(td, self.num_augment)
        for feat in self.feats:
            aug_feat = self.augmentation(td[feat])
            td_aug[feat] = aug_feat
        return td_aug


# Test above
td = TensorDict({
    "observation": torch.randn(32, 15, 2),
    "depot": torch.randn(32, 1, 2),
}, batch_size=32)
augment = StateAugmentation("op")
td_aug = augment(td)
print(td_aug.shape)


torch.Size([256])


In [17]:
def get_best_actions(actions, max_idxs):
    actions = undo_repeat_batch(actions, max_idxs.shape[0])
    return actions.gather(0, max_idxs[..., None, None])


class POMO(nn.Module):
    def __init__(self, env, policy, baseline, num_augment=8, **kwargs):
        super().__init__()
        self.env = env
        self.policy = policy
        self.baseline = baseline

        # POMO parameters
        self.num_pomo = policy.num_pomo
        self.num_augment = num_augment
        self.augment = StateAugmentation(env.name, num_augment) if num_augment > 1 else None

    def forward(self, td: TensorDict, phase: str="train", decode_type: str="sampling", return_actions: bool=False) -> TensorDict:
        """Evaluate model, get costs and log probabilities and compare with baseline"""

        # Augment data if not in training phase
        if phase != "train" and self.augment is not None:
            td = self.augment(td)

        # Evaluate model, get costs and log probabilities
        out = self.policy(td, decode_type=decode_type, return_actions=return_actions)

        costs = undo_repeat_batch(-out['reward'], self.policy.num_pomo)
        ll = undo_repeat_batch(out['log_likelihood'], self.policy.num_pomo)
        bl_val, bl_loss = self.baseline.eval(td, costs)

        # Calculate REINFORCE loss
        advantage = costs - bl_val
        reinforce_loss = (advantage * ll).mean()
        loss = reinforce_loss + bl_loss

        # Max POMO reward. Decouple augmentation and POMO 
        # [num_pomo, num_augment, batch]
        reward = undo_repeat_batch(undo_repeat_batch(out["reward"], self.num_augment if phase != "train" else 1), self.num_pomo, dim=1)
        max_reward, max_idxs = reward.max(dim=0)
        pomo_retvals = {"max_reward": max_reward, "best_actions": get_best_actions(out["actions"], max_idxs) if return_actions else None}

        # Get augmentation score only during inference
        aug_retvals = {}
        if phase != "train" and self.augment is not None:
            # [num_augment, batch]
            aug_reward = undo_repeat_batch(max_reward, self.num_augment)
            max_aug_reward, max_idxs = aug_reward.max(dim=0)
            aug_retvals = {"max_aug_reward": max_aug_reward, "best_aug_actions": get_best_actions(out["actions"], max_idxs) if return_actions else None}
 
        return {'loss': loss, 'reinforce_loss': reinforce_loss, 'bl_loss': bl_loss, 'bl_val': bl_val, **out, **pomo_retvals, **aug_retvals}
        
    def setup(self, lit_module):
        # Make baseline taking model itself and train_dataloader from model as input
        if hasattr(self.baseline, "setup"):
            self.baseline.setup(self.policy, lit_module.train_dataloader(), self.env, device=get_lightning_device(lit_module))
    
    def on_train_epoch_end(self, lit_module):
        # self.baseline.epoch_callback(self.policy, self.env, pl_module)
        self.baseline.epoch_callback(self.policy, lit_module.val_dataloader(), lit_module.current_epoch, self.env, device=get_lightning_device(lit_module))

## Lightning Module

In [18]:
class NCOLightningModule(L.LightningModule):
    def __init__(self, env, model, lr=1e-4, batch_size=128, train_size=1000, val_size=10000):
        super().__init__()

        # TODO: hydra instantiation
        self.env = env
        self.model = model
        self.lr = lr
        self.batch_size = batch_size
        self.train_size = train_size
        self.val_size = val_size

    def setup(self, stage="fit"):
        self.train_dataset = self.env.dataset(self.train_size)
        self.val_dataset = self.env.dataset(self.val_size)
        if hasattr(self.model, "setup"):
            self.model.setup(self)

    def shared_step(self, batch: Any, batch_idx: int, phase: str):
        td = self.env.reset(init_obs=batch)
        output = self.model(td, phase)
        
        # output = self.model(batch, phase)
        self.log(f"{phase}/cost", -output["reward"].mean(), prog_bar=True)
        self.log(f"{phase}/pomo_cost", -output["max_reward"].mean(), prog_bar=True)
        if phase != "train" and self.model.num_augment > 1:
            self.log(f"{phase}/aug_cost", -output["max_aug_reward"].mean(), prog_bar=True)
        
        return {"loss": output['loss']}

    def training_step(self, batch: Any, batch_idx: int):   
        return self.shared_step(batch, batch_idx, phase='train')

    def validation_step(self, batch: Any, batch_idx: int):
        return self.shared_step(batch, batch_idx, phase='val')

    def test_step(self, batch: Any, batch_idx: int):
        return self.shared_step(batch, batch_idx, phase='test')

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=1e-6)
        # optim = Lion(model.parameters(), lr=1e-4, weight_decay=1e-2)
        # TODO: scheduler
        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, total_steps)
        return [optim] #, [scheduler]
    
    def train_dataloader(self):
        return self._dataloader(self.train_dataset)
    
    def val_dataloader(self):
        return self._dataloader(self.val_dataset)
    
    def on_train_epoch_end(self):
        if hasattr(self.model, "on_train_epoch_end"):
            self.model.on_train_epoch_end(self)
        self.train_dataset = self.env.dataset(self.train_size) 
       
    def _dataloader(self, dataset):
        return DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=False, # no need to shuffle, we're resampling every epoch
            num_workers=0,
            collate_fn=torch.stack, # we need this to stack the batches in the dataset
            pin_memory=self.on_gpu,
        )

## Main training setup

In [19]:
# Hyperparameters
epochs = 1
batch_size = 64 #1024 #512
num_loc = 20
train_size = 1280000
lr = 1e-4
num_pomo = num_loc # TODO: comment to try out = 1
# num_pomo = 1 # set to 1: similar to simple AM

# Environment
env = TSPEnv(num_loc=num_loc).transform()

# Policy
policy = POMOPolicy(
    env,
    num_pomo=num_pomo,
    embedding_dim=128,
    hidden_dim=128,
    n_encode_layers=3,
    # force_flash_attn=True,
)

# Baseline
# baseline = WarmupBaseline(RolloutBaseline())
baseline = SharedBaseline() # TODO: uncomment

# Create RL model
model = POMO(env, policy, baseline)

# Create Lightning module (for training)
lit_model = NCOLightningModule(env, model, batch_size=batch_size, train_size=train_size, lr=lr)

## Fit model

In [20]:
# Trick to make calculations faster
torch.set_float32_matmul_precision("medium")

# Trainer
trainer = L.Trainer(
    max_epochs=epochs,
    accelerator="gpu",
    devices=[1],
    logger=None, # can replace with WandbLogger, TensorBoardLogger, etc.
    # precision=16, # uncomment to make faster
    log_every_n_steps=100,   
    gradient_clip_val=1.0, # clip gradients to avoid exploding gradients!
)

# Fit the model
trainer.fit(lit_model)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type   | Params
---------------------------------
0 | env   | TSPEnv | 0     
1 | model | POMO   | 710 K 
---------------------------------
710 K     Trainable params
0         Non-trainable params
710 K     Total params
2.841     Total estimated model params size (MB)
2023-04-18 23:01:33.494062: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00,  6.53it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 0:   4%|▍         | 843/20000 [00:49<18:45, 17.02it/s, v_num=76, train/cost=4.130, train/pomo_cost=3.890]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [21]:
trainer.validate(lit_model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Validation DataLoader 0: 100%|██████████| 157/157 [00:07<00:00, 19.85it/s]


[{'val/cost': 4.131106376647949,
  'val/pomo_cost': 3.897179365158081,
  'val/aug_cost': 3.8515453338623047}]