In [3]:
chengdu_data_config = {
    "data_name": "chengdu",
    "data_path": "/data/hetianran/didi/chengdu/gps_20161101",
    "data_size": 100000,  # None for all data
    "data_form": "grid",  # gps, grid, roadnet
    "grid_step": 100,  # 100 for 0.1 gps range, 200 for 0.2 gps range
    "unique": False,
    "window": (5, 128, 1),
    "varients": [],  # ["cropped", "distorted"]
    "vocab_size": 0,  # set by data factory
}

geolife_data_config = {
    "data_name": "geolife",
    "data_path": "./resource/dataset/Geolife/geolife_small.csv",
    "data_size": None,
    "data_form": "grid",
    "grid_step": 200,
    "unique": False,
    "window": (5, 256, 256),
    "varients": [],
    "vocab_size": 0,
}

bj_data_config = {
    "data_name": "bj",
    "data_path": "/data/hetianran/BJ/BJ_shuffled.csv",
    "data_size": 1000000,
    "data_form": "grid",
    "grid_step": 250,
    "unique": False,
    "window": (0, 128, 1),
    "varients": [],
    "vocab_size": 0,
}

prediction_task_config = {
    "task_name": "prediction",
    "train_mode": "fine-tune",  # pre-train, fine-tune, test-only
    "dataset_prop": (0.8, 0.1, 0.1),
    "input_len": 10,
    "output_len": 1,  # only 1
}

similarity_task_config = {
    "task_name": "similarity",
    "train_mode": "test-only",
    "dataset_prop": (0, 0, 1),
    "variant": "original",  # cropped, distorted, original
    "sub-task": "MSS",  # MSS, CDD, kNN
}

filling_task_config = {
    "task_name": "filling",
    "train_mode": "pre-train",
    "dataset_prop": (0.9, 0.1, 0),
    "sub-task": "mlm",  # mlm, autoregressive
}

classification_task_config = {
    "task_name": "classification",
    "train_mode": "fine-tune",
    "dataset_prop": (0.8, 0.1, 0.1),
    "class_attr": "vflag",
    "num_classes": 0,  # set by dataset factory
}

embedding_config = {
    "emb_name": "normal",
    "emb_dim": 256,
    "pre-trained": False,
    "embs_path": "",
}

encoder_config = {
    "encoder_name": "transformer",
    "num_layers": 6,
    "d_model": embedding_config["emb_dim"],
    "num_heads": 8,
    "d_ff": 2048,
    "dropout": 0.1,
}

trainer_config = {
    "model_path": "./resource/model/backbone/backbone.pth",
    "batch_size": 64,
    "learning_rate": 1e-4,
    "num_epochs": 10,
    "optimizer": "adam",
    "loss_function": "cross_entropy",
    "lr_scheduler": "step_lr",
}

config = {
    "data_config": bj_data_config,
    "task_config": classification_task_config,
    "embedding_config": embedding_config,
    "encoder_config": encoder_config,
    "trainer_config": trainer_config,
}

In [None]:
from trajlib.data.data_factory import create_data
from trajlib.dataset.dataset_factory import create_dataset

data, _ = create_data(config)

print(len(data.grid))
print(len(data))

lens = []
for traj in data.original:
    lens.append(len(traj))
print(min(lens), max(lens), sum(lens) / len(lens))

dataset, _, _ = create_dataset(config, data)

In [None]:
from trajlib.model.embedding.embedding_trainer import EmbeddingTrainer
from trajlib.model.embedding.gnn import GAETrainer
from trajlib.model.embedding.node2vec import Node2VecTrainer
from trajlib.data.data_factory import create_data

mapper: dict[str, type[EmbeddingTrainer]] = {
    "node2vec": Node2VecTrainer,
    "gat": GAETrainer,
    "gcn": GAETrainer,
}

if embedding_config["pre-trained"]:
    _, graph_data = create_data(config)
    trainer = mapper[embedding_config["emb_name"]](embedding_config, graph_data)
    trainer.train()

In [4]:
from accelerate import notebook_launcher

from trajlib.runner.base_runner import BaseRunner


def accelerate_run(config):
    runner = BaseRunner(config)
    runner.run()

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
# accelerate_run(config)

notebook_launcher(accelerate_run, args=(config,), num_processes=2, use_port="29502")

# TODO debug embedding 预训练和 encoder 训练无法连续运行

