In [1]:
import polars as pl
import polars.selectors as cs

from predictables.encoding.src.lagged_mean_encoding import (
    CredWtdMean,
    DynamicRollingSum,
)

In [2]:
lf = pl.read_excel(
    "/app/ts_testing.xlsx", sheet_name="Values", engine="calamine"
).write_parquet("ts_testing.parquet")

In [3]:
lf = pl.scan_parquet("/app/ts_testing.parquet").select(
    [
        "index",
        "date",
        "product_cat",
        "product_subcat",
        "product_code",
        "hit",
        "total_30_30_average_laplace(1)_smoothed",
        "code_30_30_average_laplace(1)_smoothed",
        "cred_wtd_30_30_average",
        "cred_wtd_30_30_Z",
        "cred_wtd_30_30_n",
    ]
)

lf = lf.with_columns(
    [pl.col(c).cast(pl.Categorical).name.keep() for c in lf.select(cs.string()).columns]
)

lf = (
    DynamicRollingSum()
    .lf(lf)
    .date_col("date")
    .x_col("hit")
    .cat_col("product_code")
    .index_col("index")
    .offset(30)
    .window(30)
    .rejoin(True)
    .run()
)

lf = (
    CredWtdMean()
    .lf(lf)
    .date_col("date")
    .numerator_col("hit")
    .denominator_col("count")
    .cat_col("product_code")
    .index_col("index")
    .offset(30)
    .window(30)
    .rejoin(True)
    .laplace_alpha(1)
    .run()
)

lf.tail(10).collect()

index,date,product_cat,product_subcat,product_code,hit,total_30_30_average_laplace(1)_smoothed,code_30_30_average_laplace(1)_smoothed,cred_wtd_30_30_average,cred_wtd_30_30_Z,cred_wtd_30_30_n,ROLLING_SUM(hit[product_code])[lag:30/win:30],count,individual,collective,n,Z,CRED_WTD(hit[product_code])[lag:30/win:30]
i64,date,cat,cat,cat,i64,f64,f64,f64,f64,i64,f64,i32,f64,f64,f64,f64,f64
2549,2023-08-25,"""D""","""e""","""D-e""",1,0.2459,1.0,0.33891,0.28571,2,0.0,1,1.0,0.245902,2.0,0.285714,0.461358
2550,2023-08-25,"""C""","""e""","""C-e""",1,0.2459,0.25,0.27088,0.375,3,0.0,1,0.25,0.245902,3.0,0.375,0.247439
2551,2023-08-26,"""A""","""f""","""A-f""",1,0.23333,0.33333,0.18333,0.5,5,0.0,1,0.333333,0.233333,5.0,0.5,0.283333
2552,2023-08-26,"""D""","""d""","""D-d""",0,0.23333,1.0,0.28419,0.16667,1,0.0,1,1.0,0.233333,1.0,0.166667,0.361111
2553,2023-08-27,"""C""","""e""","""C-e""",0,0.25,0.25,0.2886,0.375,3,0.0,1,0.25,0.25,3.0,0.375,0.25
2554,2023-08-27,"""A""","""c""","""A-c""",0,0.25,0.2,0.20238,0.44444,4,0.0,1,0.2,0.25,4.0,0.444444,0.227778
2555,2023-08-28,"""B""","""c""","""B-c""",0,0.25424,0.16667,0.21046,0.5,5,0.0,1,0.166667,0.254237,5.0,0.5,0.210452
2556,2023-08-28,"""C""","""d""","""C-d""",0,0.25424,0.6,0.28013,0.44444,4,0.0,1,0.6,0.254237,4.0,0.444444,0.40791
2557,2023-08-28,"""C""","""c""","""C-c""",0,0.25424,0.5,0.26395,0.16667,1,0.0,1,0.5,0.254237,1.0,0.166667,0.295198
2558,2023-08-28,"""C""","""e""","""C-e""",0,0.25424,0.25,0.27609,0.375,3,0.0,1,0.25,0.254237,3.0,0.375,0.252648


