In [None]:
import os
import pickle

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import random

from keras import layers, models, Input
from keras.applications import *
from keras.callbacks import EarlyStopping
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import Callback
# from keras.wrappers.scikit_learn import KerasClassifier
from scikeras.wrappers import KerasClassifier


from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import RandomizedSearchCV

%matplotlib inline

In [None]:
DATA_DIR_NAME = 'data'

In [None]:
class_names = os.listdir(DATA_DIR_NAME)
class_names.sort()
num_classes = len(class_names)
class_names[:10]

In [None]:
image_paths = []
labels = []
for class_name in class_names:
    pokemon_dir = os.path.join(DATA_DIR_NAME, class_name)
    image_file_names = os.listdir(pokemon_dir)
    image_paths.extend(os.path.join(pokemon_dir, name) for name in image_file_names)
    labels.extend([class_name] * len(image_file_names))

df = pd.DataFrame({'filename': image_paths, 'class': labels})

In [None]:
df.value_counts('class')

In [None]:
train_df, test_df = train_test_split(
    df,
    test_size=0.2,
    random_state=42,
    shuffle=True,
    stratify=df['class']  # Ensure all pokemon represented in each split
)
train_df, validation_df = train_test_split(
    train_df,
    test_size=0.25,
    random_state=42,
    shuffle=True,
    stratify=train_df['class']
)

### Utilities

In [None]:
def make_image_generator(dataframe, image_size, preprocessing_function, batch_size=64, **augmentations):
    datagen = ImageDataGenerator(
        preprocessing_function=preprocessing_function,
        **augmentations
    )
    return datagen.flow_from_dataframe(
        dataframe,
        target_size=image_size,
        class_mode='categorical',
        batch_size=batch_size
    )

In [None]:
def make_training_generators(image_size, preprocessing_function, batch_size=64, **augmentations):
    train_generator = make_image_generator(train_df, image_size, preprocessing_function, batch_size, **augmentations)
    validation_generator = make_image_generator(validation_df, image_size, preprocessing_function, batch_size)
    return train_generator, validation_generator

In [None]:
early_stopping = EarlyStopping(
    monitor='val_loss',
    min_delta=0,
    patience=5,
    verbose=1,
    restore_best_weights=True
)

### Baseline Model

In [None]:
train_generator, validation_generator = make_training_generators(
    image_size=(224, 224),
    preprocessing_function=vgg16.preprocess_input
)

In [None]:
vgg16_base_model = VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
vgg16_base_model.trainable = False

vgg16_model = models.Sequential([
  vgg16_base_model,
  layers.Flatten(input_shape=vgg16_base_model.output_shape[1:]),
  layers.Dense(4096, activation='relu'),
  layers.Dense(num_classes, activation='softmax')
])

vgg16_model.compile(
    loss='categorical_crossentropy',
    optimizer='adam',
    metrics=['accuracy']
)

In [None]:
vgg16_model.fit(
    train_generator,
    validation_data=validation_generator,
    epochs=30,
    callbacks=[early_stopping]
)

In [None]:
vgg16_model.save('models/vgg16_model.keras')

### Models

In [None]:
train_generator, validation_generator = make_training_generators(
    image_size=(224, 224),
    preprocessing_function=resnet.preprocess_input
)

resnet_base_model = ResNet152(weights='imagenet', include_top=False)
resnet_base_model.trainable = False

resnet_model = models.Sequential([
  resnet_base_model,
  layers.GlobalAveragePooling2D(),
  layers.Dense(num_classes, activation='softmax')
])

resnet_model.compile(
    loss='categorical_crossentropy',
    optimizer='adam',
    metrics=['accuracy']
)

In [None]:
resnet_model.fit(
    train_generator,
    validation_data=validation_generator,
    epochs=10,
    callbacks=[early_stopping]
)

In [None]:
train_generator, validation_generator = make_training_generators(
    image_size=(299, 299),
    preprocessing_function=inception_v3.preprocess_input
)

inception_v3_base_model = InceptionV3(weights='imagenet', include_top=False)
inception_v3_base_model.trainable = False

inception_v3_model = models.Sequential([
  inception_v3_base_model,
  layers.GlobalAveragePooling2D(),
  layers.Dense(num_classes, activation='softmax')
])

inception_v3_model.compile(
    loss='categorical_crossentropy',
    optimizer='adam',
    metrics=['accuracy']
)

In [None]:
inception_v3_model.fit(
    train_generator,
    validation_data=validation_generator,
    epochs=10,
    callbacks=[early_stopping]
)

In [None]:
train_generator, validation_generator = make_training_generators(
    image_size=(299, 299),
    preprocessing_function=inception_resnet_v2.preprocess_input
)

inception_resnet_base_model = InceptionResNetV2(weights='imagenet', include_top=False)
inception_resnet_base_model.trainable = False

