In [None]:
import os
import polars as pl
import numpy as np
from sklearn.preprocessing import OrdinalEncoder

In [None]:
data_dir = "../data"
min_num_training = 0
num_validation = 1
include_user_features = False # add gender, age, shopping, occupation to the dataset (filters to users with this data available)
include_ad_features = True # add category, brand, customer, campaign to the dataset (filters to ads with this data available)
include_behavior_log = True
include_ad_non_clks = True
max_sequence_len = 100
chunk_every = 100

In [None]:
dataset_params = f"{min_num_training}_min_train_clks-{num_validation}_test_clks"
if include_user_features:
    dataset_params += "-usr_fts"
if include_ad_features:
    dataset_params += "-ad_fts"
if include_ad_non_clks:
    dataset_params += "-non_clks"

In [None]:
user_feats = ["user"] + (["gender", "age", "shopping", "occupation"] if include_user_features else [])
ad_feats = ["adgroup"] + (["cate", "brand", "campaign", "customer"] if include_ad_features else [])
pretraining_ad_feats = set(["cate", "brand"]).intersection(set(ad_feats))

In [None]:
raw_sample = (pl
    .scan_parquet(os.path.join(data_dir, "raw_sample.parquet"))
    .filter(pl.col("clk") == True)
    .filter(pl.len().over("user") >= min_num_training + num_validation)
    .unique(["user", "adgroup", "timestamp"])
)
if include_user_features:
    raw_sample = raw_sample.join(
        other=pl.scan_parquet(os.path.join(data_dir, "user_profile.parquet")).select(user_feats),
        on="user", how="inner",
    )
if include_ad_features:
    raw_sample = raw_sample.join(
        other=pl.scan_parquet(os.path.join(data_dir, "ad_feature.parquet")).select(ad_feats),
        on="adgroup", how="inner",
    )
raw_sample = raw_sample.collect()
raw_sample

In [None]:
training_data = (
    raw_sample
    .filter(pl.len().over("user") > num_validation)
    .sort("user", "timestamp", nulls_last=True)
    .group_by("user", maintain_order=True)
    .agg(pl.all().head(pl.len() - num_validation))
    .explode(pl.all().exclude("user"))
    .select(*user_feats, *ad_feats, pl.lit(1).alias("btag").cast(pl.Int8), "timestamp")
)
training_data

In [None]:
validation_data = (
    raw_sample
    .sort("user", "timestamp", nulls_last=True)
    .group_by("user", maintain_order=True)
    .agg(pl.all().tail(num_validation))
    .explode(pl.all().exclude("user"))
    .select(*user_feats, *ad_feats, pl.lit(1).alias("btag").cast(pl.Int8), "timestamp")
)
validation_data

In [None]:
validation_data.join(training_data, on=["user", "adgroup"], how="inner") \
    .filter(pl.col("timestamp") <= pl.col("timestamp_right"))

In [None]:
first_validation_click = (
    validation_data.select("user", pl.col("timestamp").alias("first_validation_ad_click_time"))
    .sort("user", "first_validation_ad_click_time", nulls_last=True)
    .group_by("user", maintain_order=True).head(1)
)
first_validation_click

In [None]:
training_data = (training_data
    .join(first_validation_click, on = "user")
    .filter(pl.col("timestamp") < pl.col("first_validation_ad_click_time"))
    .select(pl.all().exclude("first_validation_ad_click_time"))
)
training_data

In [None]:
if include_behavior_log and ("cate" in ad_feats or "brand" in ad_feats):
    # Loading takes ~30s for pretraining dataset from behavior log
    behavior_log = (pl
        .scan_parquet(os.path.join(data_dir, "behavior_log.parquet"))
        .filter(pl.col("user").is_in(raw_sample.select("user").unique()))
    )
    if include_user_features:
        behavior_log = behavior_log.join(
            other=pl.scan_parquet(os.path.join(data_dir, "user_profile.parquet")).select(user_feats),
            on="user", how="inner",
        )
    behavior_log = (behavior_log.collect()
        .join(first_validation_click, on="user", how="inner")
        .filter(pl.col("timestamp") < pl.col("first_validation_ad_click_time"))
        .unique()
        .select(*user_feats, *pretraining_ad_feats, pl.col("btag").cast(pl.Int8), "timestamp")
    )
    training_data = pl.concat([training_data, behavior_log], how="diagonal")
training_data

In [None]:
valid_users = training_data.select("user").unique()
validation_data = validation_data.filter(pl.col("user").is_in(valid_users))
validation_data

