In [1]:
import glob
import pandas as pd
import numpy as np
from sklearn.model_selection import StratifiedKFold

from splitter import get_skfold_data

### Files contain the (processed) image data and filename contains label information

In [2]:
image_paths = glob.glob("../data/stage1_imgs/*.jpg")

In [3]:
print(len(image_paths))

print(image_paths[0])

# move list into pandas so we can manipulate it easier
data = pd.DataFrame(image_paths, columns=["file_path"])

data.head()

206616
../data/stage1_imgs/5766_86.jpg


Unnamed: 0,file_path
0,../data/stage1_imgs/5766_86.jpg
1,../data/stage1_imgs/185146_76.jpg
2,../data/stage1_imgs/107259_53.jpg
3,../data/stage1_imgs/28210_68.jpg
4,../data/stage1_imgs/45715_105.jpg


#### Figure out the labels from the file names and create a new column in the dataframe

In [4]:
# make a function so the dataframe apply isn't super messy
def path_to_label(path):
    '''
    Take in a path like: 
    ../data/stage1_imgs/5766_86.jpg
    and extract the label (86)
    
    Returns:
    int label: 86
    '''
    temp = path.split("/")[-1].split(".")[0].split("_")[-1]
    temp = int(temp)
    return temp

In [5]:
data.loc[:, "y"] = data.loc[:, "file_path"].apply(path_to_label)
data.head()

Unnamed: 0,file_path,y
0,../data/stage1_imgs/5766_86.jpg,86
1,../data/stage1_imgs/185146_76.jpg,76
2,../data/stage1_imgs/107259_53.jpg,53
3,../data/stage1_imgs/28210_68.jpg,68
4,../data/stage1_imgs/45715_105.jpg,105


### Convert labels to one-hots

In [6]:
y_hots = pd.get_dummies(data.loc[:, "y"])
y_hots[:5]

Unnamed: 0,1,2,3,4,5,6,7,8,9,10,...,119,120,121,122,123,124,125,126,127,128
0,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


#### Break into 10 splits even though we'll probably only end up using a few, depending on the training time

In [7]:
skf = StratifiedKFold(n_splits=10, shuffle=True)
skf.get_n_splits(data.loc[:, "file_path"], data.loc[:, "y"])

10

In [8]:
idx = 0
for train_index, test_index in skf.split(data.loc[:, "file_path"], data.loc[:, "y"]):
    idx += 1
    
    if idx > 3:
        break
        
    print("train", len(train_index))
    print("test", len(test_index))

train 185903
test 20713
train 185914
test 20702
train 185927
test 20689


In [9]:
test_y = data.loc[test_index, "y"]

In [10]:
label_counts = np.zeros(128, dtype=int)

for item in sorted(test_y.unique()):
    temp = test_y[test_y == item]
    label_counts[item-1] = len(temp)

In [11]:
for label, idx in zip(label_counts, np.arange(1, 129)):
    if idx == 83 or idx == 20 or idx == 2:
        print(idx, "count", label)

2 count 148
20 count 395
83 count 64


#### class 83 (smallest class) should have ~64 items (0.1 fold * 320 files * 2 augmentation) and it has exactly 64 for this fold
#### class 20 (largest class) should have ~394 (0.1 fold * 3940 files * 1 aug)
#### class 2 (mid sized class) should have ~148 (0.1 fold * 1480 files * 1 aug)

### See if the splitter library is working

In [12]:
data_link_dict = get_skfold_data()

In [13]:
list(data_link_dict.keys())

['X_train_1',
 'y_train_1',
 'X_test_1',
 'y_test_1',
 'X_train_2',
 'y_train_2',
 'X_test_2',
 'y_test_2',
 'X_train_3',
 'y_train_3',
 'X_test_3',
 'y_test_3',
 'X_train_4',
 'y_train_4',
 'X_test_4',
 'y_test_4',
 'X_train_5',
 'y_train_5',
 'X_test_5',
 'y_test_5',
 'X_train_6',
 'y_train_6',
 'X_test_6',
 'y_test_6',
 'X_train_7',
 'y_train_7',
 'X_test_7',
 'y_test_7',
 'X_train_8',
 'y_train_8',
 'X_test_8',
 'y_test_8',
 'X_train_9',
 'y_train_9',
 'X_test_9',
 'y_test_9',
 'X_train_10',
 'y_train_10',
 'X_test_10',
 'y_test_10']

In [14]:
data_link_dict["y_train_1"]

array([ 86,  76,  53, ...,   2,  78, 121])

In [26]:
label_counts2 = np.zeros(128, dtype=int)

for item in set(data_link_dict["y_test_1"]):
    temp = data_link_dict["y_test_1"][data_link_dict["y_test_1"] == item]
    label_counts2[item-1] = len(temp)

In [27]:
for label, idx in zip(label_counts2, np.arange(1, 129)):
    if idx == 83 or idx == 20 or idx == 2:
        print(idx, "count", label)

2 count 148
20 count 395
83 count 64
