In [7]:
import pandas as pd
from tqdm import tqdm

In [8]:
def generate_user_edges(csv_path, output_path):
    # 加载训练数据
    train_df = pd.read_csv(csv_path)

    # 确保数据按用户和轨迹分组，并按时间排序
    train_df = train_df.sort_values(["user_id", "trajectory_id", "UTC_time"])

    # 初始化用户边字典
    user_edges = {"user_id": [], "item_edges_a": [], "item_edges_b": []}

    # 按 user_id 和 trajectory_id 分组处理
    grouped = train_df.groupby(["user_id", "trajectory_id"])

    for (user_id, traj_id), traj_group in tqdm(grouped, desc="Processing user trajectories"):
        # 提取当前轨迹的 POI 序列
        pois = traj_group["POI_id"].tolist()

        # 生成共现边 (A,B), (B,C), ...
        edges_a = pois[:-1]  # 起点列表
        edges_b = pois[1:]   # 终点列表

        # 保存到字典
        user_edges["user_id"].append(user_id)
        user_edges["item_edges_a"].append(edges_a)
        user_edges["item_edges_b"].append(edges_b)

    # 转换为 DataFrame
    user_edges_df = pd.DataFrame(user_edges)

    # 按 user_id 聚合所有轨迹的边
    user_edges_agg = user_edges_df.groupby("user_id").agg({
        "item_edges_a": lambda x: [item for sublist in x for item in sublist],
        "item_edges_b": lambda x: [item for sublist in x for item in sublist]
    }).reset_index()

    # 保存为 CSV 文件（列表存储为逗号分隔字符串）
    user_edges_agg["item_edges_a"] = user_edges_agg["item_edges_a"].apply(lambda x: ",".join(map(str, x)))
    user_edges_agg["item_edges_b"] = user_edges_agg["item_edges_b"].apply(lambda x: ",".join(map(str, x)))
    user_edges_agg.to_csv(output_path, index=False)

In [7]:
generate_user_edges('./gowalla/gowalla-ca_train.csv', './gowalla/user_edges.csv')

Processing user trajectories: 100%|██████████| 65651/65651 [00:02<00:00, 25702.34it/s]


In [8]:
generate_user_edges('./foursquare/nyc/nyc_train.csv', './foursquare/nyc/user_edges.csv')
generate_user_edges('./foursquare/tky/tky_train.csv', './foursquare/tky/user_edges.csv')

Processing user trajectories: 100%|██████████| 26589/26589 [00:01<00:00, 24846.84it/s]
Processing user trajectories: 100%|██████████| 72584/72584 [00:02<00:00, 25774.71it/s]


In [6]:
generate_user_edges('./long-tail/foursquare/nyc/nyc_train_longtail.csv', './long-tail/foursquare/nyc/user_edges_longtail.csv')
generate_user_edges('./long-tail/foursquare/tky/tky_train_longtail.csv', './long-tail/foursquare/tky/user_edges_longtail.csv')

Processing user trajectories: 100%|██████████| 4719/4719 [00:00<00:00, 5787.60it/s]
Processing user trajectories: 100%|██████████| 8367/8367 [00:00<00:00, 13617.61it/s]
Processing user trajectories: 100%|██████████| 11011/11011 [00:00<00:00, 12636.74it/s]


In [9]:
generate_user_edges('./long-tail/gowalla/gowalla-ca_train_longtail.csv', './long-tail/gowalla/user_edges_longtail.csv')

Processing user trajectories: 100%|██████████| 11011/11011 [00:00<00:00, 13435.74it/s]
