In [1]:
import os
import polars as pl
import numpy as np
from sklearn.preprocessing import OrdinalEncoder

In [2]:
data_dir = "../data"
min_num_training = 4
num_validation = 1
include_user_features = False # add gender, age, shopping, occupation to the dataset (filters to users with this data available)
include_ad_features = True # add category, brand, customer, campaign to the dataset (filters to ads with this data available)
include_ad_non_clks = True

In [3]:
dataset_params = f"{min_num_training}_min_train_clks-{num_validation}_test_clks"
if include_user_features:
    dataset_params += "-usr_fts"
if include_ad_features:
    dataset_params += "-ad_fts"
if include_ad_non_clks:
    dataset_params += "-non_clks"

In [4]:
user_feats = ["user"] + (["gender", "age", "shopping", "occupation"] if include_user_features else [])
ad_feats = ["adgroup"] + (["cate", "brand", "campaign", "customer"] if include_ad_features else [])
pretraining_ad_feats = set(["cate", "brand"]).intersection(set(ad_feats))

In [5]:
raw_sample = (pl
    .scan_parquet(os.path.join(data_dir, "raw_sample.parquet"))
    .filter((pl.col("clk") == True) | (pl.col("clk") == (min_num_training + num_validation > 0)))
    .filter(pl.len().over("user") >= min_num_training + num_validation)
    .unique(["user", "adgroup", "timestamp"])
)
if include_user_features:
    raw_sample = raw_sample.join(
        other=pl.scan_parquet(os.path.join(data_dir, "user_profile.parquet")).select(user_feats),
        on="user", how="inner",
    )
if include_ad_features:
    raw_sample = raw_sample.join(
        other=pl.scan_parquet(os.path.join(data_dir, "ad_feature.parquet")).select(ad_feats),
        on="adgroup", how="inner",
    )
raw_sample = raw_sample.collect()
raw_sample

user,adgroup,clk,timestamp,cate,brand,campaign,customer
u32,u32,bool,u32,u16,u32,u32,u32
277625,313401,true,1494024788,6406,87331,83237,1
1124195,248909,true,1494304882,392,32233,83237,1
500571,375706,true,1494226042,4520,,387991,6
43176,23236,true,1494523527,5953,,395195,13
727587,23236,true,1494455217,5953,,395195,13
…,…,…,…,…,…,…,…
326033,799254,true,1494578152,1244,,376455,255837
596729,790628,true,1494307023,1244,,376455,255837
88256,837778,true,1494420815,6432,387520,383275,255841
657116,833539,true,1494558949,6432,293023,377787,255841


In [6]:
training_data = (
    raw_sample
    .filter(pl.len().over("user") > num_validation)
    .sort("user", "timestamp", nulls_last=True)
    .group_by("user", maintain_order=True)
    .agg(pl.all().head(pl.len() - num_validation))
    .explode(pl.all().exclude("user"))
    .select(*user_feats, *ad_feats, pl.lit(1).alias("btag").cast(pl.Int8), "timestamp")
)
training_data

user,adgroup,cate,brand,campaign,customer,btag,timestamp
u32,u32,u16,u32,u32,u32,i8,u32
4,144004,6300,,252559,187000,1,1494124371
4,388902,562,,159335,116625,1,1494129933
4,661336,562,,195718,75325,1,1494129933
4,438808,562,,51610,116649,1,1494136895
7,30074,7266,296265,71832,148537,1,1494598158
…,…,…,…,…,…,…,…
1141723,701127,6300,82527,266843,28529,1,1494251900
1141725,650277,4283,,328716,25728,1,1494032450
1141725,171192,134,329043,263023,112503,1,1494032450
1141725,667859,7867,73436,250629,60061,1,1494138804


