In [None]:
chengdu_data_config = {
    "data_name": "chengdu",
    "data_path": "./resource/data/chengdu/gps_20161101",
    "data_size": 100000,  # None for all data
    "grid_step": 100,  # 100 for 0.1 gps range, 200 for 0.2 gps range
    "road_type": "drive",  # TODO 怎么设置 区别是什么
    "window": (5, 128, 10000),
    "cache_path": "./resource/data_cache/chengdu-100K-5_128_10000.data",
}

geolife_data_config = {
    "data_name": "geolife",
    "data_path": "./resource/data/geolife/geolife_small.csv",
    "data_size": 100000,
    "grid_step": 200,
    "road_type": "all",
    "window": (5, 256, 128),
    "cache_path": "./resource/data_cache/geolife-100K-5_256_128.data",
}

bj_data_config = {
    "data_name": "bj",
    "data_path": "./resource/BJ/BJ_shuffled.csv",
    "data_size": 1000000,
    "grid_step": 250,
    "road_type": "drive",
    "window": (5, 128, 1),
    "cache_path": "./resource/data_cache/BJ-1M-5_128_1.data",
}

prediction_task_config = {
    "task_name": "prediction",
    "train_mode": "pre-train",  # pre-train, fine-tune, test-only
    "dataset_prop": (0.8, 0.1, 0.1),
    "input_len": 10,
    "output_len": 1,
    "tokens": ["grid"],
}

similarity_task_config = {
    "task_name": "similarity",
    "train_mode": "test-only",
    "dataset_prop": (0, 0, 1),
    "variant": "cropped",  # cropped, distorted, original
    "sub-task": "kNN",  # MSS, CDD, kNN
}

filling_task_config = {
    "task_name": "filling",
    "train_mode": "pre-train",
    "dataset_prop": (0.9, 0.1, 0),
    "sub-task": "mlm",  # mlm, autoregressive
    "tokens": ["gps", "grid", "roadnet"],
}

classification_task_config = {
    "task_name": "classification",
    "train_mode": "pre-train",
    "dataset_prop": (0.8, 0.1, 0.1),
    "class_attr": "mode",  # mode for geolife, vflag for BJ
    "num_classes": 0,  # set by dataset factory
}

embedding_config = {
    "emb_dim": 256,
    "tokens": ["gps", "grid", "roadnet"],
    "gps": {
        "emb_name": "linear",
    },
    "grid": {
        "emb_name": "embedding",
        "vocab_size": 0,  # set by data factory
        "pre-trained": False,
        "embs_path": "",
    },
    "roadnet": {
        "emb_name": "embedding",
        "vocab_size": 0,  # set by data factory
        "pre-trained": False,
        "embs_path": "",
    },
}

encoder_config = {
    "encoder_name": "transformer",
    "num_layers": 6,
    "d_model": embedding_config["emb_dim"],
    "num_heads": 8,
    "d_ff": 2048,
    "dropout": 0.1,
}

trainer_config = {
    "model_path": "./resource/backbone/backbone.pth",
    "batch_size": 64,
    "learning_rate": 1e-4,
    "num_epochs": 10,
    "optimizer": "adam",
    "loss_function": "multi_token", # "mse", "cross_entropy", "multi_token"
    "lr_scheduler": "step_lr",
}

config = {
    "data_config": chengdu_data_config,
    "task_config": similarity_task_config,
    "embedding_config": embedding_config,
    "encoder_config": encoder_config,
    "trainer_config": trainer_config,
}

In [None]:
from trajlib.data.data_factory import create_data

data, grid_graph_data, road_graph_data = create_data(config, overwrite=True)

print(len(data.grid))
print(data.roadnet.edge_num)
print(len(data))

In [None]:
from accelerate import notebook_launcher

from trajlib.runner.base_runner import BaseRunner


def accelerate_run(config):
    runner = BaseRunner(config)
    runner.run()


# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3"
accelerate_run(config)

# notebook_launcher(accelerate_run, args=(config,), num_processes=4, use_port="29502")
# lsof -i :29500