# Simple stratification in python

In [10]:
import pandas as pd
import polars as pl
import numpy as np

In [12]:
train_frac = 0.6
rng = np.random.default_rng()

In [15]:
dfs = pd.read_csv("data/simple_data.csv")

In [38]:
(
    dfs
    .groupby("Nationality", group_keys=False)
    .transform(
        lambda d: np.where(rng.permutation(len(d)) < len(d) * train_frac, "Train", "Val")
    )
)

Unnamed: 0,Name,Sex,Age,FavoriteSeason
0,Train,Train,Val,Val
1,Val,Train,Train,Val
2,Val,Train,Train,Val
3,Train,Val,Val,Train
4,Val,Val,Train,Train
5,Train,Val,Train,Val
6,Train,Train,Train,Train
7,Val,Train,Val,Train
8,Train,Train,Train,Train
9,Train,Val,Val,Train


In [7]:
(dfs
    .groupby("Nationality", group_keys=False)
    .apply(lambda d: d.sample(int(train_frac * d.shape[0]), random_state=42))
)

Unnamed: 0,Name,Sex,Age,Nationality,FavoriteSeason,Split
3,Dan,M,37,BE,Fall,Train
13,Eva,F,50,BE,Winter,Train
6,George,M,23,BE,Summer,Train
7,Jan,M,47,PL,Spring,Train
14,Ola,F,48,PL,Summer,Train
8,Małgosia,F,28,PL,Fall,Train
1,Bob,M,50,UK,Winter,Train
12,Ann,F,49,UK,Fall,Train
4,Eve,F,37,UK,Summer,Train


In [None]:
def add_split_column_pandas(df):
    train_idx = (
        df
        .groupby("Nationality", group_keys=False)
        .apply(lambda d: d.sample(int(train_frac * d.shape[0]), random_state=42))
        .index
    )
    df["Split"] = "Val"
    df.loc[train_idx, "Split"] = "Train"
    return df

In [29]:
def add_split_column_pandas(df):
    df["Split"] = (
        df
        .groupby("Nationality", group_keys=False)["Age"]
        .transform(
            lambda d: np.where(rng.permutation(len(d)) < len(d) * train_frac, "Train", "Val")
        )
    )
    return df

In [5]:
add_split_column_pandas(dfs)

Unnamed: 0,Name,Sex,Age,Nationality,FavoriteSeason,Split
0,Alice,F,23,UK,Fall,Val
1,Bob,M,50,UK,Winter,Train
2,Cecily,F,46,BE,Spring,Val
3,Dan,M,37,BE,Fall,Train
4,Eve,F,37,UK,Summer,Train
5,Felix,M,54,PL,Winter,Val
6,George,M,23,BE,Summer,Train
7,Jan,M,47,PL,Spring,Train
8,Małgosia,F,28,PL,Fall,Train
9,John,M,23,UK,Fall,Val


### Polars

In [34]:
def add_split_column_polars(df):
    dp = (
        df
        .with_column(
            pl.when(
                (pl.arange(0, pl.count()) < train_frac * pl.count())
                )
            .then("Train")
            .otherwise("Val")
            .shuffle(seed=42)
            .over("Nationality")
            .alias("Split")
        )
    )
    return dp

In [8]:
dl = pl.read_csv("data/simple_data.csv")
add_split_column_polars(dl)

Name,Sex,Age,Nationality,FavoriteSeason,Split
str,str,i64,str,str,str
"""Alice""","""F""",23,"""UK""","""Fall""","""Train"""
"""Bob""","""M""",50,"""UK""","""Winter""","""Val"""
"""Cecily""","""F""",46,"""BE""","""Spring""","""Train"""
"""Dan""","""M""",37,"""BE""","""Fall""","""Val"""
"""Eve""","""F""",37,"""UK""","""Summer""","""Train"""
"""Felix""","""M""",54,"""PL""","""Winter""","""Train"""
"""George""","""M""",23,"""BE""","""Summer""","""Val"""
"""Jan""","""M""",47,"""PL""","""Spring""","""Train"""
"""Małgosia""","""F""",28,"""PL""","""Fall""","""Val"""
"""John""","""M""",23,"""UK""","""Fall""","""Train"""


## Benchmarking

Benchmark on a similar dataframe but with 1.5mln rows.

In [114]:
pl.concat([dl] * 100_000).write_csv("data/simple_data_copied.csv")

In [32]:
%%time
df = pd.read_csv("data/simple_data_copied.csv")
add_split_column_pandas(df)

CPU times: user 1.03 s, sys: 220 ms, total: 1.25 s
Wall time: 1.26 s


Unnamed: 0,Name,Sex,Age,Nationality,FavoriteSeason,Split
0,Alice,F,23,UK,Fall,Train
1,Bob,M,50,UK,Winter,Val
2,Cecily,F,46,BE,Spring,Val
3,Dan,M,37,BE,Fall,Train
4,Eve,F,37,UK,Summer,Train
...,...,...,...,...,...,...
1499995,Bart,M,41,BE,Winter,Val
1499996,Fryderyk,M,59,PL,Summer,Val
1499997,Ann,F,49,UK,Fall,Train
1499998,Eva,F,50,BE,Winter,Train


In [35]:
%%time
dp = pl.read_csv("data/simple_data_copied.csv")
add_split_column_polars(dp)

CPU times: user 865 ms, sys: 233 ms, total: 1.1 s
Wall time: 683 ms


Name,Sex,Age,Nationality,FavoriteSeason,Split
str,str,i64,str,str,str
"""Alice""","""F""",23,"""UK""","""Fall""","""Train"""
"""Bob""","""M""",50,"""UK""","""Winter""","""Train"""
"""Cecily""","""F""",46,"""BE""","""Spring""","""Train"""
"""Dan""","""M""",37,"""BE""","""Fall""","""Train"""
"""Eve""","""F""",37,"""UK""","""Summer""","""Train"""
"""Felix""","""M""",54,"""PL""","""Winter""","""Train"""
"""George""","""M""",23,"""BE""","""Summer""","""Train"""
"""Jan""","""M""",47,"""PL""","""Spring""","""Train"""
"""Małgosia""","""F""",28,"""PL""","""Fall""","""Train"""
"""John""","""M""",23,"""UK""","""Fall""","""Val"""
