In [1]:
from trajlib.model.model_factory import create_model


model_config = {
    "model_name": 'transformer',
    "num_layers": 6,
    "d_model": 512,
    "num_heads": 8,
    "d_ff": 2048,
    "dropout": 0.1
}

model = create_model(model_config)
model

Transformer(
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-5): 6 x EncoderLayer(
        (attention): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True)
          (W_K): Linear(in_features=512, out_features=512, bias=True)
          (W_V): Linear(in_features=512, out_features=512, bias=True)
          (W_O): Linear(in_features=512, out_features=512, bias=True)
          (attention): ScaledDotProductAttention()
        )
        (ffn): PositionWiseFeedForward(
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (relu): ReLU()
        )
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
)

In [2]:
import pandas as pd

raw_data = pd.read_csv(
    "/data/hetianran/didi/chengdu/gps_20161101",
    names=["driver_id", "order_id", "timestamp", "lon", "lat"],
    nrows=1000
)

raw_data['traj_id'], unique_ids = pd.factorize(raw_data['order_id'])

from trajlib.data_processing.utils.data_definition import TrajectoryData
from tqdm import tqdm


def load_data_chengdu(raw_data) -> TrajectoryData:
    traj_data = TrajectoryData()

    extracted_df = raw_data.loc[:, ["traj_id", "driver_id", "order_id"]]
    extracted_df = extracted_df.drop_duplicates()
    extracted_df = extracted_df.reset_index(drop=True)

    traj_data.traj_table = extracted_df

    data_1 = []
    data_2 = []
    point_counter = 1
    for traj_id, timestamp, lon, lat in tqdm(
        zip(
            raw_data["traj_id"],
            raw_data["timestamp"],
            raw_data["lon"],
            raw_data["lat"],
        )
    ):
        data_1.append(
            {
                "point_id": point_counter,
                "timestamp": timestamp,
                "traj_id": traj_id,
            }
        )
        data_2.append({"lon": lon, "lat": lat})
        point_counter += 1

    traj_data.batch_append_point_data(
        new_point_data_list=data_1, extra_attr_list=data_2
    )

    return traj_data


traj_data = load_data_chengdu(raw_data)

1000it [00:00, 627795.84it/s]
  pd.concat([self.point_table, new_row_df], ignore_index=True)


In [5]:
import torch
from torch.utils.data import DataLoader
from trajlib.dataset.datasets import TrajectoryDataset


# 创建 TrajectoryDataset 对象，设置输入和输出长度
input_len = 10
output_len = 10
dataset = TrajectoryDataset(traj_data, input_len, output_len)

# 使用 DataLoader 加载数据集
dataloader = DataLoader(dataset, batch_size=800, shuffle=False)

# 迭代数据集，打印输出
for i, (x_tensor, x_mark_tensor, y_tensor, y_mark_tensor) in enumerate(dataloader):
    print(f"Batch {i}:")
    print("x_tensor:", x_tensor.shape)
    print("x_mark_tensor:", x_mark_tensor.shape)
    print("y_tensor:", y_tensor.shape)
    print("y_mark_tensor:", y_mark_tensor.shape)

Batch 0:
x_tensor: torch.Size([800, 10, 2])
x_mark_tensor: torch.Size([800, 10])
y_tensor: torch.Size([800, 10, 2])
y_mark_tensor: torch.Size([800, 10])
Batch 1:
x_tensor: torch.Size([105, 10, 2])
x_mark_tensor: torch.Size([105, 10])
y_tensor: torch.Size([105, 10, 2])
y_mark_tensor: torch.Size([105, 10])
