In [1]:
import os
import polars as pl
import numpy as np
from sklearn.preprocessing import OrdinalEncoder

In [2]:
data_dir = "../data"
min_num_training = 1
num_validation = 1
include_user_features = True # add gender, age, shopping, occupation to the dataset (filters to users with this data available)
include_ad_features = True # add category, brand, customer, campaign to the dataset (filters to ads with this data available)
include_ad_non_clks = False

In [3]:
dataset_params = f"{min_num_training}_min_train_clks-{num_validation}_test_clks"
if include_user_features:
    dataset_params += "-usr_fts"
if include_ad_features:
    dataset_params += "-ad_fts"
if include_ad_non_clks:
    dataset_params += "-non_clks"

In [4]:
user_feats = ["user"] + (["gender", "age", "shopping", "occupation"] if include_user_features else [])
ad_feats = ["adgroup"] + (["cate", "brand", "campaign", "customer"] if include_ad_features else [])
pretraining_ad_feats = set(["cate", "brand"]).intersection(set(ad_feats))

In [5]:
raw_sample = (pl
    .scan_parquet(os.path.join(data_dir, "raw_sample.parquet"))
    .filter((pl.col("clk") == True) | (pl.col("clk") == (min_num_training + num_validation > 0)))
    .filter(pl.len().over("user") >= min_num_training + num_validation)
    .unique(["user", "adgroup", "timestamp"])
)
if include_user_features:
    raw_sample = raw_sample.join(
        other=pl.scan_parquet(os.path.join(data_dir, "user_profile.parquet")).select(user_feats),
        on="user", how="inner",
    )
if include_ad_features:
    raw_sample = raw_sample.join(
        other=pl.scan_parquet(os.path.join(data_dir, "ad_feature.parquet")).select(ad_feats),
        on="adgroup", how="inner",
    )
raw_sample = raw_sample.collect()
raw_sample

user,adgroup,clk,timestamp,gender,age,shopping,occupation,cate,brand,campaign,customer
u32,u32,bool,u32,u8,u8,u8,u8,u16,u32,u32,u32
208939,718992,true,1494050883,2,4,3,0,6261,,331406,136644
1030207,541421,true,1494236444,2,4,3,0,4281,,104811,74568
924688,570814,true,1494506529,2,5,2,0,6261,,181223,42989
390494,516398,true,1494389633,2,4,3,0,366,,410911,65673
26049,320718,true,1494388132,2,3,3,0,4384,,420608,46450
…,…,…,…,…,…,…,…,…,…,…,…
675689,100854,true,1494142466,1,4,3,0,4305,353787,205794,206358
577344,339124,true,1494071379,1,4,3,0,4521,246138,9015,83578
548053,624690,true,1494461782,2,5,3,0,6261,,232938,164602
106483,628921,true,1494230346,2,5,3,0,6554,206372,246702,32597


In [6]:
training_data = (
    raw_sample
    .filter(pl.len().over("user") > num_validation)
    .sort("user", "timestamp", nulls_last=True)
    .group_by("user", maintain_order=True)
    .agg(pl.all().head(pl.len() - num_validation))
    .explode(pl.all().exclude("user"))
    .select(*user_feats, *ad_feats, pl.lit(1).alias("btag").cast(pl.Int8), "timestamp")
)
training_data

user,gender,age,shopping,occupation,adgroup,cate,brand,campaign,customer,btag,timestamp
u32,u8,u8,u8,u8,u32,u16,u32,u32,u32,i8,u32
4,1,5,2,0,144004,6300,,252559,187000,1,1494124371
4,1,5,2,0,661336,562,,195718,75325,1,1494129933
4,1,5,2,0,388902,562,,159335,116625,1,1494129933
4,1,5,2,0,438808,562,,51610,116649,1,1494136895
7,1,2,3,0,30074,7266,296265,71832,148537,1,1494598158
…,…,…,…,…,…,…,…,…,…,…,…
1141725,2,2,3,0,667859,7867,73436,250629,60061,1,1494138804
1141725,2,2,3,0,610751,4283,243097,263497,12747,1,1494472604
1141726,2,5,3,0,119854,4756,378130,88138,106614,1,1494216892
1141726,2,5,3,0,627130,6261,,96840,164289,1,1494406812


