In [211]:
import os
from sklearn.preprocessing import LabelEncoder
import pandas as pd
import numpy as np
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
import torchmetrics

import time
import torch.cuda
from flash.image import ImageClassificationData, ImageClassifier
import flash
from pytorch_lightning.loggers import CSVLogger
from timm.loss import LabelSmoothingCrossEntropy
import seaborn as sns
import matplotlib.pyplot as plt
import imageio as iio

In [212]:
DATA_PATH = '../../Pokemon-data/'
SEED = 42

# define training hyperparameters
INIT_LR = 1e-3
BATCH_SIZE = 64
EPOCHS = 10
# define the train and val splits
TRAIN_SPLIT = 0.75
VAL_SPLIT = 1 - TRAIN_SPLIT

In [213]:
metadata = pd.read_csv('metadata/pokemon.csv')

In [214]:
print(metadata.columns)

Index(['abilities', 'against_bug', 'against_dark', 'against_dragon',
       'against_electric', 'against_fairy', 'against_fight', 'against_fire',
       'against_flying', 'against_ghost', 'against_grass', 'against_ground',
       'against_ice', 'against_normal', 'against_poison', 'against_psychic',
       'against_rock', 'against_steel', 'against_water', 'attack',
       'base_egg_steps', 'base_happiness', 'base_total', 'capture_rate',
       'classfication', 'defense', 'experience_growth', 'height_m', 'hp',
       'japanese_name', 'name', 'percentage_male', 'pokedex_number',
       'sp_attack', 'sp_defense', 'speed', 'type1', 'type2', 'weight_kg',
       'generation', 'is_legendary'],
      dtype='object')


In [215]:
pokemon_names = [x for x in os.listdir(DATA_PATH)]

# remove the IDE metafile that was included in the os.listdir
#pokemon_names = pokemon_names[1:]

In [216]:
# get slice from metadata file based on the pokemon used in training
filtered_metadata = metadata[metadata.name.isin(pokemon_names)]

In [217]:
# convert dataframe to list for easier comparison
filtered_list = filtered_metadata.loc[:,'name'].tolist()

In [218]:
# check for differences in training data and the metadata file
print([x for x in pokemon_names if x not in filtered_list])

[]


In [219]:
filtered_metadata = filtered_metadata.loc[:,['name', 'type1']]

In [220]:
print(filtered_metadata.head)

<bound method NDFrame.head of            name    type1
0     Bulbasaur    grass
1       Ivysaur    grass
2      Venusaur    grass
3    Charmander     fire
4    Charmeleon     fire
..          ...      ...
146     Dratini   dragon
147   Dragonair   dragon
148   Dragonite   dragon
149      Mewtwo  psychic
150         Mew  psychic

[149 rows x 2 columns]>


In [221]:
def compile_training_data_to_list():
    all_data = []
    for pokemon in os.listdir(DATA_PATH):
        all_data += [pokemon + '/' + x for x in os.listdir(DATA_PATH + pokemon)]

    results = create_annotated_dataframe(all_data)
    return results

def create_annotated_dataframe(all_data):
    base_data = {'file_name': [], 'name': [], 'label': []}
    for item in all_data:
        base_data['file_name'].append(item)
        base_data['name'].append(item.split('/')[0])
        # yes, this is a bit ugly, but we have to match with the metadata
        base_data['label'].append(
            filtered_metadata[
                filtered_metadata['name']==(item.split('/')[0])
            ].loc[:,'type1'].tolist()[0])

    results = create_encoded_dataframe(base_data)
    return results

def create_encoded_dataframe(base_data):
    results = pd.DataFrame(base_data, columns = ['file_name', 'name', 'label'])
    le = LabelEncoder()
    labels = le.fit_transform(results['label'])
    results['label'] = np.int64(labels)
    return results

In [222]:
encoded_data = compile_training_data_to_list()

In [223]:
print(encoded_data)

                                       file_name   name  label
