In [1]:
data_config = {
    "data_name": "chengdu",
    "data_path": "/data/hetianran/didi/chengdu/gps_20161101",
    "data_size": 100000,
    "data_form": "grid",
    "grid_step": 100,
    "vocab_size": None,
}

prediction_task_config = {
    "task_name": "prediction",
    "train_mode": "pre-train", # pre-train, fine-tune, test-only
    "dataset_prop": (0.8, 0.1, 0.1),
    "input_len": 10,
    "output_len": 1,
}

similarity_task_config = {
    "task_name": "similarity",
    "train_mode": "test-only",
    "dataset_prop": (0, 0, 1),
    "sub-task": "kNN", # kNN, CDD
}

embedding_config = {
    "emb_name": "normal",
    "emb_dim": 256,
    "pre-trained": False,
    "embs_path": "",
}

encoder_config = {
    "encoder_name": "transformer",
    "num_layers": 6,
    "d_model": embedding_config["emb_dim"],
    "num_heads": 8,
    "d_ff": 2048,
    "dropout": 0.1,
}

trainer_config = {
    "model_path": "./resource/model/backbone/backbone.pth",
    "batch_size": 64,
    "learning_rate": 1e-4,
    "num_epochs": 10,
    "optimizer": "adam",
    "loss_function": "cross_entropy",
    "lr_scheduler": "step_lr",
}

config = {
    "data_config": data_config,
    "task_config": similarity_task_config,
    "embedding_config": embedding_config,
    "encoder_config": encoder_config,
    "trainer_config": trainer_config,
}

In [None]:
from trajlib.model.embedding.embedding_trainer import EmbeddingTrainer
from trajlib.model.embedding.gnn import GAETrainer
from trajlib.model.embedding.node2vec import Node2VecTrainer
from trajlib.data.data_factory import create_data

mapper: dict[str, type[EmbeddingTrainer]] = {
    "node2vec": Node2VecTrainer,
    "gat": GAETrainer,
    "gcn": GAETrainer,
}

if embedding_config["pre-trained"]:
    _, graph_data = create_data(config)
    trainer = mapper[embedding_config["emb_name"]](embedding_config, graph_data)
    trainer.train()

In [2]:
from accelerate import notebook_launcher

from trajlib.runner.base_runner import BaseRunner


def accelerate_run(config):
    runner = BaseRunner(config)
    runner.run()

# accelerate_run(config)

notebook_launcher(accelerate_run, args=(config,), num_processes=4, use_port="29504")

# TODO debug embedding 预训练和 encoder 训练无法连续运行

Launching training on 4 GPUs.


Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Load model successfully


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
Testing: 100%|██████████| 3/3 [00:02<00:00,  1.30it/s]


Final Precision@2: 0.6140, Final Recall@2: 0.6140