inception_resnet_model = models.Sequential([
  inception_resnet_base_model,
  layers.GlobalAveragePooling2D(),
  layers.Dense(num_classes, activation='softmax')
])

inception_resnet_model.compile(
    loss='categorical_crossentropy',
    optimizer='adam',
    metrics=['accuracy']
)

In [None]:
inception_resnet_model.fit(
    train_generator,
    validation_data=validation_generator,
    epochs=10,
    callbacks=[early_stopping]
)

In [None]:
train_generator, validation_generator = make_training_generators(
    image_size=(224, 224),
    preprocessing_function=densenet.preprocess_input
)

densenet_base_model = DenseNet201(weights='imagenet', include_top=False)
densenet_base_model.trainable = False

densenet_model = models.Sequential([
  densenet_base_model,
  layers.GlobalAveragePooling2D(),
  layers.Dense(num_classes, activation='softmax')
])

densenet_model.compile(
    loss='categorical_crossentropy',
    optimizer='adam',
    metrics=['accuracy']
)

In [None]:
densenet_model.fit(
    train_generator,
    validation_data=validation_generator,
    epochs=10,
    callbacks=[early_stopping]
)

In [None]:
train_generator, validation_generator = make_training_generators(
    image_size=(224, 224),
    preprocessing_function=efficientnet.preprocess_input
)

efficientnet_base_model = EfficientNetB7(weights='imagenet', include_top=False)
efficientnet_base_model.trainable = False

efficientnet_model = models.Sequential([
  efficientnet_base_model,
  layers.GlobalAveragePooling2D(),
  layers.Dense(num_classes, activation='softmax')
])

efficientnet_model.compile(
    loss='categorical_crossentropy',
    optimizer='adam',
    metrics=['accuracy']
)

In [None]:
efficientnet_model.fit(
    train_generator,
    validation_data=validation_generator,
    epochs=10,
    callbacks=[early_stopping]
)

### Data Augmentation

In [None]:
# Without data augmentations
train_generator, validation_generator = make_training_generators(
    image_size=(224, 224),
    preprocessing_function=densenet.preprocess_input
)

densenet_base_model = DenseNet201(weights='imagenet', include_top=False)
densenet_base_model.trainable = False

densenet_model = models.Sequential([
  densenet_base_model,
  layers.GlobalAveragePooling2D(),
  layers.Dense(num_classes, activation='softmax')
])

densenet_model.compile(
    loss='categorical_crossentropy',
    optimizer='adam',
    metrics=['accuracy']
)

In [None]:
hist1 = densenet_model.fit(
    train_generator,
    validation_data=validation_generator,
    epochs=20,
    callbacks=[early_stopping]
)

In [None]:
# With data augmentations
train_generator, validation_generator = make_training_generators(
    image_size=(224, 224),
    preprocessing_function=densenet.preprocess_input,
    rotation_range=20,
    zoom_range=0.2,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True,
    fill_mode='nearest'
)

densenet_base_model = DenseNet201(weights='imagenet', include_top=False)
densenet_base_model.trainable = False

densenet_model = models.Sequential([
  densenet_base_model,
  layers.GlobalAveragePooling2D(),
  layers.Dense(num_classes, activation='softmax')
])

densenet_model.compile(
    loss='categorical_crossentropy',
    optimizer='adam',
    metrics=['accuracy']
)

In [None]:
hist2 = densenet_model.fit(
    train_generator,
    validation_data=validation_generator,
    epochs=20,
    callbacks=[early_stopping]
)

In [None]:
plt.figure()
plt.plot(hist1.history['loss'], label='Training Loss (Without Data Augmentation)', color='blue')
plt.plot(hist1.history['val_loss'], label='Validation Loss (Without Data Augmentation)', color='cyan')
plt.plot(hist2.history['loss'], label='Training Loss (With Data Augmentation)', color='red')
plt.plot(hist2.history['val_loss'], label='Validation Loss (With Data Augmentation)', color='pink')

plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

### Experimenting With Model Classification Head

In [None]:
train_generator, validation_generator = make_training_generators(
    image_size=(224, 224),
    preprocessing_function=densenet.preprocess_input,
    rotation_range=20,
    zoom_range=0.2,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True,
    fill_mode='nearest'
)

densenet_base_model = DenseNet201(weights='imagenet', include_top=False)
densenet_base_model.trainable = False

### Fine Tuning

In [None]:
train_generator, validation_generator = make_training_generators(
    image_size=(224, 224),
    preprocessing_function=densenet.preprocess_input,
    rotation_range=20,
    zoom_range=0.2,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True,
    fill_mode='nearest'
)

In [None]:
#### RANDOM SEARCH FOR HYPERPARAMETERS #####

