In [1]:
import torch
from torch import nn
from lightning.pytorch.callbacks import ModelCheckpoint, RichModelSummary
from lightning.pytorch.loggers import WandbLogger
torch.cuda.device_count()

1

In [2]:
from rl4co.envs import SSPEnv
from rl4co.models.zoo.am import AttentionModelPolicy, AttentionModel
from rl4co.utils.trainer import RL4COTrainer


In [3]:
from rl4co.utils.decoding import random_policy, rollout
from rl4co.utils.ops import gather_by_index

# RL4CO env based on TorchRL
env = SSPEnv(generator_params=dict(num_loc=10, fixed_len=15))
td = env.reset(batch_size=[3])
# env.render(td)

reward, td, actions = rollout(env, env.reset(batch_size=[3]), random_policy)
reward, env.render(td, actions)

SSP codes:
 tensor([[0., 0., 1., 1., 1., 0., 0., 1., 1., 1., 0., 0., 0., 1., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 1., 0.],
        [0., 0., 0., 0., 1., 1., 0., 1., 1., 1., 0., 1., 0., 0., 0.],
        [1., 1., 0., 0., 1., 1., 1., 0., 1., 1., 0., 1., 1., 1., 1.],
        [1., 1., 1., 1., 0., 0., 0., 1., 1., 1., 1., 0., 0., 1., 0.],
        [0., 1., 1., 0., 0., 1., 1., 0., 1., 1., 1., 0., 1., 0., 0.],
        [0., 1., 1., 0., 0., 0., 0., 0., 1., 0., 1., 1., 0., 1., 0.],
        [0., 1., 0., 1., 0., 1., 1., 1., 0., 1., 1., 0., 1., 0., 0.],
        [1., 0., 0., 1., 1., 1., 0., 1., 1., 1., 1., 0., 0., 0., 0.],
        [0., 0., 1., 0., 1., 0., 1., 1., 1., 1., 1., 0., 0., 0., 0.]]) 

 Suggested order: tensor([5, 1, 4, 0, 2, 9, 8, 6, 7, 3]) 

 Sorted codes according to the order:
 tensor([[0., 1., 1., 0., 0., 1., 1., 0., 1., 1., 1., 0., 1., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 1., 0.],
        [1., 1., 1., 1., 0., 0., 0., 1., 1., 1.



(tensor([-14.0000, -13.9000, -13.8000]), None)

In [4]:
class SSPInitEmbedding(nn.Module):

    def __init__(self, embedding_dim, fixed_len, linear_bias=True):
        super(SSPInitEmbedding, self).__init__()
        node_dim = fixed_len  # x, y
        self.init_embed = nn.Linear(node_dim, embedding_dim, linear_bias)

    def forward(self, td):
        out = self.init_embed(td["codes"])
        return out

class SSPContext(nn.Module):
    """Context embedding for the Traveling Salesman Problem (TSP).
    Project the following to the embedding space:
        - first node embedding
        - current node embedding
    """

    def __init__(self, embedding_dim,  linear_bias=True):
        super(SSPContext, self).__init__()
        self.W_placeholder = nn.Parameter(
            torch.Tensor(embedding_dim).uniform_(-1, 1)
        )
        self.project_context = nn.Linear(
            embedding_dim, embedding_dim, bias=linear_bias
        )

    def forward(self, embeddings, td):
        batch_size = embeddings.size(0)
        # By default, node_dim = -1 (we only have one node embedding per node)
        node_dim = (
            (-1,) if td["current_node"].dim() == 1 else (td["current_node"].size(-1), -1)
        )
        if td["i"][(0,) * td["i"].dim()].item() < 1:  # get first item fast
            context_embedding = self.W_placeholder[None, :].expand(
                batch_size, self.W_placeholder.size(-1)
            )
        else:
            context_embedding = gather_by_index(
                embeddings,
                torch.stack([td["current_node"]], -1).view(
                    batch_size, -1
                ),
            ).view(batch_size, *node_dim)
        return self.project_context(context_embedding)
        
class StaticEmbedding(nn.Module):
    def __init__(self, *args, **kwargs):
        super(StaticEmbedding, self).__init__()

    def forward(self, td):
        return 0, 0, 0

num_loc = 100
fixed_len = 15
emb_dim = 128

env = SSPEnv(generator_params={"num_loc":num_loc,
                              "fixed_len":fixed_len})

policy = AttentionModelPolicy(env_name = env.name,
                              embed_dim=emb_dim,
                              num_encoder_layers=6,
                              num_heads=8,
                              normalization="instance",
                              init_embedding=SSPInitEmbedding(emb_dim, fixed_len),
                              context_embedding=SSPContext(emb_dim),
                              dynamic_embedding=StaticEmbedding(emb_dim),
                              use_graph_context=False
                             )

# Model: default is AM with REINFORCE and greedy rollout baseline
model = AttentionModel(env, 
            policy=policy,
            batch_size=512,
            train_data_size=100000,  # each epoch,
            val_batch_size=1000,
            val_data_size=1000,
            test_batch_size=1000,
            test_data_size=1000,
            optimizer="Adam",
            optimizer_kwargs={"lr": 1e-4, "weight_decay": 1e-6},
            lr_scheduler="MultiStepLR",
            lr_scheduler_kwargs={"milestones": [901, ], "gamma": 0.1},
        )



/home/shora/Research/rl4co/.venv/lib/python3.13/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`.
/home/shora/Research/rl4co/.venv/lib/python3.13/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`.


In [5]:
# Checkpointing callback: save models when validation reward improves
checkpoint_callback = ModelCheckpoint(dirpath="checkpoints_ssp", # save to checkpoints/
                                    filename="epoch_{epoch:03d}",  # save as epoch_XXX.ckpt
                                    save_top_k=5, # save only the best model
                                    save_last=True, # save the last model
                                    monitor="val/reward", # monitor validation reward
                                    mode="max") # maximize validation reward

rich_model_summary = RichModelSummary(max_depth=3)  # model summary callback
callbacks = [checkpoint_callback, rich_model_summary]

# Logger
# logger = WandbLogger(project="rl4co", name=f"{env.name}_{num_loc}")
logger = None # uncomment this line if you don't want logging



# We use our own wrapper around Lightning's `Trainer` to make it easier to use
trainer = RL4COTrainer(max_epochs=3, 
                       accelerator = 'gpu', 
                       devices=1,   
                       # logger=logger,
                       callbacks=callbacks,
                      )

trainer.test(model)
trainer.fit(model)

Using 16bit Automatic Mixed Precision (AMP)
Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/shora/Research/rl4co/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
val_file not set. Generating dataset instead
test_file not set. Generating dataset instead
LOCAL_RANK: 0 - C

Testing: |          | 0/? [00:00<?, ?it/s]

/home/shora/Research/rl4co/.venv/lib/python3.13/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /home/shora/Research/rl4co/checkpoints_ssp exists and is not empty.
val_file not set. Generating dataset instead
test_file not set. Generating dataset instead
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/shora/Research/rl4co/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
/home/shora/Research/rl4co/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.


In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
td = env.reset(batch_size=[3]).to(device)  # Move td to the same device as the environment
# env.render(td)

reward, td, actions = rollout(env, env.reset(batch_size=[3]).to(device), policy)  # Ensure td is on the correct device
reward, env.render(td, actions)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)