In [7]:
validation_data = (
    raw_sample
    .sort("user", "timestamp", nulls_last=True)
    .group_by("user", maintain_order=True)
    .agg(pl.all().tail(num_validation))
    .explode(pl.all().exclude("user"))
    .select(*user_feats, *ad_feats, pl.lit(1).alias("btag").cast(pl.Int8), "timestamp")
)
validation_data

user,gender,age,shopping,occupation,adgroup,cate,brand,campaign,customer,btag,timestamp
u32,u8,u8,u8,u8,u32,u16,u32,u32,u32,i8,u32
4,1,5,2,0,207109,562,200377,244872,136711,1,1494653875
7,1,2,3,0,30074,7266,296265,71832,148537,1,1494674441
14,2,2,3,1,711096,6423,452022,157119,34668,1,1494512118
24,2,1,2,0,656548,6261,,3487,60230,1,1494495452
26,2,4,2,0,645506,6261,269352,404023,54588,1,1494426628
…,…,…,…,…,…,…,…,…,…,…,…
1141718,1,5,3,0,623707,6261,,82320,1734,1,1494392748
1141723,2,1,3,0,28584,2239,215442,56502,139737,1,1494251900
1141725,2,2,3,0,183541,6255,231313,191535,14833,1,1494496336
1141726,2,5,3,0,336202,6342,90847,315038,199245,1,1494473055


In [8]:
first_validation_click = (
    validation_data.select("user", pl.col("timestamp").alias("first_validation_ad_click_time"))
    .sort("user", "first_validation_ad_click_time", nulls_last=True)
    .group_by("user", maintain_order=True).head(1)
)

In [9]:
if "cate" in ad_feats or "brand" in ad_feats:
    # Loading takes ~30s for pretraining dataset from behavior log
    behavior_log = (pl
        .scan_parquet(os.path.join(data_dir, "behavior_log.parquet"))
        .filter(pl.col("user").is_in(raw_sample.select("user").unique()))
    )
    if include_user_features:
        behavior_log = behavior_log.join(
            other=pl.scan_parquet(os.path.join(data_dir, "user_profile.parquet")).select(user_feats),
            on="user", how="inner",
        )
    behavior_log = (behavior_log.collect()
        .join(first_validation_click, on="user", how="inner")
        .filter(pl.col("timestamp") <= pl.col("first_validation_ad_click_time"))
        .unique()
        .select(*user_feats, *pretraining_ad_feats, pl.col("btag").cast(pl.Int8), "timestamp")
    )
    training_data = pl.concat([training_data, behavior_log], how="diagonal")
training_data

user,gender,age,shopping,occupation,adgroup,cate,brand,campaign,customer,btag,timestamp
u32,u8,u8,u8,u8,u32,u16,u32,u32,u32,i8,u32
4,1,5,2,0,144004,6300,,252559,187000,1,1494124371
4,1,5,2,0,661336,562,,195718,75325,1,1494129933
4,1,5,2,0,388902,562,,159335,116625,1,1494129933
4,1,5,2,0,438808,562,,51610,116649,1,1494136895
7,1,2,3,0,30074,7266,296265,71832,148537,1,1494598158
…,…,…,…,…,…,…,…,…,…,…,…
128004,1,2,3,0,,6247,67540,,,0,1492932678
658497,2,4,3,0,,4520,342760,,,0,1493693879
656554,2,5,3,0,,6261,146115,,,0,1494054212
856271,2,2,3,1,,4384,31710,,,0,1492944122


In [10]:
valid_users = training_data.select("user").unique()
validation_data = validation_data.filter(pl.col("user").is_in(valid_users))
validation_data

