In [1]:
data_config = {
    "data_name": "chengdu",
    "data_path": "/data/hetianran/didi/chengdu/gps_20161101",
    "data_size": 100000,
    "data_form": "grid",
    "grid_step": 100,
    "vocab_size": None,
}

task_config = {
    "task_name": "prediction",
    "input_len": 10,
    "output_len": 1,
}

dataset_config = {
    "val_prop": 0.1,
    "test_prop": 0.1,
}

embedding_config = {
    "emb_name": "gcn",
    "emb_dim": 256,
    "pre-trained": True,
    "embs_path": "./resource/model/embedding/gcn.pkl",
}

encoder_config = {
    "encoder_name": "transformer",
    "num_layers": 6,
    "d_model": embedding_config["emb_dim"],
    "num_heads": 8,
    "d_ff": 2048,
    "dropout": 0.1,
}

trainer_config = {
    "batch_size": 64,
    "learning_rate": 1e-4,
    "num_epochs": 10,
    "optimizer": "adam",
    "loss_function": "cross_entropy",
    "lr_scheduler": "step_lr",
}

config = {
    "data_config": data_config,
    "task_config": task_config,
    "dataset_config": dataset_config,
    "embedding_config": embedding_config,
    "encoder_config": encoder_config,
    "trainer_config": trainer_config,
}

In [None]:
from trajlib.model.embedding.embedding_trainer import EmbeddingTrainer
from trajlib.model.embedding.gnn import GAETrainer
from trajlib.model.embedding.node2vec import Node2VecTrainer
from trajlib.data.data_factory import create_data

mapper: dict[str, type[EmbeddingTrainer]] = {
    "node2vec": Node2VecTrainer,
    "gat": GAETrainer,
    "gcn": GAETrainer
}

if embedding_config["pre-trained"]:
    _, graph_data = create_data(config)
    trainer = mapper[embedding_config["emb_name"]](embedding_config, graph_data)
    trainer.train()

In [2]:
from accelerate import notebook_launcher

from trajlib.runner.base_runner import BaseRunner


def accelerate_run(config):
    runner = BaseRunner(config)
    runner.run()


notebook_launcher(accelerate_run, args=(config,), num_processes=4, use_port="29501")

# TODO debug embedding 预训练和 encoder 训练无法连续运行

Launching training on 4 GPUs.


Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
Epoch: 1: 100%|██████████| 295/295 [00:13<00:00, 21.24it/s]


Epoch: 1, Train Loss: 6.6384102595054495, Val Loss: 4.650101700344601, Test Loss: 4.601358890533447, Test Acc: 0.5222187042236328


Epoch: 2: 100%|██████████| 295/295 [00:12<00:00, 22.80it/s]


Epoch: 2, Train Loss: 3.6746511103743216, Val Loss: 2.823658311689222, Test Loss: 2.758627414703369, Test Acc: 0.7236186265945435


Epoch: 3: 100%|██████████| 295/295 [00:12<00:00, 23.10it/s]


Epoch: 3, Train Loss: 2.3886157920805076, Val Loss: 2.0130702128281466, Test Loss: 1.9515337944030762, Test Acc: 0.809417724609375


Epoch: 4: 100%|██████████| 295/295 [00:12<00:00, 23.18it/s]


Epoch: 4, Train Loss: 1.771884443800328, Val Loss: 1.6250996654098098, Test Loss: 1.5709587335586548, Test Acc: 0.8551278114318848


Epoch: 5: 100%|██████████| 295/295 [00:12<00:00, 23.07it/s]


Epoch: 5, Train Loss: 1.4450899400953519, Val Loss: 1.400683721980533, Test Loss: 1.3554611206054688, Test Acc: 0.8848233819007874


Epoch: 6: 100%|██████████| 295/295 [00:12<00:00, 23.21it/s]


Epoch: 6, Train Loss: 1.251190911713293, Val Loss: 1.2722050338178068, Test Loss: 1.2325098514556885, Test Acc: 0.9034892320632935


Epoch: 7: 100%|██████████| 295/295 [00:12<00:00, 23.04it/s]


Epoch: 7, Train Loss: 1.1217485207622333, Val Loss: 1.1857797584018193, Test Loss: 1.1471209526062012, Test Acc: 0.9140948057174683


Epoch: 8: 100%|██████████| 295/295 [00:13<00:00, 22.66it/s]


Epoch: 8, Train Loss: 1.0310354232788086, Val Loss: 1.1251483172983736, Test Loss: 1.0858763456344604, Test Acc: 0.9206702709197998


Epoch: 9: 100%|██████████| 295/295 [00:12<00:00, 23.14it/s]


Epoch: 9, Train Loss: 0.9633079577300508, Val Loss: 1.0722939597593772, Test Loss: 1.0307695865631104, Test Acc: 0.9287304878234863


Epoch: 10: 100%|██████████| 295/295 [00:12<00:00, 22.85it/s]


Epoch: 10, Train Loss: 0.9103893556837308, Val Loss: 1.0318831817523852, Test Loss: 0.990290641784668, Test Acc: 0.9356241226196289
Final Test Loss: 0.990290641784668, Final Test Accuracy: 0.9356241226196289
