In [None]:
import os
import random
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import cv2

In [None]:
def load_animal_dataset(dataset_path):
    images = []
    labels = []
    class_images = {}

    # -------- Load Images --------
    for label_name in os.listdir(dataset_path):
        class_folder = os.path.join(dataset_path, label_name)

        if not os.path.isdir(class_folder):
            continue

        class_images[label_name] = []

        for file in os.listdir(class_folder):
            img_path = os.path.join(class_folder, file)

            try:
                img = Image.open(img_path)
                img_array = np.array(img)

                images.append(img_array)
                labels.append(label_name)
                class_images[label_name].append(img_array)

            except Exception as e:
                print("Failed to load:", img_path, "Error:", e)

    print("Total images loaded:", len(images))
    return images, labels, class_images


def show_samples(class_images, samples_per_class=3):
    # -------- Display Sample Images --------
    for class_name, img_list in class_images.items():
        print(f"\nShowing samples for: {class_name}")

        sample_imgs = random.sample(img_list, min(samples_per_class, len(img_list)))

        plt.figure(figsize=(10, 3))
        for i, img in enumerate(sample_imgs):
            plt.subplot(1, samples_per_class, i + 1)
            plt.imshow(img)
            plt.title(class_name)
            plt.axis("off")

        plt.show()


In [None]:
images, labels, class_images = load_animal_dataset("animal_dataset")
show_samples(class_images=class_images)

In [None]:
def preprocess_images(images, labels, class_map=None, normalize=True):
    """
    Preprocess raw images for ML models.

    Parameters:
        images : list of np.array
            Raw images (RGB or grayscale)
        labels : list of str
            Class names
        class_map : dict, optional
            Mapping from class name to integer
        normalize : bool
            If True, scale pixels to [0,1]

    Returns:
        X : np.ndarray
            Preprocessed flattened images (N, 1024)
        y : np.ndarray
            Integer labels (N,)
        class_map : dict
            Mapping class name -> integer label
    """
    if class_map is None:
        classes = sorted(list(set(labels)))
        class_map = {cls: idx for idx, cls in enumerate(classes)}

    X_list = []
    y_list = []

    for img, lbl in zip(images, labels):
        # Convert to grayscale if not already
        if len(img.shape) == 3:
            img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)

        # Resize to 32x32
        img_resized = cv2.resize(img, (32, 32))

        # Flatten
        X_list.append(img_resized.flatten())

        # Integer label
        y_list.append(class_map[lbl])

    X = np.array(X_list, dtype=np.float32)
    y = np.array(y_list, dtype=np.int32)

    if normalize:
        X /= 255.0

    return X, y, class_map


In [None]:
X, y, class_map = preprocess_images(images, labels)

print("X shape:", X.shape)
print("y shape:", y.shape)
print("Class map:", class_map)


In [None]:
def show_samples_of_class(X, y, class_label, samples=3, img_shape=(32,32)):
    """
    Display sample images of a single class.

    Parameters:
        X : np.ndarray
            Preprocessed images (N, 1024) or (N, H, W)
        y : np.ndarray
            Labels (N,)
        class_label : int or str
            Label of the class to show
        samples : int
            Number of images to display
        img_shape : tuple
            Shape to reshape flattened images
    """
    # Get indices for the requested class
    indices = np.where(y == class_label)[0]
    sample_indices = random.sample(list(indices), min(samples, len(indices)))

    plt.figure(figsize=(10, 3))
    for i, idx in enumerate(sample_indices):
        img = X[idx].reshape(img_shape) if len(X.shape) == 2 else X[idx]
        plt.subplot(1, samples, i+1)
        plt.imshow(img, cmap='gray')
        plt.title(f"Label: {class_label}")
        plt.axis('off')
    plt.show()


In [None]:
show_samples_of_class(X, y, class_label=0)
show_samples_of_class(X, y, class_label=1)
show_samples_of_class(X, y, class_label=2)