In [None]:
import pickle
from collections import Counter
from datetime import datetime
from pathlib import Path

import joblib
import keras
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import seaborn as sns
import tensorflow as tf
from sklearn.metrics import (
    accuracy_score,
    classification_report,
    f1_score,
    precision_score,
    recall_score,
)

from utils import *

In [None]:
# Set up the configuration for the polars library
cfg = pl.Config()
cfg.set_tbl_rows(100)
cfg.set_tbl_width_chars(200)
cfg.set_fmt_str_lengths(200)

pc = ProjectConfig()
BASE_TRAINING_MODEL_DIR = pc.project_root_dir.joinpath("ablation_study").joinpath(
    "model_saves_fc"
)
MODEL_NAMES = [
    "EfficientNetV2M_SINGLE_LAYER",
    "EfficientNetV2M_FOUR_LAYER",
    "EfficientNetV2M_THREE_LAYER",
    "EfficientNetV2M_TWO_LAYER_WIDE_RELU",
    "EfficientNetV2M_TWO_LAYER_WIDE_SWISH",
]
BASE_MODEL_PATHS = []
CLASS_LIST_PATHS = []
BASE_TRAINING_TIME_PATHS = []
FINE_TUNED_TRAINING_TIME_PATHS = []
for model_name in MODEL_NAMES:
    BASE_MODEL_PATHS.append(BASE_TRAINING_MODEL_DIR.joinpath(model_name + ".h5"))
    CLASS_LIST_PATHS.append(
        pc.project_root_dir.joinpath("ablation_study").joinpath(
            f"class_list_{model_name}.lzma"
        )
    )
    BASE_TRAINING_TIME_PATHS.append(
        BASE_TRAINING_MODEL_DIR.joinpath(f"{model_name}_TRAINING_TIME.lzma")
    )

In [None]:
TEST_DATA = pc.data_root_dir.joinpath("FruitsClassification/Fruits Classification/test")
if TEST_DATA.exists() and TEST_DATA.is_dir():
    print("Test data found, loading")
    test_data = tf.keras.utils.image_dataset_from_directory(
        TEST_DATA,
        labels="inferred",
        label_mode="int",
        batch_size=32,
        image_size=(224, 224),
        shuffle=False,
    )
else:
    print("No test data found, where did it go?")

In [None]:
results = []

for i, model_path in enumerate(BASE_MODEL_PATHS):
    if model_path.exists():
        print(f"Model {model_path} found, loading")
        tf_model = tf.keras.models.load_model(model_path)
        tf_model.trainable = False
        print(f"Model {model_path} loaded")
    else:
        print(f"Model not found at path: {model_path}")
    if CLASS_LIST_PATHS[i].exists():
        print(f"Class list found at path: {CLASS_LIST_PATHS[i]}, loading")
        class_list = joblib.load(CLASS_LIST_PATHS[i])
        print(f"Class list loaded")

    start_predict = datetime.now()
    predictions = tf_model.predict(test_data)
    end_predict = datetime.now()
    prediction_time = end_predict - start_predict
    predicted_classes = np.argmax(predictions, axis=-1)
    true_labels = []
    for images, true_label in test_data:
        true_labels.extend(true_label)
    all_predictions = np.array(predicted_classes)
    all_true_labels = np.array(true_labels)
    precision = precision_score(all_true_labels, all_predictions, average="weighted")
    recall = recall_score(all_true_labels, all_predictions, average="weighted")
    accuracy = accuracy_score(all_true_labels, all_predictions)
    f1 = f1_score(all_true_labels, all_predictions, average="weighted")
    base_training_time = joblib.load(BASE_TRAINING_TIME_PATHS[i])
    results.append(
        {
            "model": MODEL_NAMES[i],
            "base_training_time": base_training_time,
            "prediction_time": prediction_time,
            "precision": precision,
            "recall": recall,
            "accuracy": accuracy,
            "f1": f1,
        }
    )
    print(f"Deleting Model {tf_model}")
    del tf_model
    tf.keras.backend.clear_session()
results_df = pl.DataFrame(results)
results_df