# Split Data

This notebook performs a stratified split of the data into test and train sets.

## Setup

In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split

In [2]:
IN_PATH = "data/2-generate_features.csv"
pokedex = pd.read_csv(IN_PATH)

## Split

In [3]:
train, test = train_test_split(
    pokedex,
    test_size=0.2,
    shuffle=True,
    random_state=441,
    stratify=pokedex["type_1"],
)

In [4]:
print(len(train))
train.head()

836


Unnamed: 0,generation,status,type_number,type_1,type_2,height_m,weight_kg,abilities_number,total_points,hp,...,damage_from_ground,damage_from_flying,damage_from_psychic,damage_from_bug,damage_from_rock,damage_from_ghost,damage_from_dragon,damage_from_dark,damage_from_steel,damage_from_fairy
531,4.0,Normal,2.0,Dragon,Ground,1.9,95.0,2.0,600.0,108.0,...,1.0,1.0,1.0,1.0,0.5,1.0,2.0,1.0,1.0,2.0
422,3.0,Normal,1.0,Normal,,1.0,22.0,2.0,440.0,60.0,...,1.0,1.0,1.0,1.0,1.0,0.0,1.0,1.0,1.0,1.0
659,5.0,Normal,1.0,Grass,,1.0,28.0,3.0,461.0,75.0,...,0.5,2.0,1.0,2.0,1.0,1.0,1.0,1.0,1.0,1.0
1028,8.0,Legendary,2.0,Fairy,Steel,2.8,355.0,1.0,720.0,92.0,...,2.0,0.5,0.5,0.25,0.5,1.0,0.0,0.5,1.0,0.5
948,8.0,Normal,1.0,Grass,,2.1,90.0,2.0,530.0,100.0,...,0.5,2.0,1.0,2.0,1.0,1.0,1.0,1.0,1.0,1.0


In [5]:
print(len(test))
test.head()

209


Unnamed: 0,generation,status,type_number,type_1,type_2,height_m,weight_kg,abilities_number,total_points,hp,...,damage_from_ground,damage_from_flying,damage_from_psychic,damage_from_bug,damage_from_rock,damage_from_ghost,damage_from_dragon,damage_from_dark,damage_from_steel,damage_from_fairy
624,5.0,Normal,2.0,Psychic,Flying,0.4,2.1,3.0,323.0,65.0,...,0.0,1.0,0.5,1.0,2.0,2.0,1.0,2.0,1.0,1.0
336,3.0,Normal,2.0,Water,Flying,0.6,9.5,3.0,270.0,40.0,...,0.0,1.0,1.0,0.5,2.0,1.0,1.0,1.0,0.5,1.0
272,2.0,Normal,2.0,Water,Rock,0.6,5.0,3.0,410.0,65.0,...,2.0,0.5,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
544,4.0,Normal,1.0,Water,,0.4,7.0,3.0,330.0,49.0,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.5,1.0
416,3.0,Normal,1.0,Water,,0.6,7.4,3.0,200.0,20.0,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.5,1.0


To verify stratification, we count the breakdown per class:

In [6]:
types = set(pokedex["type_1"])
for t in types:
    train_count = (train["type_1"] == t).sum()
    test_count = (test["type_1"] == t).sum()
    total = train_count + test_count
    print(f"{t:<10} {train_count/total:.2f} : {test_count/total:.2f}  (total: {total})")

Steel      0.81 : 0.19  (total: 36)
Ghost      0.81 : 0.19  (total: 42)
Fighting   0.79 : 0.21  (total: 42)
Psychic    0.80 : 0.20  (total: 81)
Grass      0.80 : 0.20  (total: 91)
Poison     0.80 : 0.20  (total: 41)
Dragon     0.80 : 0.20  (total: 41)
Fire       0.80 : 0.20  (total: 65)
Ground     0.80 : 0.20  (total: 41)
Dark       0.80 : 0.20  (total: 46)
Electric   0.79 : 0.21  (total: 62)
Ice        0.81 : 0.19  (total: 37)
Rock       0.80 : 0.20  (total: 60)
Bug        0.80 : 0.20  (total: 81)
Flying     0.75 : 0.25  (total: 8)
Water      0.80 : 0.20  (total: 134)
Normal     0.80 : 0.20  (total: 115)
Fairy      0.77 : 0.23  (total: 22)


## Save Results

In [7]:
OUT_PATH = "data/3-split_data"
TRAIN_PATH = f"{OUT_PATH}.train.csv"
TEST_PATH = f"{OUT_PATH}.test.csv"
train.to_csv(TRAIN_PATH, index=False)
test.to_csv(TEST_PATH, index=False)