In [1]:
data_config = {
    "data_name": "chengdu",
    "data_path": "/data/hetianran/didi/chengdu/gps_20161101",
    "data_size": 100000,
    "data_form": "grid",
    "grid_step": 100,
    "vocab_size": None,
}

task_config = {
    "task_name": "prediction",
    "input_len": 10,
    "output_len": 1,
}

dataset_config = {
    "val_prop": 0.1,
    "test_prop": 0.1,
}

embedding_config = {
    "emb_name": "node2vec",
    "emb_dim": 256,
    "pre-trained": True,
    "embs_path": "./resource/model/embedding/node2vec.pkl",
}

encoder_config = {
    "encoder_name": "transformer",
    "num_layers": 6,
    "d_model": embedding_config["emb_dim"],
    "num_heads": 8,
    "d_ff": 2048,
    "dropout": 0.1,
}

trainer_config = {
    "batch_size": 64,
    "learning_rate": 1e-4,
    "num_epochs": 10,
    "optimizer": "adam",
    "loss_function": "cross_entropy",
    "lr_scheduler": "step_lr",
}

config = {
    "data_config": data_config,
    "task_config": task_config,
    "dataset_config": dataset_config,
    "embedding_config": embedding_config,
    "encoder_config": encoder_config,
    "trainer_config": trainer_config,
}

In [None]:
from trajlib.model.embedding.embedding_trainer import EmbeddingTrainer
from trajlib.model.embedding.gat import GATTrainer
from trajlib.model.embedding.node2vec import Node2VecTrainer
from trajlib.data.data_factory import create_data

mapper: dict[str, type[EmbeddingTrainer]] = {
    "node2vec": Node2VecTrainer,
    "gat": GATTrainer,
}

if embedding_config["pre-trained"]:
    _, graph_data = create_data(config)
    trainer = mapper[embedding_config["emb_name"]](embedding_config, graph_data)
    trainer.train()

In [2]:
from accelerate import notebook_launcher

from trajlib.runner.base_runner import BaseRunner


def accelerate_run(config):
    runner = BaseRunner(config)
    runner.run()


notebook_launcher(accelerate_run, args=(config,), num_processes=4)

# TODO debug embedding 预训练和 encoder 训练无法连续运行

Launching training on 4 GPUs.


Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
Epoch: 1: 100%|██████████| 295/295 [00:12<00:00, 23.62it/s]


Epoch: 1, Train Loss: 6.2429732581316415, Val Loss: 4.282995223999023, Test Loss: 4.201303958892822, Test Acc: 0.6198960542678833


Epoch: 2: 100%|██████████| 295/295 [00:12<00:00, 23.70it/s]


Epoch: 2, Train Loss: 3.43557003150552, Val Loss: 2.7674761076231262, Test Loss: 2.6894290447235107, Test Acc: 0.7771767973899841


Epoch: 3: 100%|██████████| 295/295 [00:11<00:00, 24.60it/s]


Epoch: 3, Train Loss: 2.332533046350641, Val Loss: 2.0706423134417147, Test Loss: 2.0040957927703857, Test Acc: 0.841446578502655


Epoch: 4: 100%|██████████| 295/295 [00:12<00:00, 24.52it/s]


Epoch: 4, Train Loss: 1.7650996321338719, Val Loss: 1.696498468115523, Test Loss: 1.6400519609451294, Test Acc: 0.878035843372345


Epoch: 5: 100%|██████████| 295/295 [00:12<00:00, 24.24it/s]


Epoch: 5, Train Loss: 1.4402742145425182, Val Loss: 1.4834094305296202, Test Loss: 1.4322717189788818, Test Acc: 0.897231936454773


Epoch: 6: 100%|██████████| 295/295 [00:12<00:00, 23.94it/s]


Epoch: 6, Train Loss: 1.2373223557310589, Val Loss: 1.351755651267799, Test Loss: 1.3029507398605347, Test Acc: 0.910064697265625


Epoch: 7: 100%|██████████| 295/295 [00:12<00:00, 24.51it/s]


Epoch: 7, Train Loss: 1.100642100633201, Val Loss: 1.2649983908679034, Test Loss: 1.2171436548233032, Test Acc: 0.9200339317321777


Epoch: 8: 100%|██████████| 295/295 [00:12<00:00, 24.49it/s]


Epoch: 8, Train Loss: 1.0008506724389934, Val Loss: 1.194121716795741, Test Loss: 1.144037127494812, Test Acc: 0.9287304878234863


Epoch: 9: 100%|██████████| 295/295 [00:12<00:00, 24.31it/s]


Epoch: 9, Train Loss: 0.9245620486089738, Val Loss: 1.1422462882222355, Test Loss: 1.0923051834106445, Test Acc: 0.9327605962753296


Epoch: 10: 100%|██████████| 295/295 [00:12<00:00, 24.19it/s]


Epoch: 10, Train Loss: 0.8618332265797308, Val Loss: 1.1043674027597583, Test Loss: 1.0541352033615112, Test Acc: 0.9359422922134399
Final Test Loss: 1.0541352033615112, Final Test Accuracy: 0.9359422922134399