0      Abra/0282b2f3a22745f1a436054ea15a0ae5.jpg   Abra     12
1      Abra/06b9eec4827d4d49b1b4c284308708df.jpg   Abra     12
2      Abra/10a9f06ec6524c66b779ea80354f8519.jpg   Abra     12
3      Abra/1788abb8b51f48509cfac8067bd99e14.jpg   Abra     12
4      Abra/28cfad92ad934d1f9b579cbff4b5d012.jpg   Abra     12
...                                          ...    ...    ...
6791  Zubat/dd387067380e4d1f8672c30d4b567fac.jpg  Zubat     11
6792  Zubat/e1997a18e61641a4b0e701f6bc4c70f4.jpg  Zubat     11
6793  Zubat/e6cba9a117d64d849fcc389e04e92e11.jpg  Zubat     11
6794  Zubat/f8788465c10a4ab8bb0aeb992ec060ce.jpg  Zubat     11
6795  Zubat/fccfe4de71a543349378b09d91d3f745.jpg  Zubat     11

[6796 rows x 3 columns]


In [224]:
class CustomDataset(Dataset):
	def __init__(self, x, y, img_dir):
		self.x = x
		self.y = y
		self.img_dir = img_dir
		self.classes = np.unique(self.y)


	def __len__(self):
		return len(self.x)


	def __getitem__(self, idx):
		img_path = os.path.join(self.img_dir, self.x[idx])
        # use the slice to remove a possible 4th alpha channel
		image = iio.v2.imread(img_path)[:,:,:3]
		label = self.y[idx]
		return image, label

In [225]:
def stratified_split(dataset):
    x_train, x_val, y_train, y_val = train_test_split(dataset['file_name'].to_numpy(),
                                                  dataset['label'].to_numpy(),
                                                  test_size=0.25,
                                                  stratify=dataset['label'],
                                                  random_state=SEED)

    train = CustomDataset(x_train, y_train, DATA_PATH)
    val = CustomDataset(x_val, y_val, DATA_PATH)
    return train, val

In [226]:
dataset = compile_training_data_to_list()
train, val = stratified_split(dataset)

In [227]:
datamodule = ImageClassificationData.from_datasets(train_dataset=train,
                                                   val_dataset=val,
                                                   batch_size=BATCH_SIZE,
                                                   )
performance_metrics = [torchmetrics.Accuracy(),
                      torchmetrics.F1Score(num_classes=len(train.classes), average='macro')]

In [228]:
model = ImageClassifier(backbone='efficientnet_b0',
                        labels=train.classes,
                        metrics=performance_metrics,
                        loss_fn=LabelSmoothingCrossEntropy(0.02),
                        optimizer="AdamW",
                        learning_rate=INIT_LR, )

logger = CSVLogger(save_dir='logs/')

trainer = flash.Trainer(max_epochs=EPOCHS,
                        gpus=torch.cuda.device_count(),
                        logger=logger)

Using 'efficientnet_b0' provided by rwightman/pytorch-image-models (https://github.com/rwightman/pytorch-image-models).
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [229]:
startTime = time.time()
trainer.finetune(model,
                 datamodule=datamodule,
                 strategy='no_freeze')

endTime = time.time()
print(f"[INFO] total time taken to train the model: {(endTime - startTime) / 60 :.2f}min")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type           | Params
-------------------------------------------------
0 | loss_fn       | ModuleDict     | 0     
1 | train_metrics | ModuleDict     | 0     
2 | val_metrics   | ModuleDict     | 0     
3 | test_metrics  | ModuleDict     | 0     
4 | adapter       | DefaultAdapter | 4.0 M 
-------------------------------------------------
4.0 M     Trainable params
0         Non-trainable params
4.0 M     Total params
16.107    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

ValueError: Could not find a backend to open `../../Pokemon-data/Cloyster/ff270ebfab0f46b3b05c3fecd6a15ef9.svg`` with iomode `ri`.

In [None]:
trainer.save_checkpoint("saved-models/B0-Un.pt")

In [None]:
metrics = pd.read_csv(f'{trainer.logger.log_dir}/metrics.csv')
display(metrics)
del metrics["epoch"]
metrics.set_index("step", inplace=True)

In [None]:
plt.figure(figsize=(20,5))
sns.relplot(data=metrics, kind="line", height=20)
plt.grid()