# Context

After training the model from scratch, we decided it would be better to train a separate model for each phylum. We also opted to use pretrained models. In this notebook, we aim to identify the best-performing pretrained model.

- We compare the models based on their performance on the majority class phylum.
- The same preprocessing steps and model pipeline were used for each model to ensure a fair comparison.


# Imports

In [1]:
from google.colab import drive
import zipfile
drive.mount('/content/drive')

zip_path = '/content/drive/MyDrive/rare_species 1.zip'
extract_path = '/content/rare_species 1'
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

Mounted at /content/drive


In [7]:
import os
import shutil
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow import data as tf_data
from tensorflow.keras import layers
from tensorflow.keras.applications import VGG16, ResNet50, MobileNetV2, Xception, DenseNet121
from tensorflow.keras.layers import Rescaling, RandAugment

from sklearn.metrics import classification_report

In [3]:
# With colab
folder_path = '/content/rare_species 1/rare_species 1'
meta = pd.read_csv('/content/rare_species 1/rare_species 1/metadata.csv')

# With vscode
# folder_path = '../data/rare_species 1'
# meta = pd.read_csv('../data/rare_species 1/metadata.csv')

In [4]:
print(f"the diferent Phylum are: \n{meta['phylum'].unique()}")
print(f"each phylum contains :  \n{meta['phylum'].value_counts()}")

print(f"their is {meta['family'].nunique()} different families")

meta

the diferent Phylum are: 
['mollusca' 'chordata' 'arthropoda' 'echinodermata' 'cnidaria']
each phylum contains :  
phylum
chordata         9952
arthropoda        951
cnidaria          810
mollusca          210
echinodermata      60
Name: count, dtype: int64
their is 202 different families


Unnamed: 0,rare_species_id,eol_content_id,eol_page_id,kingdom,phylum,family,file_path
0,75fd91cb-2881-41cd-88e6-de451e8b60e2,12853737,449393,animalia,mollusca,unionidae,mollusca_unionidae/12853737_449393_eol-full-si...
1,28c508bc-63ff-4e60-9c8f-1934367e1528,20969394,793083,animalia,chordata,geoemydidae,chordata_geoemydidae/20969394_793083_eol-full-...
2,00372441-588c-4af8-9665-29bee20822c0,28895411,319982,animalia,chordata,cryptobranchidae,chordata_cryptobranchidae/28895411_319982_eol-...
3,29cc6040-6af2-49ee-86ec-ab7d89793828,29658536,45510188,animalia,chordata,turdidae,chordata_turdidae/29658536_45510188_eol-full-s...
4,94004bff-3a33-4758-8125-bf72e6e57eab,21252576,7250886,animalia,chordata,indriidae,chordata_indriidae/21252576_7250886_eol-full-s...
...,...,...,...,...,...,...,...
11978,1fa96ea5-32fa-4a25-b8d2-fa99f6e2cb89,29734618,1011315,animalia,chordata,leporidae,chordata_leporidae/29734618_1011315_eol-full-s...
11979,628bf2b4-6ecc-4017-a8e6-4306849e0cfc,29972861,1056842,animalia,chordata,emydidae,chordata_emydidae/29972861_1056842_eol-full-si...
11980,0ecfdec9-b1cd-4d43-96fc-2f8889ec1ad9,30134195,52572074,animalia,chordata,dasyatidae,chordata_dasyatidae/30134195_52572074_eol-full...
11981,27fdb1e9-c5fb-459a-8b6a-6fb222b1c512,9474963,46559139,animalia,chordata,mustelidae,chordata_mustelidae/9474963_46559139_eol-full-...


# Phylum Splits

This code splits the species into separate folders based on their phylum.  
This organization allows us to train a dedicated model for each phylum more effectively.


In [5]:
# With colab
current_locations = '/content/rare_species 1/rare_species 1'

# with vscode
# current_locations = '../data/rare_species 1'

