# PointerNet Lightning

In [1]:
%load_ext autoreload
%autoreload 2

import sys; sys.path.append('../../')

import math
from typing import List, Tuple, Optional, NamedTuple, Dict, Union, Any
from einops import rearrange, repeat
from hydra.utils import instantiate

import torch
from torch import nn
from torch.utils.checkpoint import checkpoint
from torch.nn import DataParallel
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import lightning as L

from torchrl.envs import EnvBase
from torchrl.envs.utils import step_mdp
from tensordict import TensorDict

from ncobench.data.dataset import TorchDictDataset

from ncobench.envs.tsp import TSPEnv
from ncobench.models.rl.reinforce import *
from ncobench.models.components.am.context import env_context
from ncobench.models.components.am.embeddings import env_init_embedding, env_dynamic_embedding
from ncobench.models.components.am.encoder import GraphAttentionEncoder
from ncobench.models.components.am.decoder import Decoder, decode_probs, PrecomputedCache, LogitAttention
from ncobench.models.components.am.policy import get_log_likelihood
from ncobench.models.nn.attention import NativeFlashMHA, flash_attn_wrapper
from ncobench.utils.lightning import get_lightning_device

  warn(


In [2]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import math
import numpy as np


class Encoder(nn.Module):
    """Maps a graph represented as an input sequence
    to a hidden vector"""
    def __init__(self, input_dim, hidden_dim):
        super(Encoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.lstm = nn.LSTM(input_dim, hidden_dim)
        self.init_hx, self.init_cx = self.init_hidden(hidden_dim)

    def forward(self, x, hidden):
        output, hidden = self.lstm(x, hidden)
        return output, hidden
    
    def init_hidden(self, hidden_dim):
        """Trainable initial hidden state"""
        std = 1. / math.sqrt(hidden_dim)
        enc_init_hx = nn.Parameter(torch.FloatTensor(hidden_dim))
        enc_init_hx.data.uniform_(-std, std)

        enc_init_cx = nn.Parameter(torch.FloatTensor(hidden_dim))
        enc_init_cx.data.uniform_(-std, std)
        return enc_init_hx, enc_init_cx

In [3]:
class Attention(nn.Module):
    """A generic attention module for a decoder in seq2seq"""
    def __init__(self, dim, use_tanh=False, C=10):
        super(Attention, self).__init__()
        self.use_tanh = use_tanh
        self.project_query = nn.Linear(dim, dim)
        self.project_ref = nn.Conv1d(dim, dim, 1, 1)
        self.C = C  # tanh exploration
        self.tanh = nn.Tanh()

        self.v = nn.Parameter(torch.FloatTensor(dim))
        self.v.data.uniform_(-(1. / math.sqrt(dim)), 1. / math.sqrt(dim))
        
    def forward(self, query, ref):
        """
        Args: 
            query: is the hidden state of the decoder at the current
                time step. batch x dim
            ref: the set of hidden states from the encoder. 
                sourceL x batch x hidden_dim
        """
        # ref is now [batch_size x hidden_dim x sourceL]
        ref = ref.permute(1, 2, 0)
        q = self.project_query(query).unsqueeze(2)  # batch x dim x 1
        e = self.project_ref(ref)  # batch_size x hidden_dim x sourceL 
        # expand the query by sourceL
        # batch x dim x sourceL
        expanded_q = q.repeat(1, 1, e.size(2)) 
        # batch x 1 x hidden_dim
        v_view = self.v.unsqueeze(0).expand(
                expanded_q.size(0), len(self.v)).unsqueeze(1)
        # [batch_size x 1 x hidden_dim] * [batch_size x hidden_dim x sourceL]
        u = torch.bmm(v_view, self.tanh(expanded_q + e)).squeeze(1)
        if self.use_tanh:
            logits = self.C * self.tanh(u)
        else:
            logits = u  
        return e, logits

In [4]:
class Decoder(nn.Module):
    def __init__(self, 
            embedding_dim: int = 128,
            hidden_dim: int = 128,
            tanh_exploration: float = 10.0,
            use_tanh: bool = True,
            n_glimpses=1,
            mask_glimpses=True,
            mask_logits=True):
        super(Decoder, self).__init__()

        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.n_glimpses = n_glimpses
        self.mask_glimpses = mask_glimpses
        self.mask_logits = mask_logits
        self.use_tanh = use_tanh
        self.tanh_exploration = tanh_exploration

        self.lstm = nn.LSTMCell(embedding_dim, hidden_dim)
        self.pointer = Attention(hidden_dim, use_tanh=use_tanh, C=tanh_exploration)
        self.glimpse = Attention(hidden_dim, use_tanh=False)
        self.sm = nn.Softmax(dim=1)

    def update_mask(self, mask, selected):
        return mask.clone().scatter_(1, selected.unsqueeze(-1), True)

    def recurrence(self, x, h_in, prev_mask, prev_idxs, step, context):

        logit_mask = self.update_mask(prev_mask, prev_idxs) if prev_idxs is not None else prev_mask

        logits, h_out = self.calc_logits(x, h_in, logit_mask, context, self.mask_glimpses, self.mask_logits)

        # Calculate log_softmax for better numerical stability
        log_p = torch.log_softmax(logits, dim=1)
        probs = log_p.exp()

        if not self.mask_logits:
            # If self.mask_logits, this would be redundant, otherwise we must mask to make sure we don't resample
            # Note that as a result the vector of probs may not sum to one (this is OK for .multinomial sampling)
            # But practically by not masking the logits, a model is learned over all sequences (also infeasible)
            # while only during sampling feasibility is enforced (a.k.a. by setting to 0. here)
            probs[logit_mask] = 0.
            # For consistency we should also mask out in log_p, but the values set to 0 will not be sampled and
            # Therefore not be used by the reinforce estimator

        return h_out, log_p, probs, logit_mask

    def calc_logits(self, x, h_in, logit_mask, context, mask_glimpses=None, mask_logits=None):

        if mask_glimpses is None:
            mask_glimpses = self.mask_glimpses

        if mask_logits is None:
            mask_logits = self.mask_logits

        hy, cy = self.lstm(x, h_in)
        g_l, h_out = hy, (hy, cy)

        for i in range(self.n_glimpses):
            ref, logits = self.glimpse(g_l, context)
            # For the glimpses, only mask before softmax so we have always an L1 norm 1 readout vector
            if mask_glimpses:
                logits[logit_mask] = float('-inf')
            # [batch_size x h_dim x sourceL] * [batch_size x sourceL x 1] =
            # [batch_size x h_dim x 1]
            g_l = torch.bmm(ref, self.sm(logits).unsqueeze(2)).squeeze(2)
        _, logits = self.pointer(g_l, context)

        # Masking before softmax makes probs sum to one
        if mask_logits:
            logits[logit_mask] = float('-inf')

        return logits, h_out

    def forward(self, decoder_input, embedded_inputs, hidden, context, decode_type="sampling", eval_tours=None):
        """
        Args:
            decoder_input: The initial input to the decoder
                size is [batch_size x embedding_dim]. Trainable parameter.
            embedded_inputs: [sourceL x batch_size x embedding_dim]
            hidden: the prev hidden state, size is [batch_size x hidden_dim]. 
                Initially this is set to (enc_h[-1], enc_c[-1])
            context: encoder outputs, [sourceL x batch_size x hidden_dim] 
        """

        batch_size = context.size(1)
        outputs = []
        selections = []
        steps = range(embedded_inputs.size(0))
        idxs = None
        mask = Variable(
            embedded_inputs.data.new().byte().new(embedded_inputs.size(1), embedded_inputs.size(0)).zero_(),
            requires_grad=False
        )

        for i in steps:
            hidden, log_p, probs, mask = self.recurrence(decoder_input, hidden, mask, idxs, i, context)
            # select the next inputs for the decoder [batch_size x hidden_dim]
            idxs = decode_probs(
                probs,
                mask,
                decode_type=decode_type
            ) if eval_tours is None else eval_tours[:, i]

            idxs = idxs.detach()  # Otherwise pytorch complains it want's a reward, todo implement this more properly?

            # Gather input embedding of selected
            decoder_input = torch.gather(
                embedded_inputs,
                0,
                idxs.contiguous().view(1, batch_size, 1).expand(1, batch_size, *embedded_inputs.size()[2:])
            ).squeeze(0)

            # use outs to point to next object
            outputs.append(log_p)
            selections.append(idxs)
        return (torch.stack(outputs, 1), torch.stack(selections, 1)), hidden

In [16]:
class PointerNetworkPolicy(nn.Module):

    def __init__(self,
                 env,
                 embedding_dim: int=128,
                 hidden_dim: int=128,
                 tanh_clipping=10.,
                 mask_inner=True,
                 mask_logits=True,
                 **kwargs):
        super(PointerNetworkPolicy, self).__init__()

        self.env = env
        assert self.env.name == "tsp", "Only the Euclidean TSP env supported"

        self.input_dim = 2

        self.encoder = Encoder(
            embedding_dim,
            hidden_dim)

        self.decoder = Decoder(
            embedding_dim,
            hidden_dim,
            tanh_exploration=tanh_clipping,
            use_tanh=tanh_clipping > 0,
            n_glimpses=1,
            mask_glimpses=mask_inner,
            mask_logits=mask_logits
        )

        # Trainable initial hidden states
        std = 1. / math.sqrt(embedding_dim)
        self.decoder_in_0 = nn.Parameter(torch.FloatTensor(embedding_dim))
        self.decoder_in_0.data.uniform_(-std, std)

        self.embedding = nn.Parameter(torch.FloatTensor(self.input_dim, embedding_dim))
        self.embedding.data.uniform_(-std, std)

    def forward(self, td, phase: str = "train", decode_type="sampling", eval_tours=None):
        batch_size, graph_size, input_dim = td['observation'].size()

        embedded_inputs = torch.mm(
            td['observation'].transpose(0, 1).contiguous().view(-1, input_dim),
            self.embedding
        ).view(graph_size, batch_size, -1)

        # query the actor net for the input indices 
        # making up the output, and the pointer attn 
        _log_p, actions = self._inner(embedded_inputs, decode_type, eval_tours)

        reward = self.env.get_reward(td["observation"], actions)

        # Log likelyhood is calculated within the model since returning it per action does not work well with
        # DataParallel since sequences can be of different lengths
        ll = self._calc_log_likelihood(_log_p, actions, td.get("mask", None))

        out = {"reward": reward, "log_likelihood": ll, "actions": actions}
        return out
    
    def _calc_log_likelihood(self, _log_p, a, mask):

        # Get log_p corresponding to selected actions
        log_p = _log_p.gather(2, a.unsqueeze(-1)).squeeze(-1)

        # Optional: mask out actions irrelevant to objective so they do not get reinforced
        if mask is not None:
            log_p[mask] = 0

        assert (log_p > -1000).data.all(), "Logprobs should not be -inf, check sampling procedure!"

        # Calculate log_likelihood
        return log_p.sum(1)

    def _inner(self, inputs, decode_type="sampling", eval_tours=None):

        encoder_hx = encoder_cx = Variable(
            torch.zeros(1, inputs.size(1), self.encoder.hidden_dim, out=inputs.data.new()),
            requires_grad=False
        )

        # encoder forward pass
        enc_h, (enc_h_t, enc_c_t) = self.encoder(inputs, (encoder_hx, encoder_cx))

        dec_init_state = (enc_h_t[-1], enc_c_t[-1])

        # repeat decoder_in_0 across batch
        decoder_input = self.decoder_in_0.unsqueeze(0).repeat(inputs.size(1), 1)

        (pointer_probs, input_idxs), dec_hidden_t = self.decoder(decoder_input,
                                                                 inputs,
                                                                 dec_init_state,
                                                                 enc_h,
                                                                 decode_type,
                                                                 eval_tours)

        return pointer_probs, input_idxs

In [17]:
num_loc = 15
env = TSPEnv(num_loc=num_loc).transform()

# data = env.gen_params(batch_size=[10000]) # NOTE: need to put batch_size in a list!!
# init_td = env.reset(data)
# env.batch_size = [10000]
init_td = env.reset(batch_size=[10000])
dataset = TorchDictDataset(init_td)


dataloader = DataLoader(
                dataset,
                batch_size=32,
                shuffle=False, # no need to shuffle, we're resampling every epoch
                num_workers=0,
                collate_fn=torch.stack, # we need this to stack the batches in the dataset
            )

model = PointerNetworkPolicy(
    env,
    embedding_dim=128,
    hidden_dim=128,
    n_encode_layers=3,
).to("cuda")

# model = torch.compile(model)

x = next(iter(dataloader)).to("cuda")

out = model(x, decode_type="sampling")

  logits[logit_mask] = float('-inf')
  logits[logit_mask] = float('-inf')


In [18]:
# torch.Size([15, 32, 15]) torch.Size([32])
# size of a, and idx

def gather_by_index(source, index, dim=0):
    target = torch.gather(
        source, dim, index[:, None, None].expand(-1, -1, source.size(-1))
    )
    return target

a = torch.randn(15, 32, 15)
idx = torch.randint(0, 15, (32,))
# get idx so that out is 1, 32, 15
out = gather_by_index(a, idx, dim=0)

out.shape
# a.gather(0, idx[None, :, None]).shape

torch.Size([32, 1, 15])

In [22]:
class PointerNetwork(nn.Module):
    def __init__(self, env, policy):
        super().__init__()
        self.env = env
        self.policy = policy

        # TODO: hydra instantiation
        # self.policy = instantiate(cfg.policy)
        # self.baseline = instantiate(cfg.baseline) TODO

    def forward(self, td: TensorDict, phase: str="train", decode_type: str=None) -> TensorDict:
        """Evaluate model, get costs and log probabilities and compare with baseline"""

        # Evaluate modelim=0, get costs and log probabilities
        out_policy = self.policy(td)
        cost = -out_policy['reward']
        ll = out_policy['log_likelihood']

        # Calculate loss
        bl_val, bl_loss = self.baseline.eval(td, cost)

        advantage = cost - bl_val
        reinforce_loss = (advantage * ll).mean()
        loss = reinforce_loss + bl_loss

        return {'loss': loss, 'reinforce_loss': reinforce_loss, 'bl_loss': bl_loss, 'bl_val': bl_val, **out_policy}
    
    # def setup(self, lit_module):
    #     # Make baseline taking model itself and train_dataloader from model as input
    #     # TODO make this as taken from config
    #     self.baseline = instantiate({"_target_": "__main__.SharedBaseline"})

    def setup(self, lit_module):
        # Make baseline taking model itself and train_dataloader from model as input
        # TODO make this as taken from config
        self.baseline = instantiate({"_target_": "__main__.WarmupBaseline",
                                    "baseline": {"_target_": "__main__.RolloutBaseline",                                             }
                                    })  
        # self.baseline = instantiate({"_target_": "__main__.SharedBaseline"})

        self.baseline.setup(self.policy, lit_module.val_dataloader(), self.env, device=get_lightning_device(lit_module))         
        # self.baseline = NoBaseline()

    def on_train_epoch_end(self, lit_module):
        # self.baseline.epoch_callback(self.policy, self.env, pl_module)
        self.baseline.epoch_callback(self.policy, lit_module.val_dataloader(), lit_module.current_epoch, self.env, device=get_lightning_device(lit_module))

## Lightning Module

In [23]:
class NCOLightningModule(L.LightningModule):
    def __init__(self, env, model, lr=1e-4, batch_size=128, train_size=1000, val_size=10000):
        super().__init__()

        # TODO: hydra instantiation
        self.env = env
        self.model = model
        self.lr = lr
        self.batch_size = batch_size
        self.train_size = train_size
        self.val_size = val_size

    def setup(self, stage="fit"):
        self.train_dataset = self.get_dataset(self.train_size)
        self.val_dataset = self.get_dataset(self.val_size)
        if hasattr(self.model, "setup"):
            self.model.setup(self)

    def shared_step(self, batch: Any, batch_idx: int, phase: str):
        td = self.env.reset(batch)
        output = self.model(td, phase)
        
        # output = self.model(batch, phase)
        self.log(f"{phase}/cost", -output["reward"].mean(), prog_bar=True)
        return {"loss": output['loss']}

    def training_step(self, batch: Any, batch_idx: int):   
        return self.shared_step(batch, batch_idx, phase='train')

    def validation_step(self, batch: Any, batch_idx: int):
        return self.shared_step(batch, batch_idx, phase='val')

    def test_step(self, batch: Any, batch_idx: int):
        return self.shared_step(batch, batch_idx, phase='test')

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=1e-6)
        # optim = Lion(model.parameters(), lr=1e-4, weight_decay=1e-2)
        # TODO: scheduler
        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, total_steps)
        return [optim] #, [scheduler]
    
    def train_dataloader(self):
        return self._dataloader(self.train_dataset)
    
    def val_dataloader(self):
        return self._dataloader(self.val_dataset)
    
    def on_train_epoch_end(self):
        if hasattr(self.model, "on_train_epoch_end"):
            self.model.on_train_epoch_end(self)
        self.train_dataset = self.get_dataset(self.train_size) 

    def get_dataset(self, size):
        # online data generation: we generate a new batch online
        # data = self.env.gen_params(batch_size=size)
        return TorchDictDataset(self.env.generate_data(size))
       
    def _dataloader(self, dataset):
        return DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=False, # no need to shuffle, we're resampling every epoch
            num_workers=0,
            collate_fn=torch.stack, # we need this to stack the batches in the dataset
            pin_memory=self.on_gpu,
        )