Launching training on 2 GPUs.


Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Counter({1: 23008, 0: 9590})
Counter({1: 23008, 0: 9590})
Load model successfully


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
Epoch 1 Train: 100%|██████████| 204/204 [00:06<00:00, 30.45it/s]
Epoch 1 Valid: 100%|██████████| 26/26 [00:00<00:00, 41.10it/s]
Epoch 1  Test: 100%|██████████| 26/26 [00:00<00:00, 37.95it/s]


Train Loss: 0.6453, Val Loss: 0.5805, Test Loss: 0.5745, Test Accuracy: 0.7122


Epoch 2 Train: 100%|██████████| 204/204 [00:05<00:00, 36.20it/s]
Epoch 2 Valid: 100%|██████████| 26/26 [00:00<00:00, 39.66it/s]
Epoch 2  Test: 100%|██████████| 26/26 [00:01<00:00, 17.05it/s]


Train Loss: 0.5744, Val Loss: 0.5695, Test Loss: 0.5621, Test Accuracy: 0.7125


Epoch 3 Train: 100%|██████████| 204/204 [00:05<00:00, 36.66it/s]
Epoch 3 Valid: 100%|██████████| 26/26 [00:00<00:00, 37.70it/s]
Epoch 3  Test: 100%|██████████| 26/26 [00:00<00:00, 38.53it/s]


Train Loss: 0.5667, Val Loss: 0.5661, Test Loss: 0.5579, Test Accuracy: 0.7119


Epoch 4 Train: 100%|██████████| 204/204 [00:05<00:00, 37.54it/s]
Epoch 4 Valid: 100%|██████████| 26/26 [00:00<00:00, 41.46it/s]
Epoch 4  Test: 100%|██████████| 26/26 [00:00<00:00, 38.67it/s]


Train Loss: 0.5638, Val Loss: 0.5646, Test Loss: 0.5560, Test Accuracy: 0.7140


Epoch 5 Train: 100%|██████████| 204/204 [00:05<00:00, 36.50it/s]
Epoch 5 Valid: 100%|██████████| 26/26 [00:00<00:00, 40.14it/s]
Epoch 5  Test: 100%|██████████| 26/26 [00:00<00:00, 36.27it/s]


Train Loss: 0.5617, Val Loss: 0.5636, Test Loss: 0.5547, Test Accuracy: 0.7186


Epoch 6 Train: 100%|██████████| 204/204 [00:05<00:00, 36.61it/s]
Epoch 6 Valid: 100%|██████████| 26/26 [00:00<00:00, 39.63it/s]
Epoch 6  Test: 100%|██████████| 26/26 [00:00<00:00, 38.00it/s]


Train Loss: 0.5608, Val Loss: 0.5626, Test Loss: 0.5534, Test Accuracy: 0.7195


Epoch 7 Train: 100%|██████████| 204/204 [00:05<00:00, 36.04it/s]
Epoch 7 Valid: 100%|██████████| 26/26 [00:00<00:00, 38.56it/s]
Epoch 7  Test: 100%|██████████| 26/26 [00:00<00:00, 38.79it/s]


Train Loss: 0.5597, Val Loss: 0.5617, Test Loss: 0.5523, Test Accuracy: 0.7199


Epoch 8 Train: 100%|██████████| 204/204 [00:05<00:00, 37.33it/s]
Epoch 8 Valid: 100%|██████████| 26/26 [00:00<00:00, 41.00it/s]
Epoch 8  Test: 100%|██████████| 26/26 [00:00<00:00, 40.75it/s]


Train Loss: 0.5581, Val Loss: 0.5609, Test Loss: 0.5512, Test Accuracy: 0.7208


Epoch 9 Train: 100%|██████████| 204/204 [00:05<00:00, 36.51it/s]
Epoch 9 Valid: 100%|██████████| 26/26 [00:00<00:00, 38.63it/s]
Epoch 9  Test: 100%|██████████| 26/26 [00:00<00:00, 36.94it/s]


Train Loss: 0.5568, Val Loss: 0.5602, Test Loss: 0.5503, Test Accuracy: 0.7214


Epoch 10 Train: 100%|██████████| 204/204 [00:05<00:00, 36.41it/s]
Epoch 10 Valid: 100%|██████████| 26/26 [00:00<00:00, 36.87it/s]
Epoch 10  Test: 100%|██████████| 26/26 [00:00<00:00, 36.97it/s]


Train Loss: 0.5566, Val Loss: 0.5595, Test Loss: 0.5494, Test Accuracy: 0.7208


Final Test: 100%|██████████| 26/26 [00:00<00:00, 40.14it/s]


Final Loss: 0.5494, Final Accuracy: 0.7208
