In [141]:
import pandas as pd
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
import os

In [142]:
DATA_PATH = '../../Pokemon-data/'
NORMALIZED_DATA = '../../Training-baseline/'
AUGMENTATION_PATH = '../../pipe_one_aug/'
CSV_NAME = "training-list.csv"
SEED = 42
BATCH_SIZE = 64
EPOCHS = 20
MODEL_NAME = "aug-freeze-customLR"
checkpoint_path = '../saved-models/'
TRAINING_METRICS = '../training-metrics/'
n_epochs_stop = 3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [143]:
class CustomDataset(Dataset):
    def __init__(self, x, y, img_dir):
        self.x = x
        self.y = y
        self.img_dir = img_dir
        self.classes = np.unique(self.y)


    def __len__(self):
        return len(self.x)


    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.x[idx])
        # use the slice to remove a possible 4th alpha channel
        image = np.load(img_path)[:,:,:3]
        image = image.astype(np.float32)
        label = self.y[idx]
        return image, label

In [144]:
def stratified_split(dataset, labels):
    x_train, x_val, y_train, y_val = train_test_split(dataset['file_name'].to_numpy(),
                                                      labels,
                                                      test_size=0.25,
                                                      stratify=dataset['label'],
                                                      random_state=SEED)

    trainSet = CustomDataset(x_train, y_train, AUGMENTATION_PATH)
    valSet = CustomDataset(x_val, y_val, AUGMENTATION_PATH)
    return trainSet, valSet

In [145]:
# read data
csv_data = pd.read_csv(f"../metadata/{CSV_NAME}", index_col=0)

y_train = np.zeros((len(csv_data["y_train_columns"]),csv_data["y_train_columns"][0]))
for i in range(csv_data["y_train_columns"][0]):
    y_train[:,i] = csv_data["y_train"+str(i)]

encoded_data = csv_data.drop(csv_data.columns[(csv_data.shape[1]-csv_data["y_train_columns"][0]-1):csv_data.shape[1]], axis=1)

In [146]:
train, val = stratified_split(encoded_data, y_train)
print(len(train.x))
print(len(val.x))
print(val.x)
print(type(val.x))

9059
3020
['Girafarig/azvrgkiycubwviye.npy'
 'Arbok/3e4fd8cdf8c740548826b6de29f18258.npy'
 'Poliwhirl/d843aea788dd48f586ffdf8736dd3c4c.npy' ...
 'Drowzee/f3019e459027400182e15bf74e571c92.npy'
 'Haxorus/ifhrniyifefkmknn.npy'
 'Lickitung/52c26d87db7847789148ea4e3c64819c.npy']
<class 'numpy.ndarray'>


In [147]:
extended_csv = pd.read_csv(f"../metadata/aug-training-list.csv", index_col=0)
print(len(extended_csv))
print(extended_csv.head)

25073
<bound method NDFrame.head of                               file_name       name  label  y_train0  y_train1  \
0      Abomasnow/dcedzyqfojskcahp_0.npy  Abomasnow      9       0.0       0.0   
1      Abomasnow/gqfpsmqasdqiknur_0.npy  Abomasnow      9       0.0       0.0   
2      Abomasnow/imzcvkkckbdchpro_0.npy  Abomasnow      9       0.0       0.0   
3      Abomasnow/kzibfmivzksykiwy_0.npy  Abomasnow      9       0.0       0.0   
4      Abomasnow/mjtasvyoonxyilqt_0.npy  Abomasnow      9       0.0       0.0   
...                                 ...        ...    ...       ...       ...   
25068   Zweilous/nrpzbrzmxehydoqj_0.npy   Zweilous      1       0.0       1.0   
25069   Zweilous/qjoppeepmpyujyao_0.npy   Zweilous      1       0.0       1.0   
25070   Zweilous/sihxufnlbmephyeq_0.npy   Zweilous      1       0.0       1.0   
25071   Zweilous/vshewhewmkutsdlp_0.npy   Zweilous      1       0.0       1.0   
25072   Zweilous/ytpdigaymlnyrpbd_0.npy   Zweilous      1       0.0      

In [148]:
for idx, e in enumerate(val.x):
    name = e.split('.')[0] + '_0.' + e.split('.')[1]
    val.x[idx] = name

In [149]:
print(extended_csv[extended_csv.file_name == "Arbok/3e4fd8cdf8c740548826b6de29f18258_0.npy"])

                                        file_name   name  label  y_train0  \
419  Arbok/3e4fd8cdf8c740548826b6de29f18258_0.npy  Arbok     13       0.0   

     y_train1  y_train2  y_train3  y_train4  y_train5  y_train6  ...  \
419       0.0       0.0       0.0       0.0       0.0       0.0  ...   

     y_train9  y_train10  y_train11  y_train12  y_train13  y_train14  \
419       0.0        0.0        0.0        0.0        1.0        0.0   

     y_train15  y_train16  y_train17  y_train_columns  
419        0.0        0.0        0.0               18  

[1 rows x 22 columns]


In [150]:
extended_csv = extended_csv[~extended_csv.file_name.isin(val.x.tolist())]

y_train = np.zeros((len(extended_csv["y_train_columns"]),extended_csv["y_train_columns"][0]))
for i in range(extended_csv["y_train_columns"][0]):
    y_train[:,i] = extended_csv["y_train"+str(i)]

encoded_data = extended_csv.drop(extended_csv.columns[(extended_csv.shape[1]-extended_csv["y_train_columns"][0]-1):extended_csv.shape[1]], axis=1)

print(len(encoded_data))
print(len(y_train))

22055
22055


In [151]:
print(encoded_data)

                              file_name       name  label
0      Abomasnow/dcedzyqfojskcahp_0.npy  Abomasnow      9
1      Abomasnow/gqfpsmqasdqiknur_0.npy  Abomasnow      9
2      Abomasnow/imzcvkkckbdchpro_0.npy  Abomasnow      9
3      Abomasnow/kzibfmivzksykiwy_0.npy  Abomasnow      9
4      Abomasnow/mjtasvyoonxyilqt_0.npy  Abomasnow      9
...                                 ...        ...    ...
25065   Zweilous/itjsfinjnbuxsymt_0.npy   Zweilous      1
25066   Zweilous/izpqcbvdywrnuwwv_0.npy   Zweilous      1
25067   Zweilous/lcsoqwhymtkbtnow_0.npy   Zweilous      1
25069   Zweilous/qjoppeepmpyujyao_0.npy   Zweilous      1
25070   Zweilous/sihxufnlbmephyeq_0.npy   Zweilous      1

[22055 rows x 3 columns]
