# GOAL 

Stratified Train Test Split for PySpark 

In [1]:
import pandas as pd
import numpy as np 

In [2]:
import findspark

findspark.init()

In [2]:
from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.sql import functions as pyf
from pyspark.sql.types import *
import functools

import matplotlib.pyplot as plt
import seaborn as sns 

In [3]:
sns.set(rc={'figure.figsize':(10,5)});
sns.set_style("whitegrid") # setting the style

In [4]:
spark = SparkSession\
        .builder\
        .master("local[4]") \
        .appName("churn_pred")\
        .config('spark.driver.memory', '3G')\
        .config('spark.executor.memory', '5G')\
        .getOrCreate()

In [5]:
df = spark.read.format("csv").load('../../input_data/iris_split.csv', header=True, inferSchema=True)

In [6]:
df= df.cache()

In [7]:
df.count()

60

In [9]:
df.groupby('Color').count().show()

+-----+-----+
|Color|count|
+-----+-----+
|  Red|   30|
| Blue|   20|
|Green|   10|
+-----+-----+



TRAIN TEST SPLIT  
0.8 : 0.2

In [12]:
fractions = df.select('Color').distinct().withColumn("fraction", pyf.lit(0.8)).rdd.collectAsMap()

In [13]:
fractions

{'Red': 0.8, 'Blue': 0.8, 'Green': 0.8}

In [14]:
sampled_df = df.stat.sampleBy("Color", fractions, seed=42)

In [15]:
sampled_df.groupby('Color').count().show()

+-----+-----+
|Color|count|
+-----+-----+
|  Red|   22|
| Blue|   17|
|Green|    7|
+-----+-----+



In [18]:
sampled_df.count()

46

In [16]:
sampled_df_test = df.subtract(sampled_df)

In [17]:
sampled_df_test.groupby('Color').count().show()

+-----+-----+
|Color|count|
+-----+-----+
|Green|    3|
|  Red|    8|
| Blue|    3|
+-----+-----+



In [19]:
spark.stop()