def build_model(num_layers = 0, dropout = 0):
    model = models.Sequential()
    model.add(densenet_base_model)
    model.add(layers.GlobalAveragePooling2D())
    for i in range(num_layers):
        models.add(layers.Dense(32 * (i + 1), activation='relu'))
        models.add(layers.Dropout(dropout))
    model.add(layers.Dense(num_classes, activation='softmax'))
    return model

#I took this from ChatGPT - temporary solution to get history
class CollectHistory(Callback):
    def __init__(self):
        self.history = {'loss': [], 'accuracy': [], 'val_loss': [], 'val_accuracy': []}

    def on_epoch_end(self, epoch, logs=None):
        self.history['loss'].append(logs.get('loss'))
        self.history['accuracy'].append(logs.get('accuracy'))
        self.history['val_loss'].append(logs.get('val_loss'))
        self.history['val_accuracy'].append(logs.get('val_accuracy'))

In [None]:
#### RANDOM SEARCH FOR HYPERPARAMETERS (cont) #####

params = {
    'num_layers': [0, 1, 2], 
    "dropout": [0, 0.1, 0.2]
}

model = KerasClassifier(build_fn=build_model, verbose=0, num_layers = params["num_layers"], dropout = params["dropout"])


# Custom callback to collect history during RandomizedSearchCV
collect_history = CollectHistory()

n_iter_search = 5   #number of parameter combinations to look through, can change
random_search = RandomizedSearchCV(model, param_distributions=params, n_iter=n_iter_search, cv=3)
random_search.fit(train_generator, validation_data = validation_generator, callbacks=[collect_history])  # Replace X and y with your data

print(collect_history.history)

with open('histories/finetuning_history.pkl', 'wb') as file:
    pickle.dump(collect_history.history, file)

best_model = random_search.best_estimator_.model
best_model.save('models/model_finetuned.keras')

results = pd.DataFrame(random_search.cv_results_)
print(results)

In [None]:
densenet_base_model = DenseNet201(weights='imagenet', include_top=False)
densenet_base_model.trainable = False

inputs = Input(shape=(224, 224, 3))
x = densenet_base_model(inputs, training=False)  # Ensure batchnorm layers run in inference mode
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(num_classes, activation='softmax')(x)
model = Model(inputs, outputs)

model.compile(
    loss='categorical_crossentropy',
    optimizer='adam',
    metrics=['accuracy']
)

history_train = model.fit(
    train_generator,
    validation_data=validation_generator,
    epochs=30,
    callbacks=[early_stopping]
)

In [None]:
model.save('models/model_untuned.keras')
with open('histories/model_training.pkl', 'wb') as file:
    pickle.dump(history_train.history, file)

In [None]:
temp = DenseNet201(weights=None, include_top=False)

In [None]:
for i, layer in enumerate(temp.layers):
    print(i, layer.name)

In [None]:
# Unfreeze layers in conv5
for layer in densenet_base_model.layers[481:]:
    if layer.name.startswith('conv5'):
        layer.trainable = True

model.compile(
    loss='categorical_crossentropy',
    optimizer=RMSprop(learning_rate=1e-5),
    metrics=['accuracy']
)

history_fine = model.fit(
    train_generator,
    validation_data=validation_generator,
    epochs=10,
    callbacks=[early_stopping]
)

In [None]:
model.save('models/model_tuned.keras')
with open('histories/model_tuning.pkl', 'wb') as file:
    pickle.dump(history_fine.history, file)

#### Testing

In [None]:
model = models.load_model('models/densenet201_model.keras') #load model first

In [None]:
test_datagen = ImageDataGenerator(
    rescale=1.0/255,
)

test_generator = test_datagen.flow_from_dataframe(
    test_df,
    directory=DATA_DIR_NAME,
    target_size=(224, 224),
    batch_size=128,
    class_mode='categorical',
    shuffle=False,
)

In [None]:
results = model.evaluate(test_generator)

In [None]:
predictions = model.predict(test_generator)
prediction_labels = np.argmax(predictions, axis = -1)

In [None]:
print(classification_report(test_generator.classes,prediction_labels,target_names = class_names))

In [None]:
confusion_matrix = confusion_matrix(test_generator.classes,prediction_labels)

In [None]:
most_common_misclassifications = []

for i in range(len(confusion_matrix)):
    for j in range(len(confusion_matrix[0])):
        if i != j and confusion_matrix[i][j] > 0:
            most_common_misclassifications.append((i, j, confusion_matrix[i][j]))

most_common_misclassifications.sort(key=lambda x: x[2], reverse=True)

print("Top Misclassifications:")
for misclassification in most_common_misclassifications:
    print(f"True class {class_names[misclassification[0]]} misclassified as {class_names[misclassification[1]]}: Count {misclassification[2]}")