In [69]:
import pandas as pd
import numpy as np

import os



In [123]:
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.model_selection import train_test_split

class TrainTestSplitPipeline ():
    """
    Split data into training and testing sets. If stratify is required,
    will perform stratifield shuffle split based on the selected column

    If the column selected as stratified category, and the bins value is
    provided, the columns will be first cut in to categories then perform
    the split

    Parameters
    ----------
    test_size: float, int, default=0.2
        If float, should be between 0.0 and 1.0, represent the percentage
        of groups to include the test split (rounded up).
        If int, should represent the absolute number of test groups. 
    
    Examples
    --------
    >>> split = TrainTestSplitPipeline()
    >>> split.fit(housing)
    >>> split.stratify(column='some_column', bins = [0, 1, 2, np.inf])
    >>> split.split()
    >>> train_set = split.get_train_set()
    >>> test_set = split.get_test_set()
    """
    def __init__(self, test_size=0.2):
        self.test_size = test_size

        # For practice purpose, below parameters set to fixed value
        self.random_state = 42
        self.n_splits = 1

        pass

    def fit(self, X):
        self.data = X

    def stratify(self, column=None, bins=None ):
        """
        Set stratify categories when perform the splitting

        Parameters
        ----------
        column: string
            Must be one of the column in the fitted data.
        bins: array
            If provided, will try to perform a pd.cut() to the 
            selected column, and use the cutted value as the
            stratify category. Label of 1 based index will be 
            temporarily assigned to the category.
        """
        if bins is not None:
            self.stratify_cat = pd.cut(
                self.data[column], 
                bins=bins, 
                labels=[i+1 for i in range(len(bins)-1)]
                )
        else:
            self.stratify_cat = self.data[column]
        
        pass

    def split(self):
        if self.stratify_cat is not None:
            split = StratifiedShuffleSplit(
                n_splits=self.n_splits, 
                test_size=self.test_size, 
                random_state=self.random_state
                )
            
            for train_idx, text_idx in split.split(self.data, self.stratify_cat):
                self.train_set = self.data.loc[train_idx]
                self.test_set = self.data.loc[text_idx]
        else:
            self.train_set, self.test_set = train_test_split(
                self.data, 
                random_state=self.random_state
                )
        pass
    
    def get_train_set(self):
        return self.train_set.copy()
    
    def get_test_set(self):
        return self.test_set.copy()

In [70]:
HOUSING_PATH = os.path.join('datasets', 'housing')

In [71]:
def load_housing_data(housing_path=HOUSING_PATH):
	csv_path = os.path.join(housing_path, "housing.csv")
	return pd.read_csv(csv_path)

In [120]:
housing = load_housing_data()

In [124]:
split = TrainTestSplitPipeline()
split.fit(housing)
split.stratify(column='median_income', bins = [0, 1.5, 3.0, 4.5, 6, np.inf])
split.split()

train_set = split.get_train_set()
test_set = split.get_test_set()