---

<!-- <a href="https://github.com/rraadd88/roux/blob/master/examples/roux_stat_classify.ipynb"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
 -->
 
## 🏷 Classification.

In [1]:
# install extra requirements
# !pip install roux[stat]
# loading non-roux requirements
import pandas as pd

## Demo data

In [2]:
## random state
import numpy as np
np.random.seed(1)
## demo dataframe
data=pd._testing.makeDataFrame()
data.head(1)

Unnamed: 0,A,B,C,D
LRmijlfpaq,-0.650355,-0.541145,-0.089155,0.415311


## Split table into overlapping subsets/splits
Uses [`sklearn.model_selection.KFold`](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.KFold.html)

In [3]:
%time
## import the convenience function from roux
from roux.stat.classify import get_cvsplits
cvs=get_cvsplits(
    X=data.loc[:,['A','B']],
    y=data['C'],
    cv=5,
    random_state=1,
    outtest=False,
    )

CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 4.53 µs


In [4]:
## Example splits
cvs[0]['X'].head(1)

Unnamed: 0,A,B
0,-0.650355,-0.541145


In [5]:
cvs[0]['y'].head(1)

0   -0.089155
Name: C, dtype: float64

In [6]:
## testing the fraction of the data per fold 
for i in cvs:
    data_mapped=data.copy()
    for k in ['X','y']:
        df_=cvs[i][k].to_frame() if isinstance(cvs[i][k],pd.Series) else cvs[i][k]
        data_mapped=(data_mapped
        .merge(
            right=df_,
            on=df_.columns.tolist(),
            how='left',
            validate="1:1",
            indicator=True,
        )
        .rename(columns={'_merge':k},errors='raise')
        )
    assert (data_mapped['X']==data_mapped['y']).all()
    print(f'CV#{i}:',data_mapped['X'].map({'both':'subset','left_only':'left out'}).value_counts(normalize=True).to_dict())
        # break

CV#0: {'subset': 0.8, 'left out': 0.2}
CV#1: {'subset': 0.8, 'left out': 0.2}
CV#2: {'subset': 0.8, 'left out': 0.2}
CV#3: {'subset': 0.8, 'left out': 0.2}
CV#4: {'subset': 0.8, 'left out': 0.2}


In [7]:
## validate that all the data is covered by the folds
assert len(set(np.array([cvs[i]['X']['A'].tolist() for i in cvs]).ravel()) - set(data['A'].tolist())) ==0

#### Documentation
[`roux.stat.classify`](https://github.com/rraadd88/roux#module-rouxstatclassify)