In [1]:
import pathlib
import polars as pl

pl.Config(set_fmt_float="full")

<polars.config.Config at 0x78005054b710>

In [2]:
def get_group_by_transforms(dataframe: pl.LazyFrame, group_by_cols: list[str]) -> list[pl.Expr]:
    transforms = []

    for col, col_dtype in zip(dataframe.columns, dataframe.dtypes):
        if col in group_by_cols or col in ["num_group1", "num_group2"]:
            continue
        
        # if col in ["age", "age_days"]:    
        #     transforms.append(pl.col(col).drop_nulls().mode().first().alias(col))
        if isinstance(col_dtype, pl.UInt8):
            transforms.extend([
                pl.col(col).median().alias(f"{col}_median"),
                # pl.col(col).mode().first().alias(f"{col}_mode"),
                # pl.col(col).min().alias(f"{col}_min"),
                # pl.col(col).max().alias(f"{col}_max"),
                # pl.col(col).first().alias(f"{col}_first"),
                # pl.col(col).last().alias(f"{col}_last"),
            ])

        elif isinstance(col_dtype, tuple(pl.NUMERIC_DTYPES)):
            transforms.extend([
                pl.col(col).median().alias(f"{col}_median"),
                # pl.col(col).mode().first().alias(f"{col}_mode"),
                pl.col(col).mean().alias(f"{col}_mean"),
                # pl.col(col).std().alias(f"{col}_std"),
                # pl.col(col).min().alias(f"{col}_min"),
                # pl.col(col).max().alias(f"{col}_max"),
                # pl.col(col).first().alias(f"{col}_first"),
                # pl.col(col).last().alias(f"{col}_last"),
            ])

        elif isinstance(col_dtype, pl.String):
            # transforms.extend([
            #     pl.col(col).mode().first().alias(f"{col}_mode"),
            #     pl.col(col).min().alias(f"{col}_min"),
            #     pl.col(col).max().alias(f"{col}_max"),
            #     pl.col(col).first().alias(f"{col}_first"),
            #     pl.col(col).last().alias(f"{col}_last"),
            # ])
            transforms.append(pl.col(col).mode().first().alias(f"{col}_mode"))

        elif isinstance(col_dtype, pl.Date):
            transforms.append(pl.col(col).first())

    return transforms

def convert_boolean_to_int(dataframe: pl.LazyFrame) -> pl.LazyFrame:
    transforms = []
    for col, col_dtype in zip(dataframe.columns, dataframe.dtypes):
       if isinstance(col_dtype, pl.Boolean):
           transforms.append(pl.col(col).cast(pl.UInt8))
    return dataframe.with_columns(*transforms)

In [3]:
BASE_DIR = pathlib.Path("../data/datalake/silver")

DEPTH_0_PATH = BASE_DIR / "depth_0.parquet"
DEPTH_1_PATH = BASE_DIR / "depth_1.parquet"
DEPTH_2_PATH = BASE_DIR / "depth_2.parquet"
OUTPUT_PATH = BASE_DIR / "preprocessed.parquet"

In [4]:
JOIN_COLUMNS = ["case_id", "num_group1"]
data_depth0: pl.LazyFrame = pl.scan_parquet(DEPTH_0_PATH)
data_depth1: pl.LazyFrame = pl.scan_parquet(DEPTH_1_PATH)
data_depth2: pl.LazyFrame = pl.scan_parquet(DEPTH_2_PATH)

In [5]:
data = data_depth1.join(data_depth2, how="left", on=JOIN_COLUMNS)
data = convert_boolean_to_int(data)

In [6]:
age: pl.LazyFrame = data.select("case_id", "age").group_by("case_id").agg(
    pl.col("age").drop_nulls().mode().first().alias("age")
)

In [7]:
data = data.drop(["age"])

In [8]:
data = data.group_by(JOIN_COLUMNS[0]).agg(
    *get_group_by_transforms(data, group_by_cols=JOIN_COLUMNS[0])
)

In [9]:
data = data_depth0.join(data, how="left", on=JOIN_COLUMNS[0])

In [10]:
data.select("case_id", *[col for col in data.columns if "target" in col]).sort("case_id").collect()

case_id,target
i64,i64
0,0
1,0
2,0
3,0
4,1
5,0
6,0
7,0
8,0
9,0


In [11]:
def save_dataframe_to_disk(filter_col: str, num_percentiles: int = 10, depth: int = 1) -> None:
    SLICE_OUTPUT_BASE_DIR: pathlib.Path = pathlib.Path(f"../data/datalake/silver/all_depth")
    SLICE_OUTPUT_BASE_DIR.mkdir(parents=True, exist_ok=True)

    describe: dict[str, list[str | float]] = data.select(filter_col).describe(percentiles=[x/num_percentiles for x in range(num_percentiles)]).filter((pl.col("statistic").str.contains("%")) | (pl.col("statistic").str.contains("max"))).to_dict(as_series=False)
    case_id_percentiles = sorted(describe.get(filter_col))
    print(case_id_percentiles)

    for i in range(len(case_id_percentiles) - 1):
        print(i)
        slice_data: pl.LazyFrame = data.filter((pl.col(filter_col) >= case_id_percentiles[i]) & (pl.col(filter_col) < case_id_percentiles[i+1]))
        # slice_data.group_by(group_by_col).agg(*get_group_by_transforms(data, [group_by_col])).collect().write_parquet(SLICE_OUTPUT_BASE_DIR / f"slice_{i}.parquet")
        slice_data = slice_data.join(age, on="case_id", how="left")
        slice_data.collect().write_parquet(SLICE_OUTPUT_BASE_DIR / f"slice_{i}.parquet")

    # merge files and delete slices
    pl.scan_parquet(SLICE_OUTPUT_BASE_DIR / "slice_*.parquet").sink_parquet(OUTPUT_PATH)
    for path in pathlib.Path(SLICE_OUTPUT_BASE_DIR).glob("slice*.parquet"):
        path.unlink(missing_ok=True)

    SLICE_OUTPUT_BASE_DIR.rmdir()

if True:
    save_dataframe_to_disk(filter_col="case_id")
    data = pl.scan_parquet(OUTPUT_PATH)
    data.sort(JOIN_COLUMNS[0]).head().collect()

[0.0, 198100.0, 689865.0, 842530.0, 995196.0, 1357358.0, 1510024.0, 1662690.0, 1815355.0, 2550747.0, 2703454.0]
0
1
2
3
4
5
6
7
8
9


In [12]:
data.sort("case_id").select("case_id", "age").head().collect()

case_id,age
i64,f32
0,32.5315055847168
1,61.46575164794922
2,44.12328720092773
3,25.44109535217285
4,25.024658203125
