In [None]:

import torch

from lightning.pytorch.callbacks import ModelCheckpoint, RichModelSummary

from rl4co.envs import SSPkoptEnv, ATSPkoptEnv, TSPkoptEnv
from rl4co.models import NeuOptPolicy, NeuOpt
from rl4co.models.nn.env_embeddings.edge import ATSPEdgeEmbedding
from rl4co.utils.trainer import RL4COTrainer


In [None]:
import torch.nn as nn
from rl4co.envs.routing.atsp.generator import ATSPCoordGenerator

class CustomizeATSPInitEmbedding(nn.Module):
    def __init__(self, embed_dim, num_loc, linear_bias=True):
        super(CustomizeATSPInitEmbedding, self).__init__()
        node_dim = num_loc
        self.init_embed = nn.Sequential(
            # nn.LayerNorm(node_dim),
            nn.Linear(node_dim, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim),
        )


    def forward(self, td):
        out = self.init_embed(td["locs"])
        return out

num_loc = 20  # Number of strings
embed_dim = 128  # Dimension of the embedding space

gen = ATSPCoordGenerator(num_loc=num_loc, init_sol_type="random", seed=42)
env = ATSPkoptEnv(generator=gen, k_max=4)
model = NeuOpt(
        env,
        batch_size=128,
        train_data_size=1000,
        val_data_size=100,
        test_data_size=100,
        n_step=5,
        T_train=200,
        T_test=1000,
        CL_best=True,
        policy_kwargs=dict(
            embed_dim=embed_dim,
            init_embedding=CustomizeATSPInitEmbedding(num_loc=num_loc,embed_dim=embed_dim),
        ),
    )

/home/shora/Research/rl4co/.venv/lib/python3.13/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'critic' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['critic'])`.


In [None]:
checkpoint_callback = ModelCheckpoint(  dirpath="checkpoints_tsp", # save to checkpoints/
                                        filename="epoch_{epoch:03d}",  # save as epoch_XXX.ckpt
                                        save_top_k=1, # save only the best model
                                        save_last=True, # save the last model
                                        monitor="val/cost_bsf", # monitor validation reward
                                        mode="min") # maximize validation reward


rich_model_summary = RichModelSummary(max_depth=3)

callbacks = [checkpoint_callback, rich_model_summary]

In [None]:
from lightning.pytorch.loggers import WandbLogger
logger = WandbLogger(project="rl4co", name="tsp", log_model=True, save_dir="wandb_logs")

#  logger = None


In [None]:
trainer = RL4COTrainer(
    max_epochs=20,
    gradient_clip_val=0.05,
    devices=1,
    accelerator="gpu",
    logger=logger,
    callbacks=callbacks,
)

Using 16bit Automatic Mixed Precision (AMP)
Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [None]:
trainer.test(model)

[34m[1mwandb[0m: Currently logged in as: [33mshoraaa[0m ([33mshoraaa-vnu[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


val_file not set. Generating dataset instead
test_file not set. Generating dataset instead
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/shora/Research/rl4co/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

RuntimeError: mat1 and mat2 shapes cannot be multiplied (2000x2 and 20x128)

In [None]:
# Fit and test the model
trainer.fit(model)
trainer.test(model)

Overriding gradient_clip_val to None for 'automatic_optimization=False' models
/home/shora/Research/rl4co/.venv/lib/python3.13/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:658: Checkpoint directory /home/shora/Research/rl4co/train/checkpoints_tsp exists and is not empty.
val_file not set. Generating dataset instead
test_file not set. Generating dataset instead
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/shora/Research/rl4co/.venv/lib/python3.13/site-packages/lightning/pytorch/core/optimizer.py:317: The lr scheduler dict contains the key(s) ['monitor'], but the keys will be ignored. You need to call `lr_scheduler.step()` manually in manual optimization.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/shora/Research/rl4co/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
/home/shora/Research/rl4co/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
/home/shora/Research/rl4co/.venv/lib/python3.13/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (8) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]