In [64]:
import torch

from rl4co.envs import ATSPkoptEnv
from rl4co.models import NeuOptPolicy, NeuOpt
from rl4co.utils.trainer import RL4COTrainer


In [65]:
import torch.nn as nn

class CustomizeATSPInitEmbedding(nn.Module):
    """Initial embedding for the Asymmetric Traveling Salesman Problems (ATSP).
    Embed the following node features to the embedding space:
        - dists: distance matrix of the cities
    """

    def __init__(self, embed_dim, num_loc, linear_bias=True):
        super(CustomizeATSPInitEmbedding, self).__init__()
        node_dim = num_loc # number of locations (customers)
        self.init_embed = nn.Sequential(
            nn.Linear(node_dim, embed_dim // 2, linear_bias),
            nn.ReLU(inplace=True),
            nn.Linear(embed_dim // 2, embed_dim, linear_bias),
        )

    def forward(self, td):
        out = self.init_embed(td["cost_matrix"])
        return out

num_loc = 20  # Number of locations (customers)
embed_dim = 128  # Dimension of the embedding space

env = ATSPkoptEnv(generator_params=dict(num_loc=num_loc, init_sol_type="greedy"), k_max=4)
model = NeuOpt(
        env,
        train_data_size=10,
        val_data_size=10,
        test_data_size=10,
        n_step=2,
        T_train=4,
        T_test=4,
        CL_best=True,
        policy_kwargs=dict(
            embed_dim=embed_dim,
            init_embedding=CustomizeATSPInitEmbedding(num_loc=num_loc, embed_dim=embed_dim),
        ),
    )

In [66]:
# Greedy rollouts over untrained policy
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

td_init = env.reset(batch_size=[2]).to(device)


policy = model.policy.to(device)
out = None
for i in range(20):
    out = policy(td_init.clone(), env=env, phase="test", decode_type="greedy", return_actions=True)
    td_init["action"] = out['actions']
    env.step(td_init)
    print(f"Cost: {-out['cost_bsf'][0].cpu().detach():.3f}")

Cost: -2.345
Cost: -2.345
Cost: -2.345
Cost: -2.345
Cost: -2.345
Cost: -2.345
Cost: -2.345
Cost: -2.345
Cost: -2.345
Cost: -2.345
Cost: -2.345
Cost: -2.345
Cost: -2.345
Cost: -2.345
Cost: -2.345
Cost: -2.345
Cost: -2.345
Cost: -2.345
Cost: -2.345
Cost: -2.345


In [67]:
trainer = RL4COTrainer(
    max_epochs=3,
    gradient_clip_val=0.05,
    devices=1,
    accelerator="auto",
)

Using 16bit Automatic Mixed Precision (AMP)
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [None]:
# Fit and test the model
trainer.fit(model)
trainer.test(model)

val_file not set. Generating dataset instead
test_file not set. Generating dataset instead


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name   | Type          | Params | Mode
------------------------------------------------
0 | env    | ATSPkoptEnv   | 0      | eval
1 | policy | NeuOptPolicy  | 683 K  | eval
2 | critic | CriticNetwork | 140 K  | eval
------------------------------------------------
823 K     Trainable params
0         Non-trainable params
823 K     Total params
3.295     Total estimated model params size (MB)
0         Modules in train mode
94        Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA_gather)

In [None]:


policy = model.policy.to(device)
out = None
for i in range(20):
    out = policy(td_init.clone(), env=env, phase="test", decode_type="greedy", return_actions=True)
    td_init["action"] = out['actions']
    env.step(td_init)
    print(f"Cost: {-out['cost_bsf'][0].cpu().detach():.3f}")




Cost: -2.620
Cost: -2.620
Cost: -2.620
Cost: -2.542
Cost: -2.542
Cost: -2.542
Cost: -2.542
Cost: -2.542
Cost: -2.542
Cost: -2.542
Cost: -2.542
Cost: -2.542
Cost: -2.542
Cost: -2.542
Cost: -2.542
Cost: -2.542
Cost: -2.542
Cost: -2.542
Cost: -2.542
Cost: -2.542