In [7]:
validation_data = (
    raw_sample
    .sort("user", "timestamp", nulls_last=True)
    .group_by("user", maintain_order=True)
    .agg(pl.all().tail(num_validation))
    .explode(pl.all().exclude("user"))
    .select(*user_feats, *ad_feats, pl.lit(1).alias("btag").cast(pl.Int8), "timestamp")
)
validation_data

user,adgroup,cate,brand,campaign,customer,btag,timestamp
u32,u32,u16,u32,u32,u32,i8,u32
4,207109,562,200377,244872,136711,1,1494653875
7,30074,7266,296265,71832,148537,1,1494674441
33,660050,6519,,419188,73759,1,1494683589
51,737479,1665,234846,409028,63559,1,1494466444
62,796821,1665,,15197,92329,1,1494681792
…,…,…,…,…,…,…,…
1141672,261874,6261,164153,212019,170245,1,1494502692
1141708,346723,6252,240323,215009,178726,1,1494662035
1141714,840511,4281,82527,118601,28529,1,1494633104
1141723,510595,6261,146115,162480,105620,1,1494251900


In [8]:
first_validation_click = (
    validation_data.select("user", pl.col("timestamp").alias("first_validation_ad_click_time"))
    .sort("user", "first_validation_ad_click_time", nulls_last=True)
    .group_by("user", maintain_order=True).head(1)
)

In [9]:
if ("cate" in ad_feats or "brand" in ad_feats) and not include_ad_non_clks:
    # Loading takes ~30s for pretraining dataset from behavior log
    behavior_log = (pl
        .scan_parquet(os.path.join(data_dir, "behavior_log.parquet"))
        .filter(pl.col("user").is_in(raw_sample.select("user").unique()))
    )
    if include_user_features:
        behavior_log = behavior_log.join(
            other=pl.scan_parquet(os.path.join(data_dir, "user_profile.parquet")).select(user_feats),
            on="user", how="inner",
        )
    behavior_log = (behavior_log.collect()
        .join(first_validation_click, on="user", how="inner")
        .filter(pl.col("timestamp") <= pl.col("first_validation_ad_click_time"))
        .unique()
        .select(*user_feats, *pretraining_ad_feats, pl.col("btag").cast(pl.Int8), "timestamp")
    )
    training_data = pl.concat([training_data, behavior_log], how="diagonal")
training_data

user,adgroup,cate,brand,campaign,customer,btag,timestamp
u32,u32,u16,u32,u32,u32,i8,u32
4,144004,6300,,252559,187000,1,1494124371
4,388902,562,,159335,116625,1,1494129933
4,661336,562,,195718,75325,1,1494129933
4,438808,562,,51610,116649,1,1494136895
7,30074,7266,296265,71832,148537,1,1494598158
…,…,…,…,…,…,…,…
1141723,701127,6300,82527,266843,28529,1,1494251900
1141725,650277,4283,,328716,25728,1,1494032450
1141725,171192,134,329043,263023,112503,1,1494032450
1141725,667859,7867,73436,250629,60061,1,1494138804


In [10]:
valid_users = training_data.select("user").unique()
validation_data = validation_data.filter(pl.col("user").is_in(valid_users))
validation_data

user,adgroup,cate,brand,campaign,customer,btag,timestamp
u32,u32,u16,u32,u32,u32,i8,u32
4,207109,562,200377,244872,136711,1,1494653875
7,30074,7266,296265,71832,148537,1,1494674441
33,660050,6519,,419188,73759,1,1494683589
51,737479,1665,234846,409028,63559,1,1494466444
62,796821,1665,,15197,92329,1,1494681792
…,…,…,…,…,…,…,…
1141672,261874,6261,164153,212019,170245,1,1494502692
1141708,346723,6252,240323,215009,178726,1,1494662035
1141714,840511,4281,82527,118601,28529,1,1494633104
1141723,510595,6261,146115,162480,105620,1,1494251900


