In [1]:
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from keras import layers, models
from keras.applications import *
from keras.callbacks import EarlyStopping
from keras.preprocessing.image import ImageDataGenerator

%matplotlib inline

In [2]:
DATA_DIR_NAME = 'data'

In [3]:
class_names = os.listdir(DATA_DIR_NAME)
class_names.sort()
num_classes = len(class_names)
class_names[:10]

['Abra',
 'Aerodactyl',
 'Alakazam',
 'Arbok',
 'Arcanine',
 'Articuno',
 'Beedrill',
 'Bellsprout',
 'Blastoise',
 'Bulbasaur']

In [4]:
image_paths = []
labels = []
for class_name in class_names:
    pokemon_dir = os.path.join(DATA_DIR_NAME, class_name)
    image_file_names = os.listdir(pokemon_dir)
    image_paths.extend(os.path.join(pokemon_dir, name) for name in image_file_names)
    labels.extend([class_name] * len(image_file_names))

df = pd.DataFrame({'filename': image_paths, 'class': labels})

In [6]:
df.value_counts('class')

class
Pikachu      286
Charizard    167
Venusaur     162
Sandslash    142
Gengar       140
            ... 
Poliwrath     61
Nidoking      60
Dratini       57
Nidoran♂      50
Nidoran♀      44
Name: count, Length: 151, dtype: int64

In [5]:
train_df, test_df = train_test_split(
    df,
    test_size=0.2,
    random_state=42,
    shuffle=True,
    stratify=labels
)
train_df, validation_df = train_test_split(
    train_df,
    test_size=0.25,
    random_state=42,
    shuffle=True,
    stratify=labels
)

### Baseline Model

In [13]:
BATCH_SIZE = 64
vgg16_input_size = (224, 224)

datagen = ImageDataGenerator(preprocessing_function=vgg16.preprocess_input)

train_generator = datagen.flow_from_dataframe(
    train_df,
    target_size=vgg16_input_size,
    class_mode='categorical',
    batch_size=BATCH_SIZE
)

validation_generator = datagen.flow_from_dataframe(
    validation_df,
    target_size=vgg16_input_size,
    class_mode='categorical',
    batch_size=BATCH_SIZE
)

Found 11172 validated image filenames belonging to 151 classes.
Found 2793 validated image filenames belonging to 151 classes.


In [14]:
vgg16_base_model = VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
vgg16_base_model.trainable = False

vgg16_model = models.Sequential([
  vgg16_base_model,
  layers.Flatten(input_shape=vgg16_base_model.output_shape[1:]),
  layers.Dense(4096, activation='relu'),
  layers.Dense(num_classes, activation='softmax')
])

vgg16_model.compile(
    loss='categorical_crossentropy',
    optimizer='adam',
    metrics=['accuracy']
)

In [None]:
early_stopping = EarlyStopping(
    monitor='val_loss',
    min_delta=0.001,
    patience=5,
    verbose=1,
    restore_best_weights=True
)

In [None]:
vgg16_model.fit(
    train_generator,
    validation_data=validation_generator,
    epochs=30
)

In [None]:
VGG16_MODEL_PATH = 'models/vgg16_model.keras'
vgg16_model.save(VGG16_MODEL_PATH)

In [29]:
densenet201_base_model = DenseNet201(weights='imagenet', include_top=False)
densenet201_base_model.trainable = False

densenet201_model = models.Sequential([
  densenet201_base_model,
  layers.GlobalAveragePooling2D(),
  layers.Dense(num_classes, activation='softmax')
])

In [30]:
densenet201_model.compile(
    loss='categorical_crossentropy',
    optimizer='adam',
    metrics=['accuracy']
)

In [31]:
densenet201_model.fit(
    train_generator,
    validation_data=validation_generator,
    epochs=20
)

Epoch 1/20




Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.src.callbacks.History at 0x7816defaa7a0>

In [32]:
DENSENET201_MODEL_PATH = 'models/densenet201_model.keras'
densenet201_model.save(DENSENET201_MODEL_PATH)