In [None]:
import os
import polars as pl
import numpy as np

In [None]:
data_dir = os.path.join("..", "data")
train_parquet = os.path.join(data_dir, "train.parquet")
test_parquet = os.path.join(data_dir, "test.parquet")
min_timediff_unique = 30        # The minimum number of seconds between identical interactions (user, adgroup, btag), or (user, cate, brand, btag), before they are considered duplicates
min_training_interactions = 1   # The minimum number of non-ad-click, browse, ad-click, favorite, add-to-cart, or purchase interactions required in a training sequence
augmented = True                # Whether to include behavior log interaction data or not

In [None]:
dataset_params = f"timediff{min_timediff_unique}_mintrain{min_training_interactions}" + ("_aug" if augmented else "")
user_feats = ["user", "gender", "age", "shopping", "occupation"]
ad_feats = ["adgroup", "cate", "brand", "campaign", "customer"]

In [None]:
training_data = (pl.scan_parquet(train_parquet)
    .filter((pl.col("btag") != -1) & 
            (pl.col("timediff").is_null() | (pl.col("timediff") >= min_timediff_unique)) &
            ((pl.col("btag") == 1) if not augmented else True))
    .filter(pl.len().over("user") >= min_training_interactions)
    .collect()
)
training_data

In [None]:
validation_data = (pl.scan_parquet(test_parquet)
    .filter(pl.col("user").is_in(training_data.select("user").unique()))
    .collect()
)
validation_data

In [None]:
np.savez_compressed(
    file = os.path.join(data_dir, f"train_data_{dataset_params}"),
    user_data = training_data.select(user_feats).to_numpy(),
    ads_data = training_data.select(ad_feats).to_numpy(),
    interaction_data = training_data["btag"].to_numpy(),
    timestamps = training_data["timestamp"].to_numpy(),
)

In [None]:
np.savez_compressed(
    file = os.path.join(data_dir, f"test_data_{dataset_params}"),
    user_data = validation_data.select(user_feats).to_numpy(),
    ads_data = validation_data.select(ad_feats).to_numpy(),
    interaction_data = validation_data["btag"].to_numpy(),
    timestamps = validation_data["timestamp"].to_numpy(),
)