for _, row in meta.iterrows():

    phylum = row['phylum']
    file_path = row['file_path']


    file_location = os.path.join(current_locations, file_path)

    # create a a detination folder keeping the subfolder structure

        # with colab
    target_folder = os.path.join(phylum, os.path.dirname(file_path))

        # with vscode
    # target_folder = os.path.join("../data" , phylum, os.path.dirname(file_path))

    os.makedirs(target_folder, exist_ok=True)  # Make sure the folder exists

    # Final destination path
    destination = os.path.join(target_folder, os.path.basename(file_path))

    # Copy the file if it exists
    if os.path.exists(file_location):
        shutil.copy2(file_location, destination)
    else:
        print(f"Couldn't find the file: {file_location}")

# Splits

In [8]:
# with colab
path_phylum_athropoda = "/content/arthropoda"
path_phylum_chordata = "/content/chordata"
path_phylum_cnidaria = "/content/cnidaria"
path_phylum_mollusca = "/content/mollusca"

# with vscode
# path_phylum_athropoda = "../data/arthropoda"
# path_phylum_chordata = "../data/chordata"
# path_phylum_cnidaria = "../data/cnidaria"
# path_phylum_mollusca = "../data/mollusca"

image_size = (224, 224)
seed = 42
batch_size = 32

train_ds_athoropa, val_ds_athropoda= keras.utils.image_dataset_from_directory(
    path_phylum_athropoda,
    validation_split=0.2,
    subset= "both",
    seed= seed,
    image_size= image_size,
    batch_size= batch_size
)

train_ds_chordata, val_ds_chordata = keras.utils.image_dataset_from_directory(
    path_phylum_chordata,
    validation_split=0.2,
    subset="both",
    seed=seed,
    image_size= image_size,
    batch_size= batch_size
)

train_ds_cnidaria, val_ds_cnidaria = keras.utils.image_dataset_from_directory(
    path_phylum_cnidaria,
    validation_split=0.2,
    subset="both",
    seed=seed,
    image_size= image_size,
    batch_size= batch_size
)

train_ds_mollusca, val_ds_mollusca = keras.utils.image_dataset_from_directory(
    path_phylum_mollusca,
    validation_split=0.2,
    subset="both",
    seed=seed,
    image_size= image_size,
    batch_size= batch_size
)


Found 951 files belonging to 17 classes.
Using 761 files for training.
Using 190 files for validation.
Found 9952 files belonging to 166 classes.
Using 7962 files for training.
Using 1990 files for validation.
Found 810 files belonging to 13 classes.
Using 648 files for training.
Using 162 files for validation.
Found 210 files belonging to 5 classes.
Using 168 files for training.
Using 42 files for validation.


# Defining the different models

In [9]:
# Model creation functions for different architectures
def make_model_vgg16(input_shape, num_classes):
    inputs = keras.Input(shape=input_shape)

    # Rescaling layer
    x = RandAugment(value_range= (0, 255))(inputs)
    x = Rescaling(1./255)(x)

    # Pretrained VGG16 base
    base_model = VGG16(include_top=False, input_tensor=x, weights="imagenet")
    base_model.trainable = False  # Freeze for transfer learning

    x = base_model.output
    x = layers.Flatten()(x)
    x = layers.Dropout(0.1)(x)  # Optional regularization

    outputs = layers.Dense(num_classes, activation="softmax")(x)

    return keras.Model(inputs, outputs)

def make_model_resnet50(input_shape, num_classes):
    inputs = keras.Input(shape=input_shape)

    # Rescaling layer
    x = RandAugment(value_range= (0, 255))(inputs)
    x = Rescaling(1./255)(x)

    # Pretrained ResNet50 base
    base_model = ResNet50(include_top=False, input_tensor=x, weights="imagenet")
    base_model.trainable = False  # Freeze for transfer learning

    x = base_model.output
    x = layers.Flatten()(x)
    x = layers.Dropout(0.1)(x)  # Optional regularization

    outputs = layers.Dense(num_classes, activation="softmax")(x)

    return keras.Model(inputs, outputs)

