In [None]:
import os
import polars as pl

In [None]:
data_dir = os.path.join("..", "raw_data")
min_num_training = 0
num_validation = 1
include_behavior_log = True
include_ad_non_clks = True

In [None]:
user_feats = ["user", "gender", "age", "shopping", "occupation"]
ad_feats = ["adgroup", "cate", "brand", "campaign", "customer"]
pretraining_ad_feats = ["cate", "brand"]

In [None]:
raw_sample = (pl
    .scan_parquet(os.path.join(data_dir, "raw_sample.parquet"))
    .filter(pl.col("clk") == True)
    .filter(pl.len().over("user") >= min_num_training + num_validation)
    .unique(["user", "adgroup", "timestamp"])
    .with_columns(timediff = pl.col("timestamp").diff().over("user", "adgroup", order_by="timestamp"))
    .join(
        other=pl.scan_parquet(os.path.join(data_dir, "ad_feature.parquet")).select(ad_feats),
        on="adgroup", how="inner",
    )
    .collect())
raw_sample

In [None]:
training_data = (
    raw_sample
    .filter(pl.len().over("user") > num_validation)
    .sort("user", "timestamp", nulls_last=True)
    .group_by("user", maintain_order=True)
    .agg(pl.all().head(pl.len() - num_validation))
    .explode(pl.all().exclude("user"))
    .select("user", *ad_feats, pl.lit(1).alias("btag").cast(pl.Int32), "timestamp", "timediff")
)
training_data

In [None]:
validation_data = (
    raw_sample
    .sort("user", "timestamp", nulls_last=True)
    .group_by("user", maintain_order=True)
    .agg(pl.all().tail(num_validation))
    .explode(pl.all().exclude("user"))
    .select("user", *ad_feats, pl.lit(1).alias("btag").cast(pl.Int32), "timestamp", "timediff")
)
validation_data

In [None]:
del raw_sample

In [None]:
validation_data.join(training_data, on=["user", "adgroup"], how="inner") \
    .filter(pl.col("timestamp") <= pl.col("timestamp_right"))

In [None]:
first_validation_click = (
    validation_data.select("user", pl.col("timestamp").alias("first_validation_ad_click_time"))
    .sort("user", "first_validation_ad_click_time", nulls_last=True)
    .group_by("user", maintain_order=True).head(1)
)
first_validation_click

In [None]:
training_data = (training_data
    .join(first_validation_click, on = "user")
    .filter(pl.col("timestamp") < pl.col("first_validation_ad_click_time"))
    .select(pl.all().exclude("first_validation_ad_click_time"))
)
training_data

In [None]:
if include_behavior_log and ("cate" in ad_feats or "brand" in ad_feats):
    training_data = pl.concat([training_data, pl
        .scan_parquet(os.path.join(data_dir, "behavior_log.parquet"))
        .filter(pl.col("user").is_in(first_validation_click.select("user").unique())).collect()
        .join(first_validation_click, on="user", how="inner")
        .filter(pl.col("timestamp") < pl.col("first_validation_ad_click_time"))
        .unique()
        .with_columns(timediff = pl.col("timestamp").diff().over("user", "cate", "brand", "btag", order_by="timestamp"))
        .select("user", *pretraining_ad_feats, pl.col("btag").cast(pl.Int32), "timestamp", "timediff")
    ], how="diagonal")
training_data

In [None]:
valid_users = training_data.select("user").unique()
validation_data = validation_data.filter(pl.col("user").is_in(valid_users))
validation_data

In [None]:
if include_ad_non_clks:
    training_data = pl.concat([training_data, pl
        .scan_parquet(os.path.join(data_dir, "raw_sample.parquet"))
        .filter((pl.col("clk") == False) & (pl.col("user").is_in(valid_users))).collect()
        .join(first_validation_click, on="user", how="inner")
        .filter(pl.col("timestamp") < pl.col("first_validation_ad_click_time"))
        .unique(["user", "adgroup", "timestamp"])
        .with_columns(timediff = pl.col("timestamp").diff().over("user", "adgroup", order_by="timestamp"))
        .join(
            other=pl.read_parquet(os.path.join(data_dir, "ad_feature.parquet")).select(ad_feats),
            on="adgroup", how="inner",
        )
        .select("user", *ad_feats, pl.lit(-1).alias("btag").cast(pl.Int32), "timestamp", "timediff")
    ], how="vertical")
training_data

In [None]:
user_profile = (validation_data
    .select("user")
    .unique()
    .join(
        pl.read_parquet(os.path.join(data_dir, "user_profile.parquet")),
        on = "user", how = "left"
    )
    .select(user_feats)
    .unique()
)
ad_feature: pl.DataFrame = pl.concat([
    training_data.select(ad_feats).unique(),
    validation_data.select(ad_feats).unique(),
], how="vertical").unique()

In [None]:
training_data = (training_data
    .join(user_profile, on = "user", how = "left")
    .select(*user_feats, *ad_feats, "btag", "timestamp", "timediff")
)
validation_data = (validation_data
    .join(user_profile, on = "user", how = "left")
    .select(*user_feats, *ad_feats, "btag", "timestamp", "timediff")
)

In [None]:
outdir = os.path.join("..", "data")
user_profile.write_parquet(os.path.join(outdir, "user_profile.parquet"))
ad_feature.write_parquet(os.path.join(outdir, "ad_feature.parquet"))
training_data.write_parquet(os.path.join(outdir, "train_raw.parquet"))
validation_data.write_parquet(os.path.join(outdir, "test_raw.parquet"))