In [33]:
import os
import pandas as pd
import torch
import random
from tqdm import tqdm
import dgl

MIN_NUM_NODES = 10
MAX_NUM_NODES = 20
NUM_EACH_LENGTH = 9000
NUM_SESSION_PER_SAMPLE = 10

SAMPLE_RATIO = 0.1

df = pd.read_csv("../../data/Diginetica/raw/yoochoose-clicks.dat")
df = df[["session_id", "timestamp", "item_id"]]

session_counts = df["session_id"].value_counts()
sessions = []
for i in range(MIN_NUM_NODES, MAX_NUM_NODES + 1, 1):
    valid_session_ids = session_counts[session_counts == i].index.tolist()
    assert len(valid_session_ids) >= NUM_EACH_LENGTH
    sessions.extend(random.sample(valid_session_ids, NUM_EACH_LENGTH))

df = df[df["session_id"].isin(sessions)]
# transform timestamp to Unix style
# df["timestamp"] = df["timestamp"].apply(lambda x: rfc3339_to_unix_milliseconds(x))
print("Transform timestamp to Unix style")

  df = pd.read_csv("../../data/Diginetica/raw/yoochoose-clicks.dat")


Transform timestamp to Unix style


In [34]:
# reindex item id
print("Reindex item ids")
old_item_ids = df["item_id"].drop_duplicates().tolist()
item_id_mapping = {old_id: new_id for new_id, old_id in enumerate(old_item_ids)}
df["item_id"] = df["item_id"].apply(lambda x: item_id_mapping[x])
num_nodes = len(item_id_mapping)

Reindex item ids


In [35]:
from collections import Counter


counter = Counter(df["item_id"].tolist())

In [56]:
sorted_counter = sorted(counter.items(), key=lambda x: x[1], reverse=True)
filtered_node_ids = sorted_counter[: int(len(sorted_counter) * 0.005)]
# filtered_counter = filter(lambda x: x[1] >= 1000, counter.items())
# filtered_counter = list(filtered_counter)
# filtered_node_ids, _ = zip(*filtered_counter)
filtered_node_ids


[(12975, 4210),
 (16209, 4202),
 (12136, 4192),
 (13027, 4065),
 (522, 3690),
 (25383, 2730),
 (3151, 2531),
 (24897, 2356),
 (365, 2314),
 (17159, 2300),
 (12987, 2273),
 (25841, 2241),
 (20408, 2240),
 (17153, 2192),
 (108, 2156),
 (17166, 2107),
 (12988, 2102),
 (18995, 2084),
 (3153, 2036),
 (186, 1977),
 (12976, 1975),
 (24072, 1935),
 (20419, 1902),
 (9772, 1901),
 (17168, 1901),
 (2528, 1891),
 (1292, 1865),
 (26450, 1840),
 (16192, 1818),
 (26444, 1804),
 (1467, 1793),
 (18987, 1785),
 (25840, 1740),
 (318, 1735),
 (18993, 1733),
 (16214, 1732),
 (145, 1708),
 (25851, 1705),
 (9771, 1697),
 (26437, 1693),
 (4, 1690),
 (19010, 1679),
 (26456, 1654),
 (27000, 1651),
 (1291, 1633),
 (370, 1627),
 (543, 1608),
 (563, 1588),
 (25843, 1570),
 (305, 1558),
 (19020, 1557),
 (8, 1545),
 (14815, 1529),
 (25955, 1526),
 (15345, 1524),
 (2450, 1515),
 (24082, 1513),
 (507, 1505),
 (112, 1498),
 (443, 1488),
 (26443, 1484),
 (21534, 1482),
 (438, 1471),
 (24498, 1462),
 (18994, 1460),
 (210

In [38]:
sorted(counter.items(), key=lambda x: x[1], reverse=True)

[(12975, 4210),
 (16209, 4202),
 (12136, 4192),
 (13027, 4065),
 (522, 3690),
 (25383, 2730),
 (3151, 2531),
 (24897, 2356),
 (365, 2314),
 (17159, 2300),
 (12987, 2273),
 (25841, 2241),
 (20408, 2240),
 (17153, 2192),
 (108, 2156),
 (17166, 2107),
 (12988, 2102),
 (18995, 2084),
 (3153, 2036),
 (186, 1977),
 (12976, 1975),
 (24072, 1935),
 (20419, 1902),
 (9772, 1901),
 (17168, 1901),
 (2528, 1891),
 (1292, 1865),
 (26450, 1840),
 (16192, 1818),
 (26444, 1804),
 (1467, 1793),
 (18987, 1785),
 (25840, 1740),
 (318, 1735),
 (18993, 1733),
 (16214, 1732),
 (145, 1708),
 (25851, 1705),
 (9771, 1697),
 (26437, 1693),
 (4, 1690),
 (19010, 1679),
 (26456, 1654),
 (27000, 1651),
 (1291, 1633),
 (370, 1627),
 (543, 1608),
 (563, 1588),
 (25843, 1570),
 (305, 1558),
 (19020, 1557),
 (8, 1545),
 (14815, 1529),
 (25955, 1526),
 (15345, 1524),
 (2450, 1515),
 (24082, 1513),
 (507, 1505),
 (112, 1498),
 (443, 1488),
 (26443, 1484),
 (21534, 1482),
 (438, 1471),
 (24498, 1462),
 (18994, 1460),
 (210