In [4]:
lf.tail(10).select(
    [
        pl.col(c)
        for c in [
            "total_30_30_average_laplace(1)_smoothed",
            "code_30_30_average_laplace(1)_smoothed",
            "cred_wtd_30_30_average",
            "cred_wtd_30_30_Z",
            "cred_wtd_30_30_n",
            "CRED_WTD(hit[product_code])[lag:30/win:30]",
        ]
    ]
).with_columns(
    [
        (
            pl.col("code_30_30_average_laplace(1)_smoothed")
            * pl.col("cred_wtd_30_30_Z")
            + pl.col("total_30_30_average_laplace(1)_smoothed")
            * (pl.lit(1) - pl.col("cred_wtd_30_30_Z"))
        ).alias("cw")
    ]
).collect()

total_30_30_average_laplace(1)_smoothed,code_30_30_average_laplace(1)_smoothed,cred_wtd_30_30_average,cred_wtd_30_30_Z,cred_wtd_30_30_n,CRED_WTD(hit[product_code])[lag:30/win:30],cw
f64,f64,f64,f64,i64,f64,f64
0.2459,1.0,0.33891,0.28571,2,0.461358,0.461354
0.2459,0.25,0.27088,0.375,3,0.247439,0.2474375
0.23333,0.33333,0.18333,0.5,5,0.283333,0.28333
0.23333,1.0,0.28419,0.16667,1,0.361111,0.361111
0.25,0.25,0.2886,0.375,3,0.25,0.25
0.25,0.2,0.20238,0.44444,4,0.227778,0.227778
0.25424,0.16667,0.21046,0.5,5,0.210452,0.210455
0.25424,0.6,0.28013,0.44444,4,0.40791,0.40791
0.25424,0.5,0.26395,0.16667,1,0.295198,0.295201
0.25424,0.25,0.27609,0.375,3,0.252648,0.25265


In [5]:
# lf = (
#     LaplaceSmoothedMean()
#     .lf(lf)
#     .date_col("date")
#     .numerator_col("hit")
#     .denominator_col("count")
#     .cat_col("product_code")
#     .index_col("index")
#     .offset(30)
#     .window(30)
#     .rejoin(True)
#     .laplace_alpha(1)
#     .rename("individual")
#     .run()
# )

# lf = (
#     LaplaceSmoothedMean()
#     .lf(lf)
#     .date_col("date")
#     .numerator_col("hit")
#     .denominator_col("count")
#     .index_col("index")
#     .offset(30)
#     .window(30)
#     .rejoin(True)
#     .laplace_alpha(1)
#     .rename("complement")
#     .run()
# )

# # lf.tail().collect()

In [6]:
# lf = (
#     (
#         DynamicRollingCount()
#         .lf(lf)
#         .date_col("date")
#         .x_col("hit")
#         .cat_col("product_code")
#         .index_col("index")
#         .offset(30)
#         .window(30)
#         .rejoin(True)
#         .op("ROLLING_COUNT")
#         .rename("n")
#         .run()
#     )
#     .with_columns([pl.lit(5).cast(pl.Float64).alias("K")])
#     .with_columns([pl.col("n").truediv(pl.col("n") + pl.col("K")).alias("Z")])
# ).with_columns(
#     [
#         (pl.col("individual") * pl.col("Z") + pl.col("complement") * (1 - pl.col("Z")))
#         .round(5)
#         .alias("cred_wtd")
#     ]
# )

# lf.tail().collect()

In [7]:
# lf = (
#     CredWtdMean()
#     .lf(lf)
#     .date_col("date")
#     .numerator_col("hit")
#     .denominator_col("count")
#     .cat_col("product_code")
#     .index_col("index")
#     .offset(30)
#     .window(30)
#     .rejoin(True)
#     .laplace_alpha(1)
#     # .rename("cred_wtd")
#     .run()
# ).drop(["count", "complement", "K"])

# lf.tail().collect()