In [1]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=1
import torch
from torch import nn
from lightning.pytorch.callbacks import ModelCheckpoint, RichModelSummary
from lightning.pytorch.loggers import WandbLogger
torch.cuda.device_count()

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=1


1

In [2]:
from rl4co.envs import SSPEnv
from rl4co.models.zoo.am import AttentionModelPolicy, AttentionModel
from rl4co.utils.trainer import RL4COTrainer
from rl4co.utils.decoding import random_policy, rollout
from rl4co.utils.ops import gather_by_index

In [3]:
class SSPInitEmbedding(nn.Module):

    def __init__(self, embedding_dim, fixed_len, linear_bias=True):
        super(SSPInitEmbedding, self).__init__()
        node_dim = fixed_len  # x, y
        self.init_embed = nn.Linear(node_dim, embedding_dim, linear_bias)

    def forward(self, td):
        out = self.init_embed(td["codes"])
        return out

class SSPContext(nn.Module):
    """Context embedding for the Traveling Salesman Problem (TSP).
    Project the following to the embedding space:
        - first node embedding
        - current node embedding
    """

    def __init__(self, embedding_dim,  linear_bias=True):
        super(SSPContext, self).__init__()
        self.W_placeholder = nn.Parameter(
            torch.Tensor(embedding_dim).uniform_(-1, 1)
        )
        self.project_context = nn.Linear(
            embedding_dim, embedding_dim, bias=linear_bias
        )

    def forward(self, embeddings, td):
        batch_size = embeddings.size(0)
        # By default, node_dim = -1 (we only have one node embedding per node)
        node_dim = (
            (-1,) if td["current_node"].dim() == 1 else (td["current_node"].size(-1), -1)
        )
        if td["i"][(0,) * td["i"].dim()].item() < 1:  # get first item fast
            context_embedding = self.W_placeholder[None, :].expand(
                batch_size, self.W_placeholder.size(-1)
            )
        else:
            context_embedding = gather_by_index(
                embeddings,
                torch.stack([td["current_node"]], -1).view(
                    batch_size, -1
                ),
            ).view(batch_size, *node_dim)
        return self.project_context(context_embedding)
        
class StaticEmbedding(nn.Module):
    def __init__(self, *args, **kwargs):
        super(StaticEmbedding, self).__init__()

    def forward(self, td):
        return 0, 0, 0

In [16]:
num_loc = 100
fixed_len = 15
emb_dim = 128

env = SSPEnv(generator_params={"num_loc":num_loc,
                              "fixed_len":fixed_len},
            test_file = "data_ssp.npz")

checkpoint_path = "/home/yining/ssp/rl4co/checkpoints_ssp_old/epoch_epoch=813.ckpt"

# Model: default is AM with REINFORCE and greedy rollout baseline
model = AttentionModel.load_from_checkpoint(checkpoint_path)
model.env = env
model.data_cfg["test_batch_size"] = 10000
model.data_cfg["test_data_size"] = 10000

from rl4co.utils.trainer import RL4COTrainer

# We use our own wrapper around Lightning's `Trainer` to make it easier to use
trainer = RL4COTrainer(max_epochs=1000, 
                       accelerator = 'gpu', 
                       devices=1,   
                       # logger=logger,
                       # callbacks=callbacks,
                      )

out = trainer.test(model)

val_file not set. Generating dataset instead
test_file not set. Generating dataset instead
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
val_file not set. Generating dataset instead
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]


Testing: |                                                                                        | 0/? [00:00…

In [21]:
import numpy as np
ks = [1, 2, 3, 4, 5, 10, 14, 15][::-1]
time = ['11m57s', '6m47s',  '3m58s','2m54s','1m59s', '34s', '3s', '2s'][::-1]
RL = out[0]["test/reward"] * env.generator.num_loc * -1
print(f"RL constructive length:\t {RL:.2f},\tgap: {0:.2f}%,\t time: <1s \t(GPU in parallel)")
print('-' * 90)

for k, t in zip(ks, time):
    greedy_baseline = np.load(f'greedy_{k}-mers_output.npz')["arr_0"].mean()
    gap = (greedy_baseline - RL) / RL * 100
    print(f"{k}-mers-greedy length:\t {greedy_baseline:.2f},\tgap: {gap:.2f}%,\t time: {t} \t(CPU in series)")

RL constructive length:	 923.03,	gap: 0.00%,	 time: <1s 	(GPU in parallel)
------------------------------------------------------------------------------------------
15-mers-greedy length:	 1497.43,	gap: 62.23%,	 time: 2s 	(CPU in series)
14-mers-greedy length:	 1488.60,	gap: 61.27%,	 time: 3s 	(CPU in series)
10-mers-greedy length:	 1322.88,	gap: 43.32%,	 time: 34s 	(CPU in series)
5-mers-greedy length:	 849.76,	gap: -7.94%,	 time: 1m59s 	(CPU in series)
4-mers-greedy length:	 824.67,	gap: -10.66%,	 time: 2m54s 	(CPU in series)
3-mers-greedy length:	 812.39,	gap: -11.99%,	 time: 3m58s 	(CPU in series)
2-mers-greedy length:	 806.99,	gap: -12.57%,	 time: 6m47s 	(CPU in series)
1-mers-greedy length:	 805.01,	gap: -12.79%,	 time: 11m57s 	(CPU in series)


In [13]:
from tensordict.tensordict import TensorDict
from tqdm import tqdm
td = TensorDict({"codes": torch.tensor(np.load("data/data_ssp.npz")["codes"])}, batch_size = 1000)

In [14]:
# Sampling rollouts over trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
td_init = model.env.reset(td).to(device)
model = model.to(device)

rewards_best = None
with torch.no_grad():
    for i in tqdm(range(400)):
        out = model(td_init.clone(), phase="test", decode_type="sampling", return_actions=False)
        rewards_now = out['reward'].cpu().detach()
        if rewards_best is None:
            rewards_best = rewards_now
        else:
            rewards_best = torch.max(rewards_now, rewards_best)

        if i % 100 == 99:
            obj = -rewards_best.mean().numpy() * env.generator.num_loc
            print('Sampling:', i+1, 'length:', obj,  'gap:', (obj - RL) / RL * 100, '%')

 25%|████████████████████▎                                                            | 100/400 [00:38<01:56,  2.59it/s]

Sampling: 100 length: 899.8259544372559 gap: -2.513375100311315 %


 50%|████████████████████████████████████████▌                                        | 200/400 [01:18<01:23,  2.41it/s]

Sampling: 200 length: 897.3359107971191 gap: -2.7831449920696327 %


 75%|████████████████████████████████████████████████████████████▊                    | 300/400 [02:00<00:42,  2.35it/s]

Sampling: 300 length: 896.1589813232422 gap: -2.910652852436902 %


100%|█████████████████████████████████████████████████████████████████████████████████| 400/400 [02:43<00:00,  2.45it/s]

Sampling: 400 length: 895.3390121459961 gap: -2.9994878401557084 %