In [12]:
from rich.traceback import install
install()

<bound method InteractiveShell.excepthook of <ipykernel.zmqshell.ZMQInteractiveShell object at 0x7fd5a8d59e70>>

In [24]:
num_loc = 20
env = TSPEnv(num_loc=num_loc).transform()

# env = env.transform()
policy = PointerNetworkPolicy(
    env,
    num_pomo=num_loc,
    # num_pomo=1,
    embedding_dim=128,
    hidden_dim=128,
    n_encode_layers=3,
    # force_flash_attn=True,
)

model_final = PointerNetwork(env, policy)

# # TODO CHANGE THIS
batch_size = 64 #1024 #512

model = NCOLightningModule(env, model_final, batch_size=batch_size, train_size=1280000, lr=1e-4)

# Trick to make calculations faster
torch.set_float32_matmul_precision("medium")

# Wandb Logger - we can use others as well as simply `None`
# logger = pl.loggers.WandbLogger(project="torchrl", name="pendulum")
# logger = L.loggers.CSVLogger("logs", name="tsp")

epochs = 1


# from lightning.pytorch.callbacks import DeviceStatsMonitor
# callbacks = [DeviceStatsMonitor()]

from lightning.pytorch.profilers import AdvancedProfiler

profiler = AdvancedProfiler(dirpath=".", filename="perf_logsv2")

# Trainer
trainer = L.Trainer(
    max_epochs=epochs,
    accelerator="gpu",
    devices=[1],
    # callbacks=callbacks,
    # profiler=profiler,
    # strategy="deepspeed_stage_3_offload",
    # precision=16,
    log_every_n_steps=100,   
    gradient_clip_val=1.0, # clip gradients to avoid exploding gradients
)

# Fit the model
trainer.fit(model)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  logits[logit_mask] = float('-inf')
  logits[logit_mask] = float('-inf')
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type           | Params
-----------------------------------------
0 | env   | TSPEnv         | 0     
1 | model | PointerNetwork | 662 K 
-----------------------------------------
662 K     Trainable params
0         Non-trainable params
662 K     Total params
2.649     Total estimated model params size (MB)


Epoch 0:   3%|▎         | 588/20000 [00:39<21:45, 14.87it/s, v_num=3, train/cost=7.650]