In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models, regularizers
from tensorflow.keras.utils import load_img, img_to_array
from tensorflow.keras.applications import ResNet50
import matplotlib.pyplot as plt
import json

# Set random seeds
np.random.seed(42)
tf.random.set_seed(42)

BASE_PATH = '/kaggle/input/wildlife-dataset'

CLASS_NAMES = {
    0: 'Zebra', 1: 'Lion', 2: 'Leopard', 3: 'Cheetah', 4: 'Tiger',
    5: 'Bear', 6: 'Butterfly', 7: 'Canary', 8: 'Crocodile', 9: 'Bull',
    10: 'Camel', 11: 'Centipede', 12: 'Caterpillar', 13: 'Duck', 14: 'Squirrel',
    15: 'Spider', 16: 'Ladybug', 17: 'Elephant', 18: 'Horse', 19: 'Fox',
    20: 'Tortoise', 21: 'Frog', 22: 'Kangaroo', 23: 'Deer', 24: 'Eagle',
    25: 'Monkey', 26: 'Snake', 27: 'Owl', 28: 'Swan', 29: 'Goat',
    30: 'Rabbit', 31: 'Giraffe', 32: 'Goose', 33: 'PolarBear', 34: 'Raven',
    35: 'Hippopotamus', 36: 'BrownBear', 37: 'Rhinoceros', 38: 'Woodpecker', 39: 'Sheep',
    40: 'Magpie', 41: 'Ostrich', 42: 'Jaguar', 43: 'Hedgehog', 44: 'Turkey',
    45: 'Raccoon', 46: 'Worm', 47: 'Harbor', 48: 'Panda', 49: 'RedPanda',
    50: 'Otter', 51: 'Lynx', 52: 'Scorpion', 53: 'Koala'
}

# Optimized hyperparameters
IMG_HEIGHT = 224  # Changed to 224 for better compatibility
IMG_WIDTH = 224
BATCH_SIZE = 32
EPOCHS = 100
LEARNING_RATE = 0.0001
NUM_CLASSES = 54

USE_TRANSFER_LEARNING = True  # Set to False for pure AlexNet

print("="*70)
print("AlexNet Training - Best Practices Edition")
print(f"Transfer Learning: {USE_TRANSFER_LEARNING}")
print("="*70)

def parse_yolo_label(label_path):
    if os.path.exists(label_path):
        with open(label_path, 'r') as f:
            line = f.readline().strip()
            if line:
                return int(line.split()[0])
    return None

def load_dataset(split_name):
    images_path = os.path.join(BASE_PATH, split_name, 'images')
    labels_path = os.path.join(BASE_PATH, split_name, 'labels')

    image_files = []
    labels_list = []

    if not os.path.exists(images_path):
        return [], []

    print(f"Loading {split_name}...")

    for img_file in sorted(os.listdir(images_path)):
        if img_file.lower().endswith(('.jpg', '.jpeg', '.png')):
            img_path = os.path.join(images_path, img_file)
            label_file = os.path.splitext(img_file)[0] + '.txt'
            label_path = os.path.join(labels_path, label_file)
            class_id = parse_yolo_label(label_path)

            if class_id is not None and 0 <= class_id < NUM_CLASSES:
                image_files.append(img_path)
                labels_list.append(class_id)

    print(f"  Loaded {len(image_files)} images")
    return image_files, labels_list

# Load data
train_images, train_labels = load_dataset('train')
val_images, val_labels = load_dataset('valid')
test_images, test_labels = load_dataset('test')

print(f"\nTraining: {len(train_images)}, Validation: {len(val_images)}, Test: {len(test_images)}")

# Data generator
class WildlifeDataGenerator(keras.utils.Sequence):
    def __init__(self, image_paths, labels, batch_size, img_size, num_classes, augment=False):
        super().__init__()
        self.image_paths = image_paths
        self.labels = labels
        self.batch_size = batch_size
        self.img_size = img_size
        self.num_classes = num_classes
        self.augment = augment
        self.indices = np.arange(len(self.image_paths))
        self.on_epoch_end()

        if self.augment:
            self.data_augmentation = keras.Sequential([
                layers.RandomFlip("horizontal"),
                layers.RandomRotation(0.15),
                layers.RandomZoom(0.15),
                layers.RandomContrast(0.15),
            ])

    def __len__(self):
        return len(self.image_paths) // self.batch_size

    def __getitem__(self, idx):
        batch_indices = self.indices[idx * self.batch_size:(idx + 1) * self.batch_size]

        batch_images = []
        batch_labels = []

        for i in batch_indices:
            img = keras.utils.load_img(self.image_paths[i], target_size=self.img_size)
            img_array = keras.utils.img_to_array(img) / 255.0
            batch_images.append(img_array)
            batch_labels.append(self.labels[i])

        batch_images = np.array(batch_images)
        batch_labels = keras.utils.to_categorical(batch_labels, self.num_classes)

        if self.augment:
            batch_images = self.data_augmentation(batch_images, training=True)

        return batch_images, batch_labels

    def on_epoch_end(self):
        if self.augment:
            np.random.shuffle(self.indices)

