In [255]:
import os

import pytorch_lightning.utilities.seed
from sklearn.preprocessing import LabelEncoder
import pandas as pd
import numpy as np
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
import torchmetrics

import time
import torch.cuda
from flash.image import ImageClassificationData, ImageClassifier
import flash
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from flash import Trainer

In [256]:
DATA_PATH = '../../Pokemon-data/'
NORMALIZED_DATA = '../../Training-baseline/'
SEED = 42

# define training hyperparameters
INIT_LR = 1e-3
BATCH_SIZE = 64
EPOCHS = 3
FINETUNE_STRATEGY = 'no_freeze'
VERSION = (str(EPOCHS) + 'Epoch-' + FINETUNE_STRATEGY)
# define the train and val splits
TRAIN_SPLIT = 0.75
VAL_SPLIT = 1 - TRAIN_SPLIT

In [257]:
pytorch_lightning.utilities.seed.seed_everything(SEED)

Global seed set to 42


42

In [258]:
metadata = pd.read_csv('metadata/pokemon.csv')

In [259]:
print(metadata.columns)

Index(['abilities', 'against_bug', 'against_dark', 'against_dragon',
       'against_electric', 'against_fairy', 'against_fight', 'against_fire',
       'against_flying', 'against_ghost', 'against_grass', 'against_ground',
       'against_ice', 'against_normal', 'against_poison', 'against_psychic',
       'against_rock', 'against_steel', 'against_water', 'attack',
       'base_egg_steps', 'base_happiness', 'base_total', 'capture_rate',
       'classfication', 'defense', 'experience_growth', 'height_m', 'hp',
       'japanese_name', 'name', 'percentage_male', 'pokedex_number',
       'sp_attack', 'sp_defense', 'speed', 'type1', 'type2', 'weight_kg',
       'generation', 'is_legendary'],
      dtype='object')


In [260]:
pokemon_names = [x for x in os.listdir(NORMALIZED_DATA)]

# remove the IDE metafile that was included in the os.listdir
#pokemon_names = pokemon_names[1:]

In [261]:
# get slice from metadata file based on the pokemon used in training
filtered_metadata = metadata[metadata.name.isin(pokemon_names)]

In [262]:
# convert dataframe to list for easier comparison
filtered_list = filtered_metadata.loc[:,'name'].tolist()

In [263]:
# check for differences in training data and the metadata file
print([x for x in pokemon_names if x not in filtered_list])

[]


In [264]:
filtered_metadata = filtered_metadata.loc[:,['name', 'type1']]

In [265]:
print(filtered_metadata.head)

<bound method NDFrame.head of            name   type1
0     Bulbasaur   grass
1       Ivysaur   grass
2      Venusaur   grass
3    Charmander    fire
4    Charmeleon    fire
..          ...     ...
714     Noivern  flying
715     Xerneas   fairy
716     Yveltal    dark
718     Diancie    rock
720   Volcanion    fire

[703 rows x 2 columns]>


In [266]:
def compile_training_data_to_list():
    all_data = []
    for pokemon in os.listdir(NORMALIZED_DATA):
        all_data += [pokemon + '/' + x for x in os.listdir(NORMALIZED_DATA + pokemon)]

    results = create_annotated_dataframe(all_data)
    return results

def create_annotated_dataframe(all_data):
    base_data = {'file_name': [], 'name': [], 'label': []}
    for item in all_data:
        base_data['file_name'].append(item)
        base_data['name'].append(item.split('/')[0])
        # yes, this is a bit ugly, but we have to match with the metadata
        base_data['label'].append(
            filtered_metadata[
                filtered_metadata['name']==(item.split('/')[0])
            ].loc[:,'type1'].tolist()[0])

    results = create_encoded_dataframe(base_data)
    return results

def create_encoded_dataframe(base_data):
    results = pd.DataFrame(base_data, columns = ['file_name', 'name', 'label'])
    le = LabelEncoder()
    labels = le.fit_transform(results['label'])
    results['label'] = np.int64(labels)
    return results

In [267]:
encoded_data = compile_training_data_to_list()

In [268]:
print(encoded_data)

                            file_name       name  label
0      Abomasnow/dcedzyqfojskcahp.npy  Abomasnow      9
1      Abomasnow/gqfpsmqasdqiknur.npy  Abomasnow      9
2      Abomasnow/imzcvkkckbdchpro.npy  Abomasnow      9
3      Abomasnow/kzibfmivzksykiwy.npy  Abomasnow      9
4      Abomasnow/mjtasvyoonxyilqt.npy  Abomasnow      9
...                               ...        ...    ...
12074   Zweilous/nrpzbrzmxehydoqj.npy   Zweilous      1
12075   Zweilous/qjoppeepmpyujyao.npy   Zweilous      1
12076   Zweilous/sihxufnlbmephyeq.npy   Zweilous      1
12077   Zweilous/vshewhewmkutsdlp.npy   Zweilous      1
12078   Zweilous/ytpdigaymlnyrpbd.npy   Zweilous      1

