In [31]:
import os

import pytorch_lightning.utilities.seed
from sklearn.preprocessing import LabelEncoder
import pandas as pd
import numpy as np
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
import torchmetrics

import time
import torch.cuda
from flash.image import ImageClassificationData, ImageClassifier
import flash
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from flash import Trainer
from pytorch_lightning.utilities.model_summary import ModelSummary

In [2]:
DATA_PATH = '../../Pokemon-data/'
NORMALIZED_DATA = '../../Training-baseline/'
SEED = 42

# define training hyperparameters
INIT_LR = 1e-3
BATCH_SIZE = 64
EPOCHS = 3
FINETUNE_STRATEGY = 'no_freeze'
VERSION = (str(EPOCHS) + 'Epoch-' + FINETUNE_STRATEGY)
# define the train and val splits
TRAIN_SPLIT = 0.75
VAL_SPLIT = 1 - TRAIN_SPLIT

In [3]:
pytorch_lightning.utilities.seed.seed_everything(SEED)

Global seed set to 42


42

In [4]:
metadata = pd.read_csv('metadata/pokemon.csv')

In [5]:
print(metadata.columns)

Index(['abilities', 'against_bug', 'against_dark', 'against_dragon',
       'against_electric', 'against_fairy', 'against_fight', 'against_fire',
       'against_flying', 'against_ghost', 'against_grass', 'against_ground',
       'against_ice', 'against_normal', 'against_poison', 'against_psychic',
       'against_rock', 'against_steel', 'against_water', 'attack',
       'base_egg_steps', 'base_happiness', 'base_total', 'capture_rate',
       'classfication', 'defense', 'experience_growth', 'height_m', 'hp',
       'japanese_name', 'name', 'percentage_male', 'pokedex_number',
       'sp_attack', 'sp_defense', 'speed', 'type1', 'type2', 'weight_kg',
       'generation', 'is_legendary'],
      dtype='object')


In [6]:
pokemon_names = [x for x in os.listdir(NORMALIZED_DATA)]

# remove the IDE metafile that was included in the os.listdir
#pokemon_names = pokemon_names[1:]

In [7]:
# get slice from metadata file based on the pokemon used in training
filtered_metadata = metadata[metadata.name.isin(pokemon_names)]

In [8]:
# convert dataframe to list for easier comparison
filtered_list = filtered_metadata.loc[:,'name'].tolist()

In [9]:
# check for differences in training data and the metadata file
print([x for x in pokemon_names if x not in filtered_list])

[]


In [10]:
filtered_metadata = filtered_metadata.loc[:,['name', 'type1']]

fire_data = filtered_metadata[filtered_metadata["type1"] == 'fire']
grass_data = filtered_metadata[filtered_metadata["type1"] == 'grass']
filtered_metadata = pd.concat([fire_data, grass_data], ignore_index=True)

In [11]:
print(filtered_metadata.head)

<bound method NDFrame.head of            name  type1
0    Charmander   fire
1    Charmeleon   fire
2     Charizard   fire
3        Vulpix   fire
4     Ninetales   fire
..          ...    ...
106     Chespin  grass
107   Quilladin  grass
108  Chesnaught  grass
109      Skiddo  grass
110      Gogoat  grass

[111 rows x 2 columns]>


In [12]:
print(len(filtered_metadata[filtered_metadata['name'].str.contains('Charmander')]))

1


In [13]:
def compile_training_data_to_list():
    all_data = []
    for pokemon in os.listdir(NORMALIZED_DATA):
        all_data += [pokemon + '/' + x for x in os.listdir(NORMALIZED_DATA + pokemon)]

    results = create_annotated_dataframe(all_data)
    return results

def create_annotated_dataframe(all_data):
    base_data = {'file_name': [], 'name': [], 'label': []}
    for item in all_data:
        if len(filtered_metadata[filtered_metadata['name'].str.contains(item.split('/')[0])]):
            base_data['file_name'].append(item)
            base_data['name'].append(item.split('/')[0])
            # yes, this is a bit ugly, but we have to match with the metadata
            base_data['label'].append(
                filtered_metadata[
                    filtered_metadata['name']==(item.split('/')[0])
                ].loc[:,'type1'].tolist()[0])

    results = create_encoded_dataframe(base_data)
    return results