In [None]:
if include_ad_non_clks:
    negatives = (pl
        .scan_parquet(os.path.join(data_dir, "raw_sample.parquet"))
        .filter((pl.col("clk") == False) & (pl.col("user").is_in(valid_users))).collect()
        .join(first_validation_click, on="user", how="inner")
        .filter(pl.col("timestamp") < pl.col("first_validation_ad_click_time"))
        .unique(["user", "adgroup", "timestamp"])
    )
    if include_user_features:
        negatives = negatives.join(
            other=pl.read_parquet(os.path.join(data_dir, "user_profile.parquet")).select(user_feats),
            on="user", how="inner",
        )
    if include_ad_features:
        negatives = negatives.join(
            other=pl.read_parquet(os.path.join(data_dir, "ad_feature.parquet")).select(ad_feats),
            on="adgroup", how="inner",
        )
    negatives = (negatives
        .select(*user_feats, *ad_feats, pl.lit(-1).alias("btag").cast(pl.Int8), "timestamp")
    )
    training_data = pl.concat([training_data, negatives])
training_data

In [None]:
user_profile = validation_data.select(user_feats).unique()
user_encoder = OrdinalEncoder(dtype=np.int32).fit(user_profile)
user_encoder.set_output(transform="polars")

In [None]:
ad_feature: pl.DataFrame = pl.concat([
    training_data.select(ad_feats).unique(),
    validation_data.select(ad_feats).unique(),
]).unique()
ad_encoder = OrdinalEncoder(dtype=np.int32, encoded_missing_value=-1).fit(ad_feature)
ad_encoder.set_output(transform="polars")

In [None]:
user_data = user_encoder.transform(training_data.select(user_feats))
ads_data = ad_encoder.transform(training_data.select(ad_feats))
interaction_data = training_data.select("btag", pl.col("timestamp").cast(pl.Int32), is_test = pl.lit(False))
training_data = pl.concat([user_data, ads_data, interaction_data], how="horizontal")
training_data

In [None]:
user_data = user_encoder.transform(validation_data.select(user_feats))
ads_data = ad_encoder.transform(validation_data.select(ad_feats))
interaction_data = validation_data.select("btag", pl.col("timestamp").cast(pl.Int32), is_test = pl.lit(True))
validation_data = pl.concat([user_data, ads_data, interaction_data], how="horizontal")
validation_data

In [None]:
user_profile.write_parquet(os.path.join(data_dir, f"user_profile-{dataset_params}.parquet"))
ad_feature.write_parquet(os.path.join(data_dir, f"ad_feature-{dataset_params}.parquet"))
training_data.write_parquet(os.path.join(data_dir, f"train-{dataset_params}.parquet"))
validation_data.write_parquet(os.path.join(data_dir, f"test-{dataset_params}.parquet"))

In [None]:
min_ad_click = min_num_training + num_validation
user_profile.write_parquet(os.path.join(data_dir, f"user_profile_{min_ad_click}.parquet"))
ad_feature.write_parquet(os.path.join(data_dir, f"ad_feature_{min_ad_click}.parquet"))
training_data.write_parquet(os.path.join(data_dir, f"train_{min_ad_click}.parquet"))
validation_data.write_parquet(os.path.join(data_dir, f"test_{min_ad_click}.parquet"))


In [None]:
interactions: pl.DataFrame = (pl
    .concat([training_data, validation_data])
    .with_columns(rel_ad_freq = (pl.len().over("adgroup") / pl.count("adgroup")).cast(pl.Float32))
)
rel_ad_freq_sum = interactions.select("adgroup", "rel_ad_freq").unique().select("rel_ad_freq").sum().item()
print("Relative Ad Frequency Sanity Check Sum:", rel_ad_freq_sum)
sequences = (interactions
    .sort("user", "timestamp")
    .group_by("user", maintain_order=True)
    .agg(
        pl.col(user_feats[1:]).first(),
        pl.col(*ad_feats, "rel_ad_freq", "btag", "timestamp", "is_test"),
        seq_len = pl.col("btag").len()
    )
)
max_seq_len = sequences.select(pl.col("seq_len").max()).item()
print("Maximum sequence length:", max_seq_len)
sequences

In [None]:
interaction_seq = (sequences
    .with_columns(pad_len = max_seq_len - pl.col("seq_len"))
    .select(
        pl.col(user_feats),
        pl.col("seq_len"),
        pl.col("pad_len"),
        *(
            pl.col(feat).list.concat(pl.lit(0).repeat_by(pl.col("pad_len"))).list.to_array(max_seq_len)
            for feat in [*ad_feats, "rel_ad_freq", "btag", "timestamp"]
        ),
        pl.col("is_test").list.concat(pl.lit(False).repeat_by(pl.col("pad_len"))).list.to_array(max_seq_len),
        padded_mask = (pl.lit(False).repeat_by(pl.col("seq_len"))
          .list.concat(pl.lit(True).repeat_by(pl.col("pad_len")))
          .list.to_array(max_seq_len)
        ),
    )
)

In [None]:
interactions.write_parquet(os.path.join(data_dir, f"interactions_{min_ad_click}.parquet"))
interaction_seq.write_parquet(os.path.join(data_dir, f"interaction_seq_{min_ad_click}.parquet"))