In [27]:
from dataclasses import dataclass, field

import numpy as np
import polars as pl
from sklearn.compose import ColumnTransformer
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline, make_pipeline
from sklearn.preprocessing import OneHotEncoder


@dataclass
class Model:
    df: pl.LazyFrame
    columns: list[str]
    zeros: bool = True
    pipeline: Pipeline = field(init=False)

    def __post_init__(self):
        self.pipeline = self._get_model()

    def fit(self, start_date, end_date, *, zeros: bool | None = None):
        data = self._get_data(start_date, end_date, zeros=zeros)
        X, y = self._split_Xy(data)
        self.pipeline.fit(X, y)
        return self

    def predict(self, start_date, end_date, *, zeros: bool | None = None):
        data = self._get_data(start_date, end_date, zeros=zeros)
        X, y = self._split_Xy(data)
        return self.pipeline.predict(X)

    def score(self, start_date, end_date, *, zeros: bool | None = None):
        data = self._get_data(start_date, end_date, zeros=zeros)
        X, y = self._split_Xy(data)
        return self.pipeline.score(X, y)

    def _get_categories(self, name: str):
        return self.df.select(pl.col(name).cat.get_categories())

    def _get_zeros(self, date_range):
        dates = pl.DataFrame({"date": date_range}).lazy()
        for name in self.columns:
            cats = self._get_categories(name).cast({name: pl.Categorical})
            dates = dates.join(cats, how="cross")
        return dates

    def _get_data(self, start_date, end_date, *, zeros: bool | None):
        if zeros is None:
            zeros = self.zeros
        data = (
            self.df.group_by("date", *self.columns)
            .agg(pl.col("total_sales").sum())
            .sort("date")
            .filter(pl.col("date").is_between(start_date, end_date))
        )
        if zeros:
            data = self._get_zeros(
                pl.date_range(start_date, end_date, eager=True)
            ).join(data, on=["date", *self.columns], how="left")
        return data.with_columns(
            pl.col("total_sales").fill_null(0.0),
            weekday=pl.col(name="date").dt.weekday(),
        ).collect()

    def _split_Xy(self, df: pl.DataFrame):
        X = df.select("date", "weekday", *self.columns).to_pandas()
        y = df.select("total_sales").to_pandas()
        return X, y

    def _get_model(self):
        pipeline = make_pipeline(
            ColumnTransformer(
                [
                    (
                        "encode",
                        OneHotEncoder(
                            categories=[np.arange(1, 8)]
                            + [
                                self._get_categories(c).collect().to_series().to_list()
                                for c in self.columns
                            ]
                        ),
                        ["weekday", *self.columns],
                    ),
                ]
            ),
            LinearRegression(fit_intercept=True),
        )
        return pipeline


df = pl.scan_parquet("../data/wide.parquet")

In [51]:
model = Model(df, ["store_id", "subgroup"])
date = pl.date(2023, 1, 1)
model.fit(date - pl.duration(days=7), date)
model.score(date, date + pl.duration(days=7))

0.17494030501152436

In [2]:
df_test = (
    pl.scan_csv("../data/ids_test.csv")
    .with_columns(
        pl.col("STORE_SUBGROUP_DATE_ID")
        .str.split_exact("_", 3)
        .struct.rename_fields(["store_id", "subgroup", "date"])
        .struct.unnest()
    )
    .with_columns(pl.col("subgroup").str.replace("Basketball", "Baseball"))
    .with_columns(
        pl.col("store_id").cast(pl.Categorical),
        pl.col("subgroup").cast(pl.Categorical),
        pl.col("date").cast(pl.Date),
    )
    .with_columns(weekday=pl.col("date").dt.weekday())
    .collect()
)

df_test

STORE_SUBGROUP_DATE_ID,store_id,subgroup,date,weekday
str,cat,cat,date,i8
"""S00001_Laptops_2024-01-01""","""S00001""","""Laptops""",2024-01-01,1
"""S00001_Laptops_2024-01-02""","""S00001""","""Laptops""",2024-01-02,2
"""S00001_Laptops_2024-01-03""","""S00001""","""Laptops""",2024-01-03,3
"""S00001_Laptops_2024-01-04""","""S00001""","""Laptops""",2024-01-04,4
"""S00001_Laptops_2024-01-05""","""S00001""","""Laptops""",2024-01-05,5
…,…,…,…,…
"""S00157_Puzzles_2024-01-03""","""S00157""","""Puzzles""",2024-01-03,3
"""S00157_Puzzles_2024-01-04""","""S00157""","""Puzzles""",2024-01-04,4
"""S00157_Puzzles_2024-01-05""","""S00157""","""Puzzles""",2024-01-05,5
"""S00157_Puzzles_2024-01-06""","""S00157""","""Puzzles""",2024-01-06,6


In [13]:
date = pl.date(2023, 12, 25)
model.fit(date, date + pl.duration(days=7))
model.score(date, date + pl.duration(days=7))

0.1638075850980094

In [14]:
X_test = df_test.to_pandas()
y_test = model.pipeline.predict(X_test)
df_test = df_test.with_columns(predict=y_test.squeeze())
df_test

STORE_SUBGROUP_DATE_ID,store_id,subgroup,date,weekday,predict
str,cat,cat,date,i8,f64
"""S00001_Laptops_2024-01-01""","""S00001""","""Laptops""",2024-01-01,1,394.960479
"""S00001_Laptops_2024-01-02""","""S00001""","""Laptops""",2024-01-02,2,457.726155
"""S00001_Laptops_2024-01-03""","""S00001""","""Laptops""",2024-01-03,3,460.091836
"""S00001_Laptops_2024-01-04""","""S00001""","""Laptops""",2024-01-04,4,461.10092
"""S00001_Laptops_2024-01-05""","""S00001""","""Laptops""",2024-01-05,5,451.168731
…,…,…,…,…,…
"""S00157_Puzzles_2024-01-03""","""S00157""","""Puzzles""",2024-01-03,3,54.88456
"""S00157_Puzzles_2024-01-04""","""S00157""","""Puzzles""",2024-01-04,4,55.893644
"""S00157_Puzzles_2024-01-05""","""S00157""","""Puzzles""",2024-01-05,5,45.961455
"""S00157_Puzzles_2024-01-06""","""S00157""","""Puzzles""",2024-01-06,6,52.168176