def create_encoded_dataframe(base_data):
    results = pd.DataFrame(base_data, columns = ['file_name', 'name', 'label'])
    le = LabelEncoder()
    labels = le.fit_transform(results['label'])
    results['label'] = np.int64(labels)
    return results

In [14]:
encoded_data = compile_training_data_to_list()

In [15]:
print(encoded_data)

                            file_name        name  label
0      Abomasnow/dcedzyqfojskcahp.npy   Abomasnow      1
1      Abomasnow/gqfpsmqasdqiknur.npy   Abomasnow      1
2      Abomasnow/imzcvkkckbdchpro.npy   Abomasnow      1
3      Abomasnow/kzibfmivzksykiwy.npy   Abomasnow      1
4      Abomasnow/mjtasvyoonxyilqt.npy   Abomasnow      1
...                               ...         ...    ...
1860  Whimsicott/ldzsotkctiqtdnwa.npy  Whimsicott      1
1861  Whimsicott/tuemvnafwsjqjnhv.npy  Whimsicott      1
1862  Whimsicott/vtsqtsnodkpzuruf.npy  Whimsicott      1
1863  Whimsicott/xyvaxbleknfsaaes.npy  Whimsicott      1
1864  Whimsicott/ywwawzssckbmhbcx.npy  Whimsicott      1

[1865 rows x 3 columns]


In [16]:
class CustomDataset(Dataset):
    def __init__(self, x, y, img_dir):
        self.x = x
        self.y = y
        self.img_dir = img_dir
        self.classes = np.unique(self.y)


    def __len__(self):
        return len(self.x)


    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.x[idx])
        # use the slice to remove a possible 4th alpha channel
        image = np.load(img_path)[:, :, :3]
        image = image.astype(np.float32)
        label = self.y[idx]
        return image, label

In [17]:
def stratified_split(dataset):
    x_train, x_val, y_train, y_val = train_test_split(dataset['file_name'].to_numpy(),
                                                  dataset['label'].to_numpy(),
                                                  test_size=0.25,
                                                  stratify=dataset['label'],
                                                  random_state=SEED)

    train = CustomDataset(x_train, y_train, NORMALIZED_DATA)
    val = CustomDataset(x_val, y_val, NORMALIZED_DATA)
    return train, val

In [18]:
dataset = compile_training_data_to_list()
train, val = stratified_split(dataset)

In [19]:
datamodule = ImageClassificationData.from_datasets(train_dataset=train,
                                                   val_dataset=val,
                                                   batch_size=BATCH_SIZE,
                                                   )
performance_metrics = [torchmetrics.Accuracy(),
                      torchmetrics.F1Score(num_classes=len(train.classes), average='macro')]

In [20]:
model = ImageClassifier(backbone='efficientnet_b0',
                        pretrained=True,
                        labels=train.classes,
                        metrics=performance_metrics,
                        optimizer="AdamW",
                        learning_rate=INIT_LR, )

logger = CSVLogger(save_dir='logs/', version=VERSION, name='pretrained-experiments')

early_stop_callback = EarlyStopping(monitor="val_f1score", patience=3, verbose=False, mode="max")

trainer = flash.Trainer(max_epochs=EPOCHS,
                        gpus=torch.cuda.device_count(),
                        logger=logger,
                        callbacks=[early_stop_callback])