user,gender,age,shopping,occupation,adgroup,cate,brand,campaign,customer,btag,timestamp
u32,u8,u8,u8,u8,u32,u16,u32,u32,u32,i8,u32
4,1,5,2,0,207109,562,200377,244872,136711,1,1494653875
7,1,2,3,0,30074,7266,296265,71832,148537,1,1494674441
14,2,2,3,1,711096,6423,452022,157119,34668,1,1494512118
24,2,1,2,0,656548,6261,,3487,60230,1,1494495452
26,2,4,2,0,645506,6261,269352,404023,54588,1,1494426628
…,…,…,…,…,…,…,…,…,…,…,…
1141718,1,5,3,0,623707,6261,,82320,1734,1,1494392748
1141723,2,1,3,0,28584,2239,215442,56502,139737,1,1494251900
1141725,2,2,3,0,183541,6255,231313,191535,14833,1,1494496336
1141726,2,5,3,0,336202,6342,90847,315038,199245,1,1494473055


In [11]:
if include_ad_non_clks:
    negatives = (pl
        .scan_parquet(os.path.join(data_dir, "raw_sample.parquet"))
        .filter((pl.col("clk") == False) & (pl.col("user").is_in(valid_users))).collect()
        .join(first_validation_click, on="user", how="inner")
        .filter(pl.col("timestamp") <= pl.col("first_validation_ad_click_time"))
        .unique(["user", "adgroup", "timestamp"])
    )
    if include_user_features:
        negatives = negatives.join(
            other=pl.read_parquet(os.path.join(data_dir, "user_profile.parquet")).select(user_feats),
            on="user", how="inner",
        )
    if include_ad_features:
        negatives = negatives.join(
            other=pl.read_parquet(os.path.join(data_dir, "ad_feature.parquet")).select(ad_feats),
            on="adgroup", how="inner",
        )
    negatives = (negatives
        .select(*user_feats, *ad_feats, pl.lit(-1).alias("btag").cast(pl.Int8), "timestamp")
    )
    training_data = pl.concat([training_data, negatives])
training_data

user,gender,age,shopping,occupation,adgroup,cate,brand,campaign,customer,btag,timestamp
u32,u8,u8,u8,u8,u32,u16,u32,u32,u32,i8,u32
4,1,5,2,0,144004,6300,,252559,187000,1,1494124371
4,1,5,2,0,661336,562,,195718,75325,1,1494129933
4,1,5,2,0,388902,562,,159335,116625,1,1494129933
4,1,5,2,0,438808,562,,51610,116649,1,1494136895
7,1,2,3,0,30074,7266,296265,71832,148537,1,1494598158
…,…,…,…,…,…,…,…,…,…,…,…
128004,1,2,3,0,,6247,67540,,,0,1492932678
658497,2,4,3,0,,4520,342760,,,0,1493693879
656554,2,5,3,0,,6261,146115,,,0,1494054212
856271,2,2,3,1,,4384,31710,,,0,1492944122


In [12]:
user_profile = validation_data.select(user_feats).unique()
user_encoder = OrdinalEncoder(dtype=np.int32).fit(user_profile)
user_encoder.set_output(transform="polars")

In [13]:
ad_feature = pl.concat([
    training_data.select(ad_feats).unique(),
    validation_data.select(ad_feats).unique(),
]).unique()
ad_encoder = OrdinalEncoder(dtype=np.int32, encoded_missing_value=-1).fit(ad_feature)
ad_encoder.set_output(transform="polars")

In [None]:
user_data = user_encoder.transform(training_data.select(user_feats))
ads_data = ad_encoder.transform(training_data.select(ad_feats))
interaction_data = training_data.select("btag", pl.col("timestamp").cast(pl.Int32))
training_data = pl.concat([user_data, ads_data, interaction_data], how="horizontal")

In [None]:
user_data = user_encoder.transform(validation_data.select(user_feats))
ads_data = ad_encoder.transform(validation_data.select(ad_feats))
interaction_data = validation_data.select("btag", pl.col("timestamp").cast(pl.Int32))
validation_data = pl.concat([user_data, ads_data, interaction_data], how="horizontal")

In [16]:
user_profile.write_parquet(os.path.join(data_dir, f"user_profile-{dataset_params}.parquet"))
ad_feature.write_parquet(os.path.join(data_dir, f"ad_feature-{dataset_params}.parquet"))
training_data.write_parquet(os.path.join(data_dir, f"train-{dataset_params}.parquet"))
validation_data.write_parquet(os.path.join(data_dir, f"test-{dataset_params}.parquet"))