In [None]:
import os
import numpy as np
import polars as pl
from sklearn.preprocessing import OrdinalEncoder

In [None]:
data_dir = "../data"
min_num_training = 4
num_validation = 1
include_user_ids = True
user_features = ["gender", "age", "shopping", "occupation"]
include_ad_ids = True
ad_features = ["cate", "brand", "customer", "campaign"]

In [3]:
raw_sample = (pl
    .scan_parquet(os.path.join(data_dir, "raw_sample.parquet"))
    .filter((pl.col("clk") == True) | (pl.col("clk") == (min_num_training + num_validation > 0)))
    .filter(pl.len().over("user") >= min_num_training + num_validation)
)
if user_features:
    raw_sample = raw_sample.join(
        other=pl.scan_parquet(os.path.join(data_dir, "user_profile.parquet")).select(["user"] + user_features),
        on="user", how="inner",
    )
if ad_features:
    raw_sample = raw_sample.join(
        other=pl.scan_parquet(os.path.join(data_dir, "ad_feature.parquet")).select(["adgroup"] + ad_features),
        on="adgroup", how="inner",
    )
raw_sample = raw_sample.collect()

In [4]:
valid_users = raw_sample.select("user").unique()

In [5]:
user_feats = (["user"] if include_user_ids else []) + user_features
ad_feats = (["adgroup"] if include_ad_ids else []) + ad_features
pretraining_ad_feats = set(["cate", "brand"]).intersection(set(ad_feats))

In [6]:
ad_feature = raw_sample.select(ad_feats).unique()
if "cate" in ad_features or "brand" in ad_features:
    # Loading takes ~30s for pretraining dataset from behavior log
    behavior_log = (pl
        .scan_parquet(os.path.join(data_dir, "behavior_log.parquet"))
        .filter(pl.col("user").is_in(valid_users))
    )
    if user_features:
        behavior_log = behavior_log.join(
            other=pl.scan_parquet(os.path.join(data_dir, "user_profile.parquet")).select(["user"] + user_features),
            on="user", how="inner",
        )
    behavior_log = behavior_log.collect()
    ad_feature = pl.concat([ad_feature, behavior_log.select(*pretraining_ad_feats).unique()], how="align")

In [7]:
raw_sample

user,adgroup,clk,timestamp,gender,age,shopping,occupation,cate,brand,customer,campaign
u32,u32,bool,u32,u8,u8,u8,u8,u16,u32,u32,u32
642854,102,true,1494264162,2,2,3,0,126,102457,20107,138148
443793,102,true,1494155701,2,4,3,0,126,102457,20107,138148
355080,102,true,1494492568,2,3,3,0,126,102457,20107,138148
843732,102,true,1494420431,2,3,3,0,126,102457,20107,138148
1076956,102,true,1494334701,2,3,3,0,126,102457,20107,138148
…,…,…,…,…,…,…,…,…,…,…,…
102161,846745,true,1494471716,2,3,3,0,6939,416477,52471,397861
1098194,846745,true,1494311997,2,3,3,0,6939,416477,52471,397861
1089310,846781,true,1494071559,2,5,3,0,6361,459870,47122,392860
417228,846781,true,1494528900,2,3,3,0,6361,459870,47122,392860


In [8]:
training_data = (
    raw_sample
    .filter(pl.len().over("user") > num_validation)
    .sort("user", "timestamp", nulls_last=True)
    .group_by("user", maintain_order=True)
    .agg(pl.all().head(pl.len() - num_validation))
    .explode(pl.all().exclude("user"))
    .select(*user_feats, *ad_feats, pl.lit(1).alias("btag").cast(pl.UInt8), "timestamp")
)
validation_data = (
    raw_sample
    .sort("user", "timestamp", nulls_last=True)
    .group_by("user", maintain_order=True)
    .agg(pl.all().tail(num_validation))
    .explode(pl.all().exclude("user"))
    .select(*user_feats, *ad_feats, pl.lit(1).alias("btag").cast(pl.UInt8), "timestamp")
)

In [9]:
validation_data

user,gender,age,shopping,occupation,adgroup,cate,brand,customer,campaign,btag,timestamp
u32,u8,u8,u8,u8,u32,u16,u32,u32,u32,u8,u32
4,1,5,2,0,207109,562,200377,136711,244872,1,1494653875
7,1,2,3,0,30074,7266,296265,148537,71832,1,1494674441
14,2,2,3,1,711096,6423,452022,34668,157119,1,1494512118
24,2,1,2,0,656548,6261,,60230,3487,1,1494495452
26,2,4,2,0,645506,6261,269352,54588,404023,1,1494426628
…,…,…,…,…,…,…,…,…,…,…,…
1141718,1,5,3,0,623707,6261,,1734,82320,1,1494392748
1141723,2,1,3,0,701127,6300,82527,28529,266843,1,1494251900
1141725,2,2,3,0,183541,6255,231313,14833,191535,1,1494496336
1141726,2,5,3,0,336202,6342,90847,199245,315038,1,1494473055


In [10]:
training_data

