# Split Data

This notebook performs a stratified split of the data into test and train sets.

## Setup

In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split

In [2]:
IN_PATH = "../data/2-generate_features.csv"
pokedex = pd.read_csv(IN_PATH)

## Split

In [3]:
train, test = train_test_split(
    pokedex,
    test_size=0.2,
    shuffle=True,
    random_state=441,
    stratify=pokedex["type_1"],
)

In [4]:
print(len(train))
train.head()

836


Unnamed: 0,generation,status,type_number,type_1,type_2,height_m,weight_kg,abilities_number,total_points,hp,...,sprite_red_mean,sprite_green_mean,sprite_blue_mean,sprite_brightness_mean,sprite_red_sd,sprite_green_sd,sprite_blue_sd,sprite_brightness_sd,sprite_overflow_vertical,sprite_overflow_horizontal
531,4.0,Normal,2.0,Dragon,Ground,1.9,95.0,2.0,600.0,108.0,...,0.431611,0.399533,0.534143,0.455096,0.259289,0.2057,0.339243,0.21828,0.0,0.0
422,3.0,Normal,1.0,Normal,,1.0,22.0,2.0,440.0,60.0,...,0.489944,0.528789,0.327594,0.448776,0.267602,0.273251,0.149213,0.219278,0.0,0.0
659,5.0,Normal,1.0,Grass,,1.0,28.0,3.0,461.0,75.0,...,0.426992,0.461558,0.325674,0.404741,0.317187,0.320635,0.220401,0.265111,0.0,0.0
1028,8.0,Legendary,2.0,Fairy,Steel,2.8,355.0,1.0,720.0,92.0,...,0.472524,0.460141,0.384571,0.439079,0.34165,0.313592,0.274833,0.287598,0.196429,0.0
948,8.0,Normal,1.0,Grass,,2.1,90.0,2.0,530.0,100.0,...,0.318634,0.370976,0.286005,0.325205,0.239398,0.222001,0.168121,0.197868,0.0,0.0


In [5]:
print(len(test))
test.head()

209


Unnamed: 0,generation,status,type_number,type_1,type_2,height_m,weight_kg,abilities_number,total_points,hp,...,sprite_red_mean,sprite_green_mean,sprite_blue_mean,sprite_brightness_mean,sprite_red_sd,sprite_green_sd,sprite_blue_sd,sprite_brightness_sd,sprite_overflow_vertical,sprite_overflow_horizontal
624,5.0,Normal,2.0,Psychic,Flying,0.4,2.1,3.0,323.0,65.0,...,0.339377,0.386275,0.4409,0.38885,0.279958,0.323926,0.380297,0.319476,0.0,0.0
336,3.0,Normal,2.0,Water,Flying,0.6,9.5,3.0,270.0,40.0,...,0.51246,0.51946,0.533966,0.521962,0.430521,0.43002,0.449841,0.430802,0.0,0.0
272,2.0,Normal,2.0,Water,Rock,0.6,5.0,3.0,410.0,65.0,...,0.602974,0.394052,0.415114,0.470714,0.434139,0.28653,0.297301,0.330996,0.0,0.0
544,4.0,Normal,1.0,Water,,0.4,7.0,3.0,330.0,49.0,...,0.42693,0.437081,0.504963,0.456325,0.232812,0.232317,0.295431,0.235667,0.0,0.0
416,3.0,Normal,1.0,Water,,0.6,7.4,3.0,200.0,20.0,...,0.456187,0.435818,0.394202,0.428736,0.344905,0.316402,0.321709,0.307814,0.0,0.0


To verify stratification, we count the breakdown per class:

In [6]:
types = set(pokedex["type_1"])
for t in types:
    train_count = (train["type_1"] == t).sum()
    test_count = (test["type_1"] == t).sum()
    total = train_count + test_count
    print(f"{t:<10} {train_count/total:.2f} : {test_count/total:.2f}  (total: {total})")

Psychic    0.80 : 0.20  (total: 81)
Fighting   0.79 : 0.21  (total: 42)
Grass      0.80 : 0.20  (total: 91)
Ground     0.80 : 0.20  (total: 41)
Ice        0.81 : 0.19  (total: 37)
Fire       0.80 : 0.20  (total: 65)
Poison     0.80 : 0.20  (total: 41)
Electric   0.79 : 0.21  (total: 62)
Rock       0.80 : 0.20  (total: 60)
Bug        0.80 : 0.20  (total: 81)
Dark       0.80 : 0.20  (total: 46)
Steel      0.81 : 0.19  (total: 36)
Flying     0.75 : 0.25  (total: 8)
Normal     0.80 : 0.20  (total: 115)
Ghost      0.81 : 0.19  (total: 42)
Dragon     0.80 : 0.20  (total: 41)
Water      0.80 : 0.20  (total: 134)
Fairy      0.77 : 0.23  (total: 22)


## Save Results

In [7]:
OUT_PATH = "../data/3-split_data"
TRAIN_PATH = f"{OUT_PATH}.train.csv"
TEST_PATH = f"{OUT_PATH}.test.csv"
train.to_csv(TRAIN_PATH, index=False)
test.to_csv(TEST_PATH, index=False)