In [11]:
if include_ad_non_clks:
    negatives = (pl
        .scan_parquet(os.path.join(data_dir, "raw_sample.parquet"))
        .filter((pl.col("clk") == False) & (pl.col("user").is_in(valid_users))).collect()
        .join(first_validation_click, on="user", how="inner")
        .filter(pl.col("timestamp") <= pl.col("first_validation_ad_click_time"))
        .unique(["user", "adgroup", "timestamp"])
    )
    if include_user_features:
        negatives = negatives.join(
            other=pl.read_parquet(os.path.join(data_dir, "user_profile.parquet")).select(user_feats),
            on="user", how="inner",
        )
    if include_ad_features:
        negatives = negatives.join(
            other=pl.read_parquet(os.path.join(data_dir, "ad_feature.parquet")).select(ad_feats),
            on="adgroup", how="inner",
        )
    negatives = (negatives
        .select(*user_feats, *ad_feats, pl.lit(-1).alias("btag").cast(pl.Int8), "timestamp")
    )
    training_data = pl.concat([training_data, negatives])
training_data

user,adgroup,cate,brand,campaign,customer,btag,timestamp
u32,u32,u16,u32,u32,u32,i8,u32
4,144004,6300,,252559,187000,1,1494124371
4,388902,562,,159335,116625,1,1494129933
4,661336,562,,195718,75325,1,1494129933
4,438808,562,,51610,116649,1,1494136895
7,30074,7266,296265,71832,148537,1,1494598158
…,…,…,…,…,…,…,…
408327,547640,1665,,422088,73020,-1,1494393010
917116,153483,8878,293548,251646,65269,-1,1494582869
743041,446853,6181,31352,32399,47187,-1,1494472016
577795,556055,6519,,223606,213679,-1,1494471608


training_data.write_parquet(os.path.join(data_dir, f"train_min_5_click.parquet"))
validation_data.write_parquet(os.path.join(data_dir, f"test_min_5_click.parquet"))

In [12]:
user_profile = validation_data.select(user_feats).unique()
user_encoder = OrdinalEncoder(dtype=np.int32).fit(user_profile)
user_encoder.set_output(transform="polars")

In [13]:
ad_feature = pl.concat([
    training_data.select(ad_feats).unique(),
    validation_data.select(ad_feats).unique(),
]).unique()
ad_encoder = OrdinalEncoder(dtype=np.int32, encoded_missing_value=-1).fit(ad_feature)
ad_encoder.set_output(transform="polars")

In [14]:
user_data = user_encoder.transform(training_data.select(user_feats))
ads_data = ad_encoder.transform(training_data.select(ad_feats))
interaction_data = training_data.select("btag", pl.col("timestamp").cast(pl.Int32))
training_data = pl.concat([user_data, ads_data, interaction_data], how="horizontal")

In [15]:
user_data = user_encoder.transform(validation_data.select(user_feats))
ads_data = ad_encoder.transform(validation_data.select(ad_feats))
interaction_data = validation_data.select("btag", pl.col("timestamp").cast(pl.Int32))
validation_data = pl.concat([user_data, ads_data, interaction_data], how="horizontal")

In [16]:
min_ad_click = min_num_training + num_validation
user_profile.write_parquet(os.path.join(data_dir, f"user_profile_{min_ad_click}.parquet"))
ad_feature.write_parquet(os.path.join(data_dir, f"ad_feature_{min_ad_click}.parquet"))
training_data.write_parquet(os.path.join(data_dir, f"train_{min_ad_click}.parquet"))
validation_data.write_parquet(os.path.join(data_dir, f"test_{min_ad_click}.parquet"))

In [None]:
user_profile.write_parquet(os.path.join(data_dir, f"user_profile-{dataset_params}.parquet"))
ad_feature.write_parquet(os.path.join(data_dir, f"ad_feature-{dataset_params}.parquet"))
training_data.write_parquet(os.path.join(data_dir, f"train-{dataset_params}.parquet"))
validation_data.write_parquet(os.path.join(data_dir, f"test-{dataset_params}.parquet"))