In [2]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=1
import torch
from torch import nn
from lightning.pytorch.callbacks import ModelCheckpoint, RichModelSummary
from lightning.pytorch.loggers import WandbLogger
torch.cuda.device_count()

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=1


4

In [3]:
from rl4co.envs import SSPEnv
from rl4co.models.zoo.am import AttentionModelPolicy, AttentionModel
from rl4co.utils.trainer import RL4COTrainer

In [4]:
from rl4co.utils.decoding import random_policy, rollout
from rl4co.utils.ops import gather_by_index

# RL4CO env based on TorchRL
env = SSPEnv(generator_params=dict(num_loc=10, fixed_len=15))
td = env.reset(batch_size=[3])
# env.render(td)

reward, td, actions = rollout(env, env.reset(batch_size=[3]), random_policy)
reward, env.render(td, actions)

SSP codes:
 tensor([[0., 1., 1., 1., 1., 1., 1., 0., 1., 0., 0., 1., 0., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 0., 1., 1.],
        [1., 1., 1., 0., 1., 0., 1., 0., 1., 0., 0., 1., 1., 0., 1.],
        [1., 0., 0., 1., 1., 1., 1., 1., 1., 0., 1., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 1., 1., 1., 0., 1., 0., 1., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 1., 1., 1., 1., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 0., 0.],
        [0., 0., 0., 1., 0., 0., 1., 0., 1., 1., 1., 0., 0., 1., 0.]]) 

 Suggested order: tensor([2, 3, 7, 8, 5, 0, 6, 1, 4, 9]) 

 Sorted codes according to the order:
 tensor([[1., 1., 1., 0., 1., 0., 1., 0., 1., 0., 0., 1., 1., 0., 1.],
        [1., 0., 0., 1., 1., 1., 1., 1., 1., 0., 1., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0., 1., 1., 1., 1., 0.

(tensor([-13.3000, -13.5000, -14.5000]), None)

In [4]:
class SSPInitEmbedding(nn.Module):

    def __init__(self, embedding_dim, fixed_len, linear_bias=True):
        super(SSPInitEmbedding, self).__init__()
        node_dim = fixed_len  # x, y
        self.init_embed = nn.Linear(node_dim, embedding_dim, linear_bias)

    def forward(self, td):
        out = self.init_embed(td["codes"])
        return out

class SSPContext(nn.Module):
    """Context embedding for the Traveling Salesman Problem (TSP).
    Project the following to the embedding space:
        - first node embedding
        - current node embedding
    """

    def __init__(self, embedding_dim,  linear_bias=True):
        super(SSPContext, self).__init__()
        self.W_placeholder = nn.Parameter(
            torch.Tensor(embedding_dim).uniform_(-1, 1)
        )
        self.project_context = nn.Linear(
            embedding_dim, embedding_dim, bias=linear_bias
        )

    def forward(self, embeddings, td):
        batch_size = embeddings.size(0)
        # By default, node_dim = -1 (we only have one node embedding per node)
        node_dim = (
            (-1,) if td["current_node"].dim() == 1 else (td["current_node"].size(-1), -1)
        )
        if td["i"][(0,) * td["i"].dim()].item() < 1:  # get first item fast
            context_embedding = self.W_placeholder[None, :].expand(
                batch_size, self.W_placeholder.size(-1)
            )
        else:
            context_embedding = gather_by_index(
                embeddings,
                torch.stack([td["current_node"]], -1).view(
                    batch_size, -1
                ),
            ).view(batch_size, *node_dim)
        return self.project_context(context_embedding)
        
class StaticEmbedding(nn.Module):
    def __init__(self, *args, **kwargs):
        super(StaticEmbedding, self).__init__()

    def forward(self, td):
        return 0, 0, 0

In [5]:
num_loc = 100
fixed_len = 15
emb_dim = 128

env = SSPEnv(generator_params={"num_loc":num_loc,
                              "fixed_len":fixed_len})

policy = AttentionModelPolicy(env_name = env.name,
                              embed_dim=emb_dim,
                              num_encoder_layers=6,
                              num_heads=8,
                              normalization="instance",
                              init_embedding=SSPInitEmbedding(emb_dim, fixed_len),
                              context_embedding=SSPContext(emb_dim),
                              dynamic_embedding=StaticEmbedding(emb_dim),
                              use_graph_context=False
                             )

# Model: default is AM with REINFORCE and greedy rollout baseline
model = AttentionModel(env, 
            policy=policy,
            batch_size=512,
            train_data_size=100000,  # each epoch,
            val_batch_size=1000,
            val_data_size=1000,
            test_batch_size=1000,
            test_data_size=1000,
            optimizer="Adam",
            optimizer_kwargs={"lr": 1e-4, "weight_decay": 1e-6},
            lr_scheduler="MultiStepLR",
            lr_scheduler_kwargs={"milestones": [901, ], "gamma": 0.1},
        )

/home/yining/miniconda3/envs/ai4co/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:199: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`.
/home/yining/miniconda3/envs/ai4co/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:199: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`.


In [6]:
# Checkpointing callback: save models when validation reward improves
checkpoint_callback = ModelCheckpoint(dirpath="checkpoints_ssp", # save to checkpoints/
                                    filename="epoch_{epoch:03d}",  # save as epoch_XXX.ckpt
                                    save_top_k=5, # save only the best model
                                    save_last=True, # save the last model
                                    monitor="val/reward", # monitor validation reward
                                    mode="max") # maximize validation reward

rich_model_summary = RichModelSummary(max_depth=3)  # model summary callback
callbacks = [checkpoint_callback, rich_model_summary]

# Logger
# logger = WandbLogger(project="rl4co", name=f"{env.name}_{num_loc}")
logger = None # uncomment this line if you don't want logging



# We use our own wrapper around Lightning's `Trainer` to make it easier to use
trainer = RL4COTrainer(max_epochs=1000, 
                       accelerator = 'gpu', 
                       devices=1,   
                       # logger=logger,
                       callbacks=callbacks,
                      )

trainer.test(model)
trainer.fit(model)

Using 16bit Automatic Mixed Precision (AMP)
Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
2024-06-03 20:53:56.227090: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-06-03 20:53:56.263796: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler

Testing: |                                                                                        | 0/? [00:00…

/home/yining/miniconda3/envs/ai4co/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:653: Checkpoint directory /home/yining/ssp/rl4co/checkpoints_ssp exists and is not empty.
val_file not set. Generating dataset instead
test_file not set. Generating dataset instead
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]


Sanity Checking: |                                                                                | 0/? [00:00…

/home/yining/miniconda3/envs/ai4co/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=27` in the `DataLoader` to improve performance.
/home/yining/miniconda3/envs/ai4co/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=27` in the `DataLoader` to improve performance.


Training: |                                                                                       | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

/home/yining/miniconda3/envs/ai4co/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [92]:
import torch
def generate_batch_superstring_data(batch_size, num_str, str_dim, alphabet_size=2):
    # Generate random strings
    batch_data = torch.randint(0, alphabet_size, (batch_size, num_str, str_dim))
    
    # Generate random overlap masks
    overlap_mask = torch.rand(batch_size, num_str - 1) > 0.5
    overlap_lengths = torch.randint(1, str_dim // 2 + 1, (batch_size, num_str - 1))
    
    # Generate index tensors for efficient slicing
    overlap_indices = torch.arange(str_dim).expand(batch_size, num_str - 1, str_dim)
    overlap_mask_expanded = overlap_mask.unsqueeze(-1).expand(batch_size, num_str - 1, str_dim)
    overlap_lengths_expanded = overlap_lengths.unsqueeze(-1).expand(batch_size, num_str - 1, str_dim)

    # Generate a mask for the overlap regions
    overlap_region_mask = (overlap_indices < overlap_lengths_expanded) & overlap_mask_expanded
    
    # Copy the values to the overlap region
    previous_strings = batch_data[:, :-1, :].clone()
    for i in range(str_dim):
        current_mask = overlap_region_mask[:, :, i]
        selected_overlap_index_at_i = (str_dim - overlap_lengths + i).view(-1,1) % str_dim
        selected_overlap = previous_strings.view(-1, str_dim).gather(1, selected_overlap_index_at_i).view(batch_size, num_str - 1)
        batch_data[:, 1:, i][current_mask] = selected_overlap[current_mask]
    
    # Shuffle the num_str dimension
    print(batch_data)
    perm = torch.rand(batch_size, num_str).argsort(dim=1)
    print(perm)
    batch_data = batch_data[torch.arange(batch_size).unsqueeze(1), perm]
    
    return batch_data.float()