def make_model_mobilenetv2(input_shape, num_classes):
    inputs = keras.Input(shape=input_shape)

    # Rescaling layer
    x = RandAugment(value_range= (0, 255))(inputs)
    x = Rescaling(1./255)(x)

    # Pretrained MobileNetV2 base
    base_model = MobileNetV2(include_top=False, input_tensor=x, weights="imagenet")
    base_model.trainable = False  # Freeze for transfer learning

    x = base_model.output
    x = layers.Flatten()(x)
    x = layers.Dropout(0.1)(x)  # Optional regularization

    outputs = layers.Dense(num_classes, activation="softmax")(x)

    return keras.Model(inputs, outputs)

def make_model_xception(input_shape, num_classes):
    inputs = keras.Input(shape=input_shape)

    # Rescaling layer
    x = RandAugment(value_range= (0, 255))(inputs)
    x = Rescaling(1./255)(x)

    # Pretrained Xception base
    base_model = Xception(include_top=False, input_tensor=x, weights="imagenet")
    base_model.trainable = False  # Freeze for transfer learning

    x = base_model.output
    x = layers.Flatten()(x)
    x = layers.Dropout(0.1)(x)  # Optional regularization

    outputs = layers.Dense(num_classes, activation="softmax")(x)

    return keras.Model(inputs, outputs)

def make_model_densenet121(input_shape, num_classes):
    inputs = keras.Input(shape=input_shape)

    # Rescaling layer
    x = RandAugment(value_range= (0, 255))(inputs)
    x = Rescaling(1./255)(x)

    # Pretrained DenseNet121 base
    base_model = DenseNet121(include_top=False, input_tensor=x, weights="imagenet")
    base_model.trainable = False  # Freeze for transfer learning

    x = base_model.output
    x = layers.Flatten()(x)
    x = layers.Dropout(0.1)(x)  # Optional regularization

    outputs = layers.Dense(num_classes, activation="softmax")(x)

    return keras.Model(inputs, outputs)

# Train and evaluate the models

In [10]:
def train_and_evaluate_model(model, model_name, train_ds, val_ds, epochs=50):
    """Train and evaluate a model, saving the best version"""

    # Learning rate schedule
    lr_schedule = keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate=1e-3,
        decay_steps=10000,
        decay_rate=0.9
    )

    # Callbacks
    checkpoint_path = f"best_model_{model_name}.keras"
    callbacks = [
        keras.callbacks.ModelCheckpoint(
            checkpoint_path,
            save_best_only=True,
            monitor="val_acc",
            mode="max",
            verbose=1
        )
    ]

    # Compile the model
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
        metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],
    )

    # Train the model
    history = model.fit(
        train_ds,
        epochs=epochs,
        callbacks=callbacks,
        validation_data=val_ds,
    )

    # Load the best model
    best_model = keras.models.load_model(checkpoint_path)

    # Get predictions
    y_pred_probs = best_model.predict(val_ds)
    y_pred = np.argmax(y_pred_probs, axis=1)

    # Get true labels
    y_true = np.concatenate([y for x, y in val_ds], axis=0)

    # Print classification report
    print(f"\nClassification Report for {model_name}:")
    report = classification_report(y_true, y_pred, output_dict=True)
    print(classification_report(y_true, y_pred))

    # Return metrics and paths
    return {
        'model_name': model_name,
        'history': history.history,
        'accuracy': report['accuracy'],
        'f1_macro': report['macro avg']['f1-score'],
        'f1_weighted': report['weighted avg']['f1-score'],
        'model_path': checkpoint_path
    }

