# Attention Model Lightning

In [1]:
%load_ext autoreload
%autoreload 2

import sys; sys.path.append('../../')

from typing import Any

import torch
from torch.utils.data import DataLoader
import lightning as L

from ncobench.models.am import AttentionModel
from ncobench.models.common.am_base import AttentionModelBase
from ncobench.models.common.reinforce_baselines import *
from ncobench.envs.tsp import TSPEnv
from ncobench.data.dataset import TorchDictDataset

  warn(


## Test `AttentionModelBase`

Here we test the `AttentionModelBase` class: simple forward pass through the model

The `AttentionModelBase` includes only the single forward pass through an environment: given initial conditions, find policy.
We define the REINFORCE baseline and loss functions in final `AttentionModel`

In [2]:
device = 'cuda'

env = TSPEnv(n_loc=20)
env = env.transform()

data = env.gen_params(batch_size=[10000]) # NOTE: need to put batch_size in a list!!
init_td = env.reset(data)
dataset = TorchDictDataset(init_td)

dataloader = DataLoader(
                dataset,
                batch_size=128,
                shuffle=False, # no need to shuffle, we're resampling every epoch
                num_workers=0,
                collate_fn=torch.stack, # we need this to stack the batches in the dataset
            )


model = AttentionModelBase(
    env,
    embedding_dim=128,
    hidden_dim=128,
    n_encode_layers=3,
).to(device)

# model = torch.compile(model, backend="cuda") # Torch 2.x

x = next(iter(dataloader)).to(device)

out = model(x, decode_type="sampling")

import tqdm.auto as tqdm

res = []
for x in dataloader:
    x = x.to("cuda")
    res.append(- model(x, decode_type="sampling")['reward'])


print(torch.cat(res).mean())

  from .autonotebook import tqdm as notebook_tqdm


tensor(9.0525, device='cuda:0')


## `AttentionModel` class

Here we include the REINFORCE baseline and loss functions

In [3]:
env = TSPEnv(n_loc=20).transform() # we transform to get easy observations
policy = AttentionModelBase(
    env,
    embedding_dim=128,
    hidden_dim=128,
    n_encode_layers=3,
)
baseline = WarmupBaseline(baseline=RolloutBaseline())

model = AttentionModel(env, policy, baseline)

## Lightning Module

This training loop deals with the training of the model as well as many other goodies - such as logging, checkpointing, device management, etc.

Note that the following will be done automatically with Hydra+Lightning

In [4]:
class NCOLitModule(L.LightningModule):
    def __init__(self, env, model, lr=1e-4, batch_size=128, train_size=1000, val_size=10000):
        super().__init__()

        # TODO: hydra instantiation
        self.env = env
        self.model = model
        self.lr = lr
        self.batch_size = batch_size
        self.train_size = train_size
        self.val_size = val_size
        self.setup()

    def setup(self, stage="fit"):
        self.train_dataset = self.get_observation_dataset(self.train_size)
        self.val_dataset = self.get_observation_dataset(self.val_size)
        if hasattr(self.model, "setup"):
            self.model.setup(self)

    def shared_step(self, batch: Any, batch_idx: int, phase: str):
        td = self.env.reset(init_observation=batch)
        output = self.model(td, phase)
        
        # output = self.model(batch, phase)
        self.log(f"{phase}/cost", output["cost"].mean(), prog_bar=True)
        return {"loss": output['loss']}

    def training_step(self, batch: Any, batch_idx: int):    
        return self.shared_step(batch, batch_idx, phase='train')

    def validation_step(self, batch: Any, batch_idx: int):
        return self.shared_step(batch, batch_idx, phase='val')

    def test_step(self, batch: Any, batch_idx: int):
        return self.shared_step(batch, batch_idx, phase='test')

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=1e-5)
        # TODO: scheduler
        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, total_steps)
        return [optim] #, [scheduler]
    
    def train_dataloader(self):
        return self._dataloader(self.train_dataset)
    
    def val_dataloader(self):
        return self._dataloader(self.val_dataset)
    
    def on_train_epoch_end(self):
        if hasattr(self.model, "on_train_epoch_end"):
            self.model.on_train_epoch_end(self)
        self.train_dataset = self.get_observation_dataset(self.train_size) 

    def get_observation_dataset(self, size):
        # online data generation: we generate a new batch online
        data = self.env.gen_params(batch_size=size)
        return TorchDictDataset(self.env.reset(data)['observation'])
       
    def _dataloader(self, dataset):
        return DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=False, # no need to shuffle, we're resampling every epoch
            num_workers=0,
            collate_fn=torch.stack, # we need this to stack the batches in the dataset
        )

## Main training loop

Here we define the main training loop


In [5]:
epochs = 100
batch_size = 512
train_size = 1280000
lr = 1e-4

task = NCOLitModule(env, model, batch_size=batch_size, train_size=train_size, lr=lr)

# Trick to make calculations faster
torch.set_float32_matmul_precision("medium")

# Wandb Logger - we can use others as well as simply `None`
# logger = pl.loggers.WandbLogger(project="tsp", name="am")
# logger = L.loggers.CSVLogger("logs", name="tsp")
logger = None # comment to insert logger


# Trainer
trainer = L.Trainer(
    max_epochs=epochs,
    accelerator="gpu",
    devices=1,
    logger=logger, 
    log_every_n_steps=1,   
    gradient_clip_val=1.0, # clip gradients to avoid exploding gradients
)

# Fit the model
trainer.fit(task)

Evaluating baseline model on evaluation dataset


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Evaluating baseline model on evaluation dataset


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type           | Params
-----------------------------------------
0 | env   | TransformedEnv | 0     
1 | model | AttentionModel | 1.4 M 
-----------------------------------------
1.4 M     Trainable params
0         Non-trainable params
1.4 M     Total params
5.681     Total estimated model params size (MB)


                                                                           

  rank_zero_warn(
  rank_zero_warn(


Epoch 0:   9%|▊         | 216/2500 [00:17<03:03, 12.44it/s, v_num=31, train/cost=4.190]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