Using 'efficientnet_b0' provided by rwightman/pytorch-image-models (https://github.com/rwightman/pytorch-image-models).
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [21]:
startTime = time.time()
trainer.finetune(model,
                 datamodule=datamodule,
                 strategy=FINETUNE_STRATEGY)

endTime = time.time()
print(f"[INFO] total time taken to train the model: {(endTime - startTime) / 60 :.2f}min")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(

  | Name          | Type           | Params
-------------------------------------------------
0 | train_metrics | ModuleDict     | 0     
1 | val_metrics   | ModuleDict     | 0     
2 | test_metrics  | ModuleDict     | 0     
3 | adapter       | DefaultAdapter | 4.0 M 
-------------------------------------------------
4.0 M     Trainable params
0         Non-trainable params
4.0 M     Total params
16.040    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(
Global seed set to 42
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

[INFO] total time taken to train the model: 1.03min


In [22]:
trainer.save_checkpoint("saved-models/" + VERSION + ".pt")

In [23]:
metrics = pd.read_csv(f'{trainer.logger.log_dir}/metrics.csv')
display(metrics)
del metrics["epoch"]
metrics.set_index("step", inplace=True)

Unnamed: 0,val_accuracy,val_f1score,val_cross_entropy,epoch,step,train_accuracy_epoch,train_f1score_epoch,train_cross_entropy_epoch,train_accuracy_step,train_f1score_step,train_cross_entropy_step
0,0.978587,0.978386,0.120292,0,20,,,,,,
1,,,,0,20,0.919643,0.918629,0.200266,,,
2,0.989293,0.989159,0.050754,1,41,,,,,,
3,,,,1,41,0.986607,0.986456,0.049086,,,
4,,,,2,49,,,,0.984375,0.983189,0.038061
5,0.978587,0.978307,0.060585,2,62,,,,,,
6,,,,2,62,0.979911,0.9797,0.081916,,,


In [24]:
le = LabelEncoder()
labels = le.fit_transform(filtered_metadata['type1'])
filtered_metadata['encoded_label'] = labels

type_list = filtered_metadata['type1'].unique()
enc_type_list = filtered_metadata['encoded_label'].unique()

type_lookup = pd.DataFrame(list(zip(type_list, enc_type_list)), columns=['Type_name', 'encoded_type'])

In [32]:
summary = ModelSummary(model, max_depth=-1)
print(summary)

    | Name                                       | Type                   | Params
----------------------------------------------------------------------------------------
0   | train_metrics                              | ModuleDict             | 0     
1   | train_metrics.accuracy                     | Accuracy               | 0     
2   | train_metrics.f1score                      | F1Score                | 0     
3   | val_metrics                                | ModuleDict             | 0     
4   | val_metrics.accuracy                       | Accuracy               | 0     
5   | val_metrics.f1score                        | F1Score                | 0     
6   | test_metrics                               | ModuleDict             | 0     
7   | test_metrics.accuracy                      | Accuracy               | 0     
8   | test_metrics.f1score                       | F1Score                | 0     
9   | adapter                                    | DefaultAdapter         | 4.0 M

In [29]:
model.eval()

pred_data = ImageClassificationData.from_files(
    predict_files=[
        '../../Training-baseline/Growlithe/2d7043870a8843f08b3267d9c70885c3.npy',
        '../../Training-baseline/Growlithe/0dd396a2ce43499cb2f1feec957a5e0f.npy',
        '../../Training-baseline/Growlithe/2bb7c233adbb4c789a7919a103cb110f.npy',
        '../../Training-baseline/Growlithe/7a9a39e6183747f6b6e810d4b6187142.npy',
        '../../Training-baseline/Growlithe/91a68fd497724a0db5e7f91dec706278.npy',
        '../../Training-baseline/Bellsprout/9ad85f5338424bd2ab7d9b202555d3ad.npy',
        '../../Training-baseline/Bellsprout/4f322b81ba9b49ddb22c94f5c97bb53f.npy',
        '../../Training-baseline/Bellsprout/50454dc01af44a6ea561ca6e9221fa5c.npy',
        '../../Training-baseline/Bellsprout/be83544c5ca94f11a28ec022503f515f.npy',
        '../../Training-baseline/Bellsprout/e65b4a8173644d4fba87c55ae2b20b10.npy',
    ],
    batch_size = 1
)


predictions = trainer.predict(model, datamodule=pred_data, output='labels')

print(predictions)

print(f"""
- Pokemon had the predicted type:
{type_lookup.loc[type_lookup['encoded_type'] == predictions[0][0]]}
---------------------------------------------------------------------
- The Pokemon in Question actually has the type:
{filtered_metadata[filtered_metadata['name'] == 'Growlithe']}
""")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Predicting: 21it [00:00, ?it/s]

[[0], [0], [0], [0], [0], [0], [0], [0], [0], [0]]

- Pokemon had the predicted type:
  Type_name  encoded_type
0      fire             0
---------------------------------------------------------------------
- The Pokemon in Question actually has the type:
        name type1  encoded_label
5  Growlithe  fire              0