[12079 rows x 3 columns]


In [269]:
class CustomDataset(Dataset):
    def __init__(self, x, y, img_dir):
        self.x = x
        self.y = y
        self.img_dir = img_dir
        self.classes = np.unique(self.y)


    def __len__(self):
        return len(self.x)


    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.x[idx])
        # use the slice to remove a possible 4th alpha channel
        image = np.load(img_path)[:, :, :3]
        image = image.astype(np.float32)
        label = self.y[idx]
        return image, label

In [270]:
def stratified_split(dataset):
    x_train, x_val, y_train, y_val = train_test_split(dataset['file_name'].to_numpy(),
                                                  dataset['label'].to_numpy(),
                                                  test_size=0.25,
                                                  stratify=dataset['label'],
                                                  random_state=SEED)

    train = CustomDataset(x_train, y_train, NORMALIZED_DATA)
    val = CustomDataset(x_val, y_val, NORMALIZED_DATA)
    return train, val

In [271]:
dataset = compile_training_data_to_list()
train, val = stratified_split(dataset)

In [272]:
datamodule = ImageClassificationData.from_datasets(train_dataset=train,
                                                   val_dataset=val,
                                                   batch_size=BATCH_SIZE,
                                                   )
performance_metrics = [torchmetrics.Accuracy(),
                      torchmetrics.F1Score(num_classes=len(train.classes), average='macro')]

In [273]:
model = ImageClassifier(backbone='efficientnet_b0',
                        pretrained=False,
                        labels=train.classes,
                        metrics=performance_metrics,
                        optimizer="AdamW",
                        learning_rate=INIT_LR, )

logger = CSVLogger(save_dir='logs/', version=VERSION, name='not-pretrained')

early_stop_callback = EarlyStopping(monitor="val_f1score", patience=3, verbose=False, mode="max")

trainer = flash.Trainer(max_epochs=EPOCHS,
                        gpus=torch.cuda.device_count(),
                        logger=logger,
                        callbacks=[early_stop_callback])

Using 'efficientnet_b0' provided by rwightman/pytorch-image-models (https://github.com/rwightman/pytorch-image-models).
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [None]:
startTime = time.time()
trainer.finetune(model,
                 datamodule=datamodule,
                 strategy=FINETUNE_STRATEGY)

endTime = time.time()
print(f"[INFO] total time taken to train the model: {(endTime - startTime) / 60 :.2f}min")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(

  | Name          | Type           | Params
-------------------------------------------------
0 | train_metrics | ModuleDict     | 0     
1 | val_metrics   | ModuleDict     | 0     
2 | test_metrics  | ModuleDict     | 0     
3 | adapter       | DefaultAdapter | 4.0 M 
-------------------------------------------------
4.0 M     Trainable params
0         Non-trainable params
4.0 M     Total params
16.122    Total estimated model params size (MB)
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


Validation sanity check: 0it [00:00, ?it/s]

In [None]:
trainer.save_checkpoint("saved-models/" + VERSION + ".pt")

In [None]:
metrics = pd.read_csv(f'{trainer.logger.log_dir}/metrics.csv')
display(metrics)
del metrics["epoch"]
metrics.set_index("step", inplace=True)

In [None]:
le = LabelEncoder()
labels = le.fit_transform(filtered_metadata['type1'])
filtered_metadata['encoded_label'] = labels

type_list = filtered_metadata['type1'].unique()
enc_type_list = filtered_metadata['encoded_label'].unique()

type_lookup = pd.DataFrame(list(zip(type_list, enc_type_list)), columns=['Type_name', 'encoded_type'])

In [None]:
model.eval()

pred_data = ImageClassificationData.from_files(
    predict_files=[
        '../../Training-baseline/Growlithe/2d7043870a8843f08b3267d9c70885c3.npy',
        '../../Training-baseline/Cloyster/6cb25d6053c4480e98c5b3073d811ec5.npy'
    ],
    batch_size = 1
)


predictions = trainer.predict(model, datamodule=pred_data, output='labels')

print(predictions)

print(f"""
- Pokemon had the predicted type:
{type_lookup.loc[type_lookup['encoded_type'] == predictions[0][0]]}
---------------------------------------------------------------------
- The Pokemon in Question actually has the type:
{filtered_metadata[filtered_metadata['name'] == 'Growlithe']}
""")