user,gender,age,shopping,occupation,adgroup,cate,brand,customer,campaign,btag,timestamp
u32,u8,u8,u8,u8,u32,u16,u32,u32,u32,u8,u32
4,1,5,2,0,144004,6300,,187000,252559,1,1494124371
4,1,5,2,0,388902,562,,116625,159335,1,1494129933
4,1,5,2,0,661336,562,,75325,195718,1,1494129933
4,1,5,2,0,438808,562,,116649,51610,1,1494136895
7,1,2,3,0,30074,7266,296265,148537,71832,1,1494598158
…,…,…,…,…,…,…,…,…,…,…,…
1141725,2,2,3,0,667859,7867,73436,60061,250629,1,1494138804
1141725,2,2,3,0,610751,4283,243097,12747,263497,1,1494472604
1141726,2,5,3,0,119854,4756,378130,106614,88138,1,1494216892
1141726,2,5,3,0,627130,6261,,164289,96840,1,1494406812


In [11]:
behavior_log

user,cate,brand,btag,timestamp,gender,age,shopping,occupation
u32,u16,u32,u8,u32,u8,u8,u8,u8
558157,6250,91286,0,1493741625,2,5,3,0
558157,6250,91286,0,1493741626,2,5,3,0
558157,6250,91286,0,1493741627,2,5,3,0
332634,1101,365477,0,1493809895,1,5,3,0
619381,385,428950,0,1493774638,2,4,3,0
…,…,…,…,…,…,…,…,…
1035186,1101,20348,0,1493549440,2,3,3,0
1035186,1101,20348,0,1493549405,2,3,3,0
1035186,1101,20348,0,1493549595,2,3,3,0
1035186,1101,20348,0,1493549516,2,3,3,0


In [16]:
training_data.drop_nulls("adgroup")

user,gender,age,shopping,occupation,adgroup,cate,brand,customer,campaign,btag,timestamp
u32,u8,u8,u8,u8,u32,u16,u32,u32,u32,u8,u32
4,1,5,2,0,144004,6300,,187000,252559,1,1494124371
4,1,5,2,0,388902,562,,116625,159335,1,1494129933
4,1,5,2,0,661336,562,,75325,195718,1,1494129933
4,1,5,2,0,438808,562,,116649,51610,1,1494136895
7,1,2,3,0,30074,7266,296265,148537,71832,1,1494598158
…,…,…,…,…,…,…,…,…,…,…,…
1141725,2,2,3,0,667859,7867,73436,60061,250629,1,1494138804
1141725,2,2,3,0,610751,4283,243097,12747,263497,1,1494472604
1141726,2,5,3,0,119854,4756,378130,106614,88138,1,1494216892
1141726,2,5,3,0,627130,6261,,164289,96840,1,1494406812


In [12]:
training_data = pl.concat([
    training_data,(
    behavior_log
    .join(
        validation_data.select("user", pl.col("timestamp").alias("first_validation_ad_click_time"))
        .sort("user", "first_validation_ad_click_time", nulls_last=True)
        .group_by("user", maintain_order=True).head(1),
        on="user", how="inner"
    ))
    .filter(pl.col("timestamp") <= pl.col("first_validation_ad_click_time"))
    .select(*user_feats, *pretraining_ad_feats, "btag", "timestamp")
], how="diagonal")

In [13]:
training_data

user,gender,age,shopping,occupation,adgroup,cate,brand,customer,campaign,btag,timestamp
u32,u8,u8,u8,u8,u32,u16,u32,u32,u32,u8,u32
4,1,5,2,0,144004,6300,,187000,252559,1,1494124371
4,1,5,2,0,388902,562,,116625,159335,1,1494129933
4,1,5,2,0,661336,562,,75325,195718,1,1494129933
4,1,5,2,0,438808,562,,116649,51610,1,1494136895
7,1,2,3,0,30074,7266,296265,148537,71832,1,1494598158
…,…,…,…,…,…,…,…,…,…,…,…
1035186,2,3,3,0,,1101,20348,,,0,1493549440
1035186,2,3,3,0,,1101,20348,,,0,1493549405
1035186,2,3,3,0,,1101,20348,,,0,1493549595
1035186,2,3,3,0,,1101,20348,,,0,1493549516


In [14]:
validation_data

user,gender,age,shopping,occupation,adgroup,cate,brand,customer,campaign,btag,timestamp
u32,u8,u8,u8,u8,u32,u16,u32,u32,u32,u8,u32
4,1,5,2,0,207109,562,200377,136711,244872,1,1494653875
7,1,2,3,0,30074,7266,296265,148537,71832,1,1494674441
14,2,2,3,1,711096,6423,452022,34668,157119,1,1494512118
24,2,1,2,0,656548,6261,,60230,3487,1,1494495452
26,2,4,2,0,645506,6261,269352,54588,404023,1,1494426628
…,…,…,…,…,…,…,…,…,…,…,…
1141718,1,5,3,0,623707,6261,,1734,82320,1,1494392748
1141723,2,1,3,0,701127,6300,82527,28529,266843,1,1494251900
1141725,2,2,3,0,183541,6255,231313,14833,191535,1,1494496336
1141726,2,5,3,0,336202,6342,90847,199245,315038,1,1494473055


In [None]:
training_data.write_parquet(
    os.path.join(data_dir, f"train_min_{min_num_training+num_validation}_click.parquet")
)
validation_data.write_parquet(
    os.path.join(data_dir, f"test_min_{min_num_training+num_validation}_click.parquet")
)