# Create AlexNet
def create_alexnet(num_classes):
    model = models.Sequential([
        # Block 1
        layers.Conv2D(96, 11, strides=4, activation='relu', input_shape=(IMG_HEIGHT, IMG_WIDTH, 3)),
        layers.BatchNormalization(),
        layers.MaxPooling2D(3, strides=2),

        # Block 2
        layers.Conv2D(256, 5, padding='same', activation='relu'),
        layers.BatchNormalization(),
        layers.MaxPooling2D(3, strides=2),

        # Block 3
        layers.Conv2D(384, 3, padding='same', activation='relu'),
        layers.BatchNormalization(),

        # Block 4
        layers.Conv2D(384, 3, padding='same', activation='relu'),
        layers.BatchNormalization(),

        # Block 5
        layers.Conv2D(256, 3, padding='same', activation='relu'),
        layers.BatchNormalization(),
        layers.MaxPooling2D(3, strides=2),

        # Classifier
        layers.GlobalAveragePooling2D(),  # Better than Flatten
        layers.Dense(1024, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(512, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(num_classes, activation='softmax')
    ], name='AlexNet')
    return model

# Create transfer learning model
def create_transfer_model(num_classes):
    base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(IMG_HEIGHT, IMG_WIDTH, 3))

    # Freeze base model
    base_model.trainable = False

    model = models.Sequential([
        base_model,
        layers.GlobalAveragePooling2D(),
        layers.BatchNormalization(),
        layers.Dense(512, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(256, activation='relu'),
        layers.Dropout(0.3),
        layers.Dense(num_classes, activation='softmax')
    ], name='TransferLearning_ResNet50')

    return model, base_model

# Create generators
train_gen = WildlifeDataGenerator(train_images, train_labels, BATCH_SIZE, (IMG_HEIGHT, IMG_WIDTH), NUM_CLASSES, augment=True)
val_gen = WildlifeDataGenerator(val_images, val_labels, BATCH_SIZE, (IMG_HEIGHT, IMG_WIDTH), NUM_CLASSES, augment=False)
test_gen = WildlifeDataGenerator(test_images, test_labels, BATCH_SIZE, (IMG_HEIGHT, IMG_WIDTH), NUM_CLASSES, augment=False)

# Build model
if USE_TRANSFER_LEARNING:
    print("\nBuilding Transfer Learning Model (ResNet50)...")
    model, base_model = create_transfer_model(NUM_CLASSES)
else:
    print("\nBuilding AlexNet from scratch...")
    model = create_alexnet(NUM_CLASSES)

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    loss='categorical_crossentropy',
    metrics=['accuracy', keras.metrics.TopKCategoricalAccuracy(k=5, name='top5_acc')]
)

model.summary()

# Callbacks
callbacks = [
    keras.callbacks.ModelCheckpoint(
        'best_model.h5',
        monitor='val_accuracy',
        save_best_only=True,
        verbose=1
    ),
    keras.callbacks.EarlyStopping(
        monitor='val_accuracy',
        patience=20,
        restore_best_weights=True,
        verbose=1
    ),
    keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.3,
        patience=8,
        min_lr=1e-7,
        verbose=1
    ),
]

# Phase 1: Train with frozen base (if transfer learning)
print("\n" + "="*70)
print("Phase 1: Training...")
print("="*70)

history1 = model.fit(
    train_gen,
    epochs=EPOCHS if not USE_TRANSFER_LEARNING else 30,
    validation_data=val_gen,
    callbacks=callbacks,
    verbose=1
)

# Phase 2: Fine-tune (if transfer learning)
if USE_TRANSFER_LEARNING:
    print("\n" + "="*70)
    print("Phase 2: Fine-tuning (unfreezing layers)...")
    print("="*70)

    # Unfreeze the base model
    base_model.trainable = True

    # Recompile with lower learning rate
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE/10),
        loss='categorical_crossentropy',
        metrics=['accuracy', keras.metrics.TopKCategoricalAccuracy(k=5, name='top5_acc')]
    )

    history2 = model.fit(
        train_gen,
        epochs=30,
        validation_data=val_gen,
        callbacks=callbacks,
        verbose=1
    )

    # Combine histories
    for key in history1.history.keys():
        history1.history[key].extend(history2.history[key])

history = history1

# Evaluate
print("\n" + "="*70)
print("Final Evaluation...")
print("="*70)

test_results = model.evaluate(test_gen, verbose=1)
print(f"\nTest Accuracy: {test_results[1]*100:.2f}%")
print(f"Test Top-5 Accuracy: {test_results[2]*100:.2f}%")

# Plot
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

axes[0, 0].plot(history.history['accuracy'], label='Train', linewidth=2)
axes[0, 0].plot(history.history['val_accuracy'], label='Val', linewidth=2)
axes[0, 0].set_title('Accuracy', fontsize=14, fontweight='bold')
axes[0, 0].legend()
axes[0, 0].grid(alpha=0.3)

axes[0, 1].plot(history.history['loss'], label='Train', linewidth=2)
axes[0, 1].plot(history.history['val_loss'], label='Val', linewidth=2)
axes[0, 1].set_title('Loss', fontsize=14, fontweight='bold')
axes[0, 1].legend()
axes[0, 1].grid(alpha=0.3)

axes[1, 0].plot(history.history['top5_acc'], label='Train', linewidth=2)
axes[1, 0].plot(history.history['val_top5_acc'], label='Val', linewidth=2)
axes[1, 0].set_title('Top-5 Accuracy', fontsize=14, fontweight='bold')
axes[1, 0].legend()
axes[1, 0].grid(alpha=0.3)

gap = np.array(history.history['accuracy']) - np.array(history.history['val_accuracy'])
axes[1, 1].plot(gap, linewidth=2, color='red')
axes[1, 1].axhline(y=0, color='black', linestyle='--', alpha=0.5)
axes[1, 1].set_title('Overfitting Gap', fontsize=14, fontweight='bold')
axes[1, 1].grid(alpha=0.3)

plt.tight_layout()
plt.savefig('final_training_history.png', dpi=300)
print("Plot saved!")
plt.show()

model.save('final_wildlife_model.h5')
print("Model saved!")

with open('class_names.json', 'w') as f:
    json.dump(CLASS_NAMES, f, indent=2)

print("\n" + "="*70)
print("COMPLETE!")
print("="*70)