def compare_models(results, dataset_name):
    """Compare results from multiple models"""

    # Create comparison DataFrame
    comparison = pd.DataFrame([
        {'Model': r['model_name'],
         'Accuracy': r['accuracy'],
         'F1 (Macro)': r['f1_macro'],
         'F1 (Weighted)': r['f1_weighted']}
        for r in results
    ])

    # Sort by accuracy
    comparison = comparison.sort_values('Accuracy', ascending=False)

    # Print results
    print(f"\n=== Model Comparison for {dataset_name} ===")
    print(comparison)

    # Create visualization
    plt.figure(figsize=(12, 6))

    # Plot accuracy comparison
    plt.subplot(1, 2, 1)
    plt.bar(comparison['Model'], comparison['Accuracy'])
    plt.title('Accuracy Comparison')
    plt.ylabel('Accuracy')
    plt.xticks(rotation=45)
    plt.ylim(0, 1)

    # Plot F1 comparison
    plt.subplot(1, 2, 2)
    plt.bar(comparison['Model'], comparison['F1 (Weighted)'])
    plt.title('F1 Score (Weighted) Comparison')
    plt.ylabel('F1 Score')
    plt.xticks(rotation=45)
    plt.ylim(0, 1)

    plt.tight_layout()
    plt.savefig(f'model_comparison_{dataset_name}.png')
    plt.show()

    # Print best model
    best_model = comparison.iloc[0]
    print(f"\nBest model for {dataset_name}: {best_model['Model']}")
    print(f"Accuracy: {best_model['Accuracy']:.4f}")
    print(f"F1 Score (Weighted): {best_model['F1 (Weighted)']:.4f}")

    return comparison

# Function to plot learning curves
def plot_learning_curves(results, dataset_name):
    """Plot learning curves for all models"""

    plt.figure(figsize=(12, 6))

    # Plot accuracy
    plt.subplot(1, 2, 1)
    for result in results:
        plt.plot(result['history']['acc'], label=f"{result['model_name']} (Train)")
        plt.plot(result['history']['val_acc'], label=f"{result['model_name']} (Val)", linestyle='--')

    plt.title(f'Accuracy - {dataset_name}')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend(loc='lower right')
    plt.grid(True)

    # Plot loss
    plt.subplot(1, 2, 2)
    for result in results:
        plt.plot(result['history']['loss'], label=f"{result['model_name']} (Train)")
        plt.plot(result['history']['val_loss'], label=f"{result['model_name']} (Val)", linestyle='--')

    plt.title(f'Loss - {dataset_name}')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend(loc='upper right')
    plt.grid(True)

    plt.tight_layout()
    plt.savefig(f'learning_curves_{dataset_name}.png')
    plt.show()

# Model run

In [None]:
epochs = 20


model_vgg16_chordata = make_model_vgg16(input_shape=image_size + (3,), num_classes=166)
model_resnet50_chordata = make_model_resnet50(input_shape=image_size + (3,), num_classes=166)
model_mobilenet_chordata = make_model_mobilenetv2(input_shape=image_size + (3,), num_classes=166)

# Train and evaluate models
results_chordata = []

# Train VGG16
print("\n=== Training VGG16 on chordata dataset ===")
result_vgg16 = train_and_evaluate_model(
    model=model_vgg16_chordata,
    model_name="vgg16",
    train_ds=train_ds_chordata,
    val_ds=val_ds_chordata,
    epochs=epochs
)
results_chordata.append(result_vgg16)

# Train ResNet50
print("\n=== Training ResNet50 on chordata dataset ===")
result_resnet50 = train_and_evaluate_model(
    model=model_resnet50_chordata,
    model_name="resnet50",
    train_ds=train_ds_chordata,
    val_ds=val_ds_chordata,
    epochs=epochs
)
results_chordata.append(result_resnet50)

# Train MobileNetV2
print("\n=== Training MobileNetV2 on chordata dataset ===")
result_mobilenet = train_and_evaluate_model(
    model=model_mobilenet_chordata,
    model_name="mobilenetv2",
    train_ds=train_ds_chordata,
    val_ds=val_ds_chordata,
    epochs=epochs
)
results_chordata.append(result_mobilenet)

# Compare models
compare_models(results_chordata, "Chordata")

# Plot learning curves
plot_learning_curves(results_chordata, "Chordata")

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5
[1m58889256/58889256[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
[1m94765736/94765736[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 0us/step


  base_model = MobileNetV2(include_top=False, input_tensor=x, weights="imagenet")


Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top.h5
[1m9406464/9406464[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step

=== Training VGG16 on chordata dataset ===
Epoch 1/20
[1m  1/249[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m1:35:05[0m 23s/step - acc: 0.0000e+00 - loss: 5.6451