In [1]:
from rl4co.envs.routing import TSPEnv, TSPGenerator

from rl4co.models.zoo.neufaco.model import DeepACO
from rl4co.models.zoo.neufaco.policy import DeepACOPolicy
from rl4co.utils import RL4COTrainer


from lightning.pytorch.callbacks import ModelCheckpoint, RichModelSummary
from lightning.pytorch.loggers import WandbLogger

from rl4co.models.zoo.neufaco.antsystem import AntSystem
from rl4co.models.zoo.hdgaco.FocusedACO import FocusedACO

In [2]:
# Instantiate generator and environment
num_loc = 50
generator = TSPGenerator(num_loc=num_loc, loc_distribution="uniform")
env = TSPEnv(generator)

In [3]:


policy = DeepACOPolicy(env_name=env.name, 
                     aco_class=AntSystem,
                     k_sparse=20, 
                     train_with_local_search=True,
                     aco_kwargs={"use_local_search": True})
model = DeepACO(env, 
              policy, 
              batch_size=512, 
              train_data_size=10_000,
              val_data_size=1_000,
              val_batch_size=512,
              test_data_size=1_000, 
              optimizer_kwargs={"lr": 1e-4}
              )

logger = WandbLogger(project="hdgaco", name="tsp_{num_loc}_deepaco_faco")

/home/shora/Research/rl4co/.venv/lib/python3.12/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`.
/home/shora/Research/rl4co/.venv/lib/python3.12/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`.


In [4]:
checkpoint_callback = ModelCheckpoint(  dirpath="checkpoints/gfacs_mmas_{num_loc}", # save to checkpoints/
                                        filename="epoch_{epoch:03d}",  # save as epoch_XXX.ckpt
                                        save_top_k=1, # save only the best model
                                        save_last=True, # save the last model
                                        monitor="val/reward", # monitor validation reward
                                        mode="max") # maximize validation reward


rich_model_summary = RichModelSummary(max_depth=3)

callbacks = [checkpoint_callback, rich_model_summary]


In [5]:
trainer = RL4COTrainer(
    max_epochs=20,
    accelerator="gpu",
    devices=1,
    logger=logger,
    callbacks=callbacks,
)

Using 16bit Automatic Mixed Precision (AMP)
Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [6]:
trainer.fit(model)


[34m[1mwandb[0m: Currently logged in as: [33mshoraaa[0m ([33mshoraaa-vnu[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


/home/shora/Research/rl4co/.venv/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:658: Checkpoint directory /home/shora/Research/rl4co/checkpoints/gfacs_mmas_{num_loc} exists and is not empty.
val_file not set. Generating dataset instead
test_file not set. Generating dataset instead
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/shora/Research/rl4co/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


TypeError: AntSystem.__init__() got an unexpected keyword argument 'best_ant_only'

In [None]:
trainer.test(model)

val_file not set. Generating dataset instead
test_file not set. Generating dataset instead
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/shora/Research/rl4co/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 2/2 [00:44<00:00,  0.04it/s]


[{'test/reward': -7.7954230308532715}]