## Project GERALD

### Objective



#### Authors: Suman Senapati
####          Sanjay Rao
####          Pratibha

### Importing Libraries

In [3]:
import zipfile
import json
import xml.etree.ElementTree as ET
from pathlib import Path
from collections import defaultdict, Counter
from tqdm import tqdm
import matplotlib.pyplot as plt
import random
from typing import Dict, List, Set, Tuple, Optional, Any

import os
import glob
import pandas as pd
import seaborn as sns
import matplotlib.patches as patches
from PIL import Image

import os
import cv2
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import xml.etree.ElementTree as ET
from pathlib import Path
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score
from sklearn.preprocessing import LabelBinarizer

import tensorflow as tf
from tensorflow.keras.applications import ResNet50, VGG16, MobileNetV2, EfficientNetB0
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout, Input
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical
import tensorflow.keras.utils as keras_utils
import tensorflow.keras.backend as K

import warnings
warnings.filterwarnings('ignore')

ModuleNotFoundError: No module named 'matplotlib'

## Configuration

In [None]:
OUT_DIR = "../dataset/GERALD_subset"
TARGET_MB = 2400
CLASSES_KEEP = ["main_signal", "distant_signal"]

CATEGORY_ALIASES = {
    "main_signal": ["Hp_0_HV", "Hp_1", "Hp_2", "Hp_0_Sh", "Ks_1", "Ks_2"],
    "distant_signal": ["Vr_0", "Vr_1", "Vr_2"]
}

## Helper classes

#####  Class to create a balanced subset of the GERALD dataset under different weather and light conditions.

In [None]:
class GeraldSignalSubsetCreator:

    """
    Class to create a balanced subset of the GERALD dataset under different weather and light conditions.
    
    Attributes:
        out_dir (Path): Output directory for the subset.
        target_bytes (int): Maximum target dataset size in bytes.
        current_bytes (int): Current accumulated dataset size in bytes.
        class_stats (defaultdict): Counter for objects per class.
        all_classes (set): Set of all classes detected in the dataset.
    """
    
    def __init__(self):
        self.out_dir = Path(OUT_DIR)
        self.target_bytes = TARGET_MB * 1024 * 1024
        self.current_bytes = 0
        self.class_stats = defaultdict(int)
        self.all_classes = set()

        (self.out_dir / "images").mkdir(parents=True, exist_ok=True)
        (self.out_dir / "annotations").mkdir(parents=True, exist_ok=True)

# Subset Creation
    def create_subset(self, zip_path: Path):

        """
        Creates a balanced subset from the GERALD dataset ZIP file.
        
        Args:
            zip_path (Path): Path to the GERALD.zip archive.
            
        Raises:
            RuntimeError: If info.json is missing or class mapping fails.
            AssertionError: If dataset validation fails.
        """
        mapping = {}

        with zipfile.ZipFile(zip_path, 'r') as zf:

            # Locate info.json inside GERALD/ directory
            info_candidates = [
                f for f in zf.namelist()
                if f.lower().endswith("info.json")
            ]

            if not info_candidates:
                raise RuntimeError("info.json not found anywhere inside GERALD.zip!")

            info_path = info_candidates[0]
            print("Found info.json at:", info_path)

            raw = zf.read(info_path)

            # Encoding fallback
            for enc in ["utf-8", "utf-16", "latin1", "cp1252"]:
                try:
                    info_json = json.loads(raw.decode(enc))
                    print("Loaded info.json successfully with:", enc)
                    break
                except Exception:
                    continue
            else:
                raise RuntimeError("Could not decode info.json in any common encoding")

            (self.out_dir / "info.json").write_bytes(raw)

            # Collect Annotation XML files
            xml_files = [f for f in zf.namelist() if f.lower().endswith(".xml")]

            print("Scanning XMLs to collect class names!!!")
            for xml_file in tqdm(xml_files, desc="Scanning XMLs"):
                try:
                    with zf.open(xml_file) as f:
                        root = ET.parse(f).getroot()
                        for obj in root.findall("object"):
                            self.all_classes.add(obj.find("name").text.strip())
                except:
                    continue

            # Build alias mapping
            for target_class, aliases in CATEGORY_ALIASES.items():
                for cls in self.all_classes:
                    if any(a.lower() in cls.lower() for a in aliases):
                        mapping[cls] = target_class

            if not mapping:
                raise RuntimeError("No classes matched alias rules!")

            # Build weather/light mapping
            print("Indexing weather/light combinations!!!")
            wls = defaultdict(list)

            for xml_file in tqdm(xml_files, desc="Identifying weather/light"):
                try:
                    with zf.open(xml_file) as f:
                        root = ET.parse(f).getroot()

                    img_name = root.find("filename").text.strip()

                    if img_name not in info_json:
                        continue

                    weather = info_json[img_name]["weather"]
                    light = info_json[img_name]["light"]

                    wls[(weather, light)].append(xml_file)

                except:
                    continue

            # Plot the distribution
            self.plot_distribution(wls)


            # Auto tuning logic ->Auto tune max per-pair 
            #---- logic ------------------------------
            # 1. Compute the size of each weather/light group.
            # 2. Sort the group sizes to find the median group size.
            # 3. If the number of groups is odd, take the middle value as the median.
            #    If even, take the average of the two middle values.
            # 4. Set MAX_PER_PAIR as the median group size but restrict it to a minimum of 50
            #    and a maximum of 300 to avoid selecting too few or too many samples.
            # This ensures a balanced and reasonable selection of images per weather/light combination.

            # ----------------------------------------

            print("\nAuto-tuning max samples per weather/light pair!!!!!")

            group_sizes = [len(v) for v in wls.values()]
            group_sizes_sorted = sorted(group_sizes)

            mid = len(group_sizes_sorted) // 2
            if len(group_sizes_sorted) % 2 == 1:
                median_size = group_sizes_sorted[mid]
            else:
                median_size = (group_sizes_sorted[mid - 1] + group_sizes_sorted[mid]) // 2

            MAX_PER_PAIR = max(50, min(median_size, 300))

            print(f"Group sizes: {group_sizes}")
            print(f"Median = {median_size}")
            print(f"Auto-selected MAX_PER_PAIR = {MAX_PER_PAIR}")


            # Randomly select XML files for each weather/light group.
            # For each group:
                # 1. Shuffle the XML list to ensure random selection.
                # 2. Take up to MAX_PER_PAIR files (or all if the group is smaller).
                # 3. Add them to the final selection.
            
            selected_xmls = []

            for pair, xml_list in wls.items():
                xml_list = xml_list.copy()
                random.shuffle(xml_list)           # RANDOM PICK INSIDE EACH GROUP
                take_n = min(len(xml_list), MAX_PER_PAIR)
                selected_xmls.extend(xml_list[:take_n])

            print(f"Total images selected = {len(selected_xmls)}")

            # Copy balanced dataset and embed metadata
            print("\nCopying balanced dataset!!!!")

            with tqdm(total=self.target_bytes, unit="B", unit_scale=True, desc="Copying dataset") as pbar:
                for xml_file in selected_xmls:
                    if self.current_bytes >= self.target_bytes:
                        break

                    try:
                        with zf.open(xml_file) as f:
                            tree = ET.parse(f)
                            root = tree.getroot()

                        filtered = []
                        for obj in root.findall("object"):
                            name = obj.find("name").text.strip()
                            if name in mapping:
                                obj.find("name").text = mapping[name]
                                filtered.append(obj)

                        if not filtered:
                            continue

                        # rewrite XML with filtered objects only
                        for obj in root.findall("object"):
                            root.remove(obj)
                        for obj in filtered:
                            root.append(obj)

                        img_name = root.find("filename").text.strip()

                        # metadata injection happens here - > (weather + light only)
                        meta = ET.SubElement(root, "metadata")
                        ET.SubElement(meta, "weather").text = info_json[img_name]["weather"]
                        ET.SubElement(meta, "light").text = info_json[img_name]["light"]

                        # Find and copy image
                        img_file_zip = next(
                            (f for f in zf.namelist() if f.endswith(img_name)),
                            None
                        )

                        if not img_file_zip:
                            continue

                        img_bytes = zf.read(img_file_zip)

                        if self.current_bytes + len(img_bytes) > self.target_bytes:
                            break

                        (self.out_dir / "images" / img_name).write_bytes(img_bytes)
                        tree.write(self.out_dir / "annotations" / Path(xml_file).name)

                        for obj in filtered:
                            self.class_stats[obj.find("name").text.strip()] += 1

                        self.current_bytes += len(img_bytes)
                        pbar.update(len(img_bytes))

                    except Exception as e:
                        print(f"Skipping {xml_file} due to {e}")

        # Save mapping and stats
        reverse_map = defaultdict(list)
        for original, target in mapping.items():
            reverse_map[target].append(original)

        with open(self.out_dir / "class_mapping.json", "w") as f:
            json.dump({
                "target_classes": CLASSES_KEEP,
                "gerald_to_target_mapping": mapping,
                "target_to_gerald_mapping": dict(reverse_map),
                "class_statistics": dict(self.class_stats),
            }, f, indent=2)

        print("\nSubset creation completed.")
        print(f"Final dataset size: {self.current_bytes / (1024*1024):.2f} MB")

        self.validate_subset()


    def plot_distribution(self, wls):
        """
        Plots a bar chart showing the number of samples per weather/light combination
        and saves the figure as 'subset_distribution.png' in the output directory.

        Args:
            wls (dict): Dictionary mapping (weather, light) pairs to lists of XML files.
        """
        
        pairs = [f"{w}-{l}" for (w, l) in wls.keys()]
        counts = [len(wls[p]) for p in wls.keys()]

        plt.figure(figsize=(14, 6))
        plt.bar(pairs, counts, color='steelblue')
        plt.xticks(rotation=45, ha='right')
        plt.title("Weather Ã— Light Distribution")
        plt.ylabel("Sample Count")
        plt.tight_layout()

        plt.savefig(self.out_dir / "subset_distribution.png")
        plt.show()

    # Subset Validation
    
    def validate_subset(self):
        
        """
        Validates the generated subset for consistency:
        - All images exist
        - All object classes are correct
        - Class statistics match JSON metadata
        Raises:
            AssertionError: If any validation check fails.
        """
        
        print("\n=== VALIDATING DATASET ===")

        ann_dir = self.out_dir / "annotations"
        img_dir = self.out_dir / "images"
        mapping_path = self.out_dir / "class_mapping.json"

        with open(mapping_path) as f:
            m = json.load(f)

        target_classes = set(m["target_classes"])
        class_stats_json = m["class_statistics"]

        """
        i. Collects all XML annotation files in the annotations directory.
        ii. Initializes counters to track object counts per class, a list for missing images, 
           and a flag to track overall validation success.
        iii. Iterates through each XML file to:
           - Ensure the corresponding image exists; if not, mark it as missing.
           - Count all objects per class and check that each object belongs to a valid target class.
        iv. Compares the counted objects with the stored class statistics from JSON to ensure 
           consistency between the XML annotations and metadata.
        v. Flags any missing images, invalid classes, or count mismatches as validation failures.
        """

        xml_files = list(ann_dir.glob("*.xml"))
        counter = Counter()
        missing_images = []
        all_good = True

        for xml_file in xml_files:
            root = ET.parse(xml_file).getroot()
            img_name = root.find("filename").text

            if not (img_dir / img_name).exists():
                missing_images.append(img_name)
                all_good = False

            for obj in root.findall("object"):
                cls = obj.find("name").text.strip()
                counter[cls] += 1

                if cls not in target_classes:
                    print(f"Invalid class '{cls}' in {xml_file.name}")
                    all_good = False

        if counter != Counter(class_stats_json):
            print("Class statistics mismatch!")
            print("Counts from XML:", dict(counter))
            print("Counts in JSON:", class_stats_json)
            all_good = False

        print("\n=== DATASET SUMMARY ===")
        print(f"Images saved: {len(list(img_dir.glob('*')))}")
        print(f"Annotations: {len(xml_files)}")
        print(f"Total objects: {sum(counter.values())}")

        print("\nObjects per class:")
        for cls, count in counter.items():
            print(f"  - {cls}: {count}")

        print(f"\nSubset size: {self.current_bytes / (1024*1024):.2f} MB")

        if missing_images:
            print("Missing images:", missing_images)
            all_good = False

        if all_good:
            print("\n All checks passed! Dataset is fully consistent!\n")
        else:
            raise AssertionError("Dataset validation failed.")

## Subset creation

In [None]:
if __name__ == "__main__":
    zip_path = Path("../dataset/GERALD.zip")
    if not zip_path.exists():
        print("GERALD.zip not found!")
    else:
        GeraldSignalSubsetCreator().create_subset(zip_path)


GERALD.zip not found!


In [None]:
BASE_DIR = Path("../dataset/GERALD_subset")
IMG_DIR = BASE_DIR / "images"
ANN_DIR = BASE_DIR / "annotations"

## Exploratory Data Analysis

#### EDA Helper Functions

In [None]:
def parse_annotations(ann_dir):
    """
    Parses all XML files in the annotation directory and creates a DataFrame.
    """
    xml_list = []
    xml_files = list(ann_dir.glob("*.xml"))
    
    print(f"Parsing {len(xml_files)} annotation files...")

    for xml_file in xml_files:
        tree = ET.parse(xml_file)
        root = tree.getroot()
        
        filename = root.find('filename').text
        img_width = int(root.find('size').find('width').text)
        img_height = int(root.find('size').find('height').text)
        
        # Extract Metadata (Weather/Light) if available
        weather = "Unknown"
        light = "Unknown"
        meta = root.find('metadata')
        if meta is not None:
            weather = meta.find('weather').text if meta.find('weather') is not None else "Unknown"
            light = meta.find('light').text if meta.find('light') is not None else "Unknown"

        # Extract Objects
        for member in root.findall('object'):
            class_name = member.find('name').text
            bndbox = member.find('bndbox')
            
            # FIXED: Parse as float first to handle decimal coordinates
            xmin = int(float(bndbox.find('xmin').text))
            ymin = int(float(bndbox.find('ymin').text))
            xmax = int(float(bndbox.find('xmax').text))
            ymax = int(float(bndbox.find('ymax').text))
            
            # Calculate derived stats
            bbox_width = xmax - xmin
            bbox_height = ymax - ymin
            area = bbox_width * bbox_height
            aspect_ratio = bbox_width / bbox_height if bbox_height > 0 else 0
            
            # Normalized center coordinates (0-1) for heatmap
            center_x = (xmin + bbox_width/2) / img_width
            center_y = (ymin + bbox_height/2) / img_height

            xml_list.append({
                'filename': filename,
                'width': img_width,
                'height': img_height,
                'class': class_name,
                'xmin': xmin,
                'ymin': ymin,
                'xmax': xmax,
                'ymax': ymax,
                'bbox_width': bbox_width,
                'bbox_height': bbox_height,
                'bbox_area': area,
                'aspect_ratio': aspect_ratio,
                'center_x': center_x,
                'center_y': center_y,
                'weather': weather,
                'light': light
            })
            
    return pd.DataFrame(xml_list)

In [None]:
def plot_class_distribution(df):
    """Plots the count of objects per class."""
    plt.figure(figsize=(10, 6))
    sns.countplot(data=df, x='class', palette='viridis', order=df['class'].value_counts().index)
    plt.title('Object Class Distribution')
    plt.xlabel('Class Name')
    plt.ylabel('Count')
    plt.xticks(rotation=45)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.show()

In [None]:
def plot_metadata_distribution(df):
    """Plots the distribution of Weather and Light conditions."""
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    # Weather
    sns.countplot(data=df, x='weather', ax=axes[0], palette='coolwarm', order=df['weather'].value_counts().index)
    axes[0].set_title('Weather Condition Distribution')
    axes[0].tick_params(axis='x', rotation=45)
    
    # Light
    sns.countplot(data=df, x='light', ax=axes[1], palette='magma', order=df['light'].value_counts().index)
    axes[1].set_title('Lighting Condition Distribution')
    axes[1].tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    plt.show()

In [None]:
def plot_bbox_statistics(df):
    """Plots Box sizes and Aspect Ratios."""
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # 1. BBox Area Distribution
    sns.histplot(df['bbox_area'], bins=30, kde=True, ax=axes[0], color='skyblue')
    axes[0].set_title('Bounding Box Area Distribution')
    axes[0].set_xlabel('Area (pixels)')
    
    # 2. Aspect Ratio Distribution
    sns.histplot(df['aspect_ratio'], bins=30, kde=True, ax=axes[1], color='orange')
    axes[1].set_title('Aspect Ratio (Width/Height)')
    axes[1].set_xlabel('Ratio')
    axes[1].axvline(1.0, color='red', linestyle='--', label='Square')
    axes[1].legend()

    # 3. Width vs Height Scatter
    sns.scatterplot(data=df, x='bbox_width', y='bbox_height', hue='class', alpha=0.6, ax=axes[2])
    axes[2].set_title('BBox Width vs Height')
    axes[2].plot([0, max(df.bbox_width)], [0, max(df.bbox_width)], 'r--', alpha=0.5) # Diagonal reference
    
    plt.tight_layout()
    plt.show()

In [None]:
def plot_object_heatmap(df):
    """Plots a 2D histogram of where objects appear in images."""
    plt.figure(figsize=(8, 6))
    plt.hist2d(df['center_x'], df['center_y'], bins=50, cmap='inferno', range=[[0, 1], [0, 1]])
    plt.colorbar(label='Count')
    plt.gca().invert_yaxis() # Image coordinates: (0,0) is top-left
    plt.title('Object Location Heatmap (Normalized Coordinates)')
    plt.xlabel('X Position (Normalized 0-1)')
    plt.ylabel('Y Position (Normalized 0-1)')
    plt.show()

In [None]:
def visualize_samples(df, img_dir, num_samples=3):
    """Draws bounding boxes on random sample images."""
    unique_files = df['filename'].unique()
    samples = random.sample(list(unique_files), min(len(unique_files), num_samples))
    
    print(f"\nVisualizing {len(samples)} random samples...")
    
    for filename in samples:
        img_path = img_dir / filename
        if not img_path.exists():
            print(f"Warning: Image {filename} not found.")
            continue
            
        # Get all boxes for this image
        img_data = df[df['filename'] == filename]
        
        # Open Image
        im = Image.open(img_path)
        fig, ax = plt.subplots(1, 1, figsize=(10, 8))
        ax.imshow(im)
        
        # Draw Boxes
        for _, row in img_data.iterrows():
            width = row['xmax'] - row['xmin']
            height = row['ymax'] - row['ymin']
            
            # Create a Rectangle patch
            rect = patches.Rectangle(
                (row['xmin'], row['ymin']), 
                width, height, 
                linewidth=2, 
                edgecolor='lime', 
                facecolor='none'
            )
            ax.add_patch(rect)
            
            # Add Label
            label = f"{row['class']}"
            plt.text(
                row['xmin'], 
                row['ymin'] - 5, 
                label, 
                color='white', 
                fontsize=10, 
                weight='bold', 
                bbox=dict(facecolor='lime', alpha=0.5, pad=2, edgecolor='none')
            )
            
        plt.title(f"Sample: {filename}\nWeather: {img_data.iloc[0]['weather']} | Light: {img_data.iloc[0]['light']}")
        plt.axis('off')
        plt.show()

In [None]:
if __name__ == "__main__":
    if not ANN_DIR.exists():
        print(f"Error: Annotation directory not found at {ANN_DIR}")
        print("Please run the subset creator script first.")
    else:
        # 1. Load Data
        df = parse_annotations(ANN_DIR)
        
        if df.empty:
            print("No annotations found!")
        else:
            print(f"Loaded {len(df)} objects from {df['filename'].nunique()} images.")
            print("-" * 30)
            
            # 2. Basic Stats
            print("Class Counts:\n", df['class'].value_counts())
            print("\nWeather Counts:\n", df['weather'].value_counts())
            
            # 3. Visualizations
            print("\nGenerating Class Distribution Plot...")
            plot_class_distribution(df)
            
            print("Generating Metadata Distribution Plot...")
            plot_metadata_distribution(df)
            
            print("Generating Bounding Box Statistics...")
            plot_bbox_statistics(df)
            
            print("Generating Object Location Heatmap...")
            plot_object_heatmap(df)
            
            # 4. Visual Samples
            visualize_samples(df, IMG_DIR, num_samples=3)

Error: Annotation directory not found at ../dataset/GERALD_subset/annotations
Please run the subset creator script first.


## Configuration

In [None]:
IMG_SIZE = (224, 224)
BATCH_SIZE = 32
EPOCHS = 10

## Preprocessing function

In [None]:
class GeraldDataGenerator(keras_utils.Sequence):
    """
    Generates data for Keras on-the-fly to save memory.
    Loads, crops, and processes images in batches.
    """
    def __init__(self, samples, batch_size, img_size, lb, num_classes, shuffle=True):
        self.samples = samples  # List of (img_path, label, bbox)
        self.batch_size = batch_size
        self.img_size = img_size
        self.lb = lb
        self.num_classes = num_classes
        self.shuffle = shuffle
        self.indexes = np.arange(len(self.samples))
        if self.shuffle:
            np.random.shuffle(self.indexes)

    def __len__(self):
        # Denotes the number of batches per epoch
        return math.ceil(len(self.samples) / self.batch_size)

    def __getitem__(self, index):
        # Generate one batch of data
        batch_indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
        batch_samples = [self.samples[k] for k in batch_indexes]
        
        X = []
        y = []
        
        for img_path, label, bbox in batch_samples:
            try:
                # Read Image
                img = cv2.imread(str(img_path))
                if img is None: continue
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                h, w, _ = img.shape
                
                # Crop using pre-calculated bbox
                xmin, ymin, xmax, ymax = bbox
                
                # Safety clip
                xmin, ymin = max(0, xmin), max(0, ymin)
                xmax, ymax = min(w, xmax), min(h, ymax)
                
                if xmax - xmin <= 0 or ymax - ymin <= 0:
                    continue

                crop = img[ymin:ymax, xmin:xmax]
                crop = cv2.resize(crop, self.img_size)
                
                # Normalize
                crop = crop.astype('float32') / 255.0
                
                X.append(crop)
                y.append(label)
                
            except Exception as e:
                print(f"Error loading {img_path}: {e}")
                continue
        
        if not X: # Handle empty batch edge case
            return np.zeros((0, *self.img_size, 3)), np.zeros((0, self.num_classes))

        X = np.array(X)
        # Transform labels using the fitted binarizer
        y_encoded = self.lb.transform(y)
        
        # Ensure correct shape for binary classification if needed
        if self.num_classes == 2 and y_encoded.shape[1] == 1:
             y_encoded = to_categorical(y_encoded, num_classes=2)
        elif self.num_classes == 2:
             # If lb returns 2 columns already (rare for 2 classes), handle it
             pass

        return X, np.array(y_encoded)

    def on_epoch_end(self):
        # Updates indexes after each epoch
        if self.shuffle:
            np.random.shuffle(self.indexes)

In [None]:
def scan_dataset_metadata(ann_dir, img_dir):
    """
    Scans XMLs and returns a list of metadata tuple: (img_path, label, bbox).
    Does NOT load images into memory.
    """
    samples = []
    labels_list = []
    
    xml_files = list(ann_dir.glob("*.xml"))
    print(f"Scanning {len(xml_files)} annotation files...")
    
    for xml_file in tqdm(xml_files):
        try:
            tree = ET.parse(xml_file)
            root = tree.getroot()
            filename = root.find('filename').text
            img_path = img_dir / filename
            
            if not img_path.exists():
                continue
                
            for obj in root.findall('object'):
                label = obj.find('name').text
                bndbox = obj.find('bndbox')
                
                xmin = int(float(bndbox.find('xmin').text))
                ymin = int(float(bndbox.find('ymin').text))
                xmax = int(float(bndbox.find('xmax').text))
                ymax = int(float(bndbox.find('ymax').text))
                
                samples.append((img_path, label, (xmin, ymin, xmax, ymax)))
                labels_list.append(label)
                
        except Exception as e:
            print(f"Error reading {xml_file}: {e}")
            continue

    return samples, labels_list

## Model Training

In [None]:
def build_model(model_name, num_classes):
    input_tensor = Input(shape=(IMG_SIZE[0], IMG_SIZE[1], 3))
    
    if model_name == "ResNet50":
        base_model = ResNet50(weights='imagenet', include_top=False, input_tensor=input_tensor)
    elif model_name == "VGG16":
        base_model = VGG16(weights='imagenet', include_top=False, input_tensor=input_tensor)
    elif model_name == "MobileNetV2":
        base_model = MobileNetV2(weights='imagenet', include_top=False, input_tensor=input_tensor)
    elif model_name == "EfficientNetB0":
        base_model = EfficientNetB0(weights='imagenet', include_top=False, input_tensor=input_tensor)
    else:
        raise ValueError(f"Unknown model: {model_name}")
    
    base_model.trainable = False
    
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dropout(0.5)(x)
    output = Dense(num_classes, activation='softmax')(x)
    
    model = Model(inputs=base_model.input, outputs=output)
    
    # Use Built-in TF 2.13+ F1Score metric
    f1_metric = tf.keras.metrics.F1Score(average='weighted', name='f1_score')
    
    model.compile(optimizer=Adam(learning_rate=0.0001),
                  loss='categorical_crossentropy',
                  metrics=['accuracy', f1_metric])
    
    return model

In [None]:
def plot_history(histories):
    """Plots Accuracy and F1 Score curves."""
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    # Plot Accuracy
    for name, history in histories.items():
        axes[0].plot(history.history['val_accuracy'], label=f'{name}')
    axes[0].set_title('Validation Accuracy')
    axes[0].set_xlabel('Epochs')
    axes[0].set_ylabel('Accuracy')
    axes[0].legend()
    axes[0].grid(True)

    # Plot F1 Score
    for name, history in histories.items():
        # Check for standard 'f1_score' or 'val_f1_score' keys
        metric_key = 'val_f1_score'
        
        # Sometimes Keras adds indices if multiple metrics exist, handle robustly
        keys = history.history.keys()
        if metric_key not in keys:
            # Try to find a key containing 'f1'
            possible_keys = [k for k in keys if 'f1' in k and 'val' in k]
            if possible_keys:
                metric_key = possible_keys[0]
            else:
                 print(f"Warning: F1 metric key not found for {name}. Available: {keys}")
                 continue

        axes[1].plot(history.history[metric_key], label=f'{name}')

    axes[1].set_title('Validation F1 Score')
    axes[1].set_xlabel('Epochs')
    axes[1].set_ylabel('F1 Score')
    axes[1].legend()
    axes[1].grid(True)
    
    plt.tight_layout()
    plt.show()

In [None]:
if __name__ == "__main__":
    
    # 1. Scan Metadata (No Image Loading yet)
    samples, labels_raw = scan_dataset_metadata(ANN_DIR, IMG_DIR)
    
    if len(samples) == 0:
        print("No data found! Please run the subset creator first.")
        exit()

    print(f"\nTotal Samples Found: {len(samples)}")
    print(f"Class distribution: {np.unique(labels_raw, return_counts=True)}")
    
    # 2. Prepare Labels
    lb = LabelBinarizer()
    lb.fit(labels_raw) # Fit once on all data
    num_classes = len(lb.classes_)
    print(f"Classes: {lb.classes_}")
    
    # 3. Split Metadata
    train_samples, test_samples = train_test_split(
        samples, test_size=0.2, random_state=42, stratify=labels_raw
    )
    
    # 4. Create Generators
    train_gen = GeraldDataGenerator(train_samples, BATCH_SIZE, IMG_SIZE, lb, num_classes, shuffle=True)
    test_gen = GeraldDataGenerator(test_samples, BATCH_SIZE, IMG_SIZE, lb, num_classes, shuffle=False)
    
    # 5. Train Models
    models_to_train = ["ResNet50", "VGG16", "MobileNetV2", "EfficientNetB0"]
    histories = {}
    results = []
    
    for model_name in models_to_train:
        print(f"\n{'='*20} Training {model_name} {'='*20}")
        
        model = build_model(model_name, num_classes)
        
        history = model.fit(
            train_gen,
            epochs=EPOCHS,
            validation_data=test_gen,
            verbose=1
        )
        histories[model_name] = history
        
        # Evaluate using generator (Need prediction loop for sklearn metrics)
        print(f"Evaluating {model_name}...")
        
        # Re-instantiate test gen with shuffle=False to ensure order matches for evaluation
        eval_gen = GeraldDataGenerator(test_samples, BATCH_SIZE, IMG_SIZE, lb, num_classes, shuffle=False)
        
        y_pred_prob = model.predict(eval_gen)
        y_pred = np.argmax(y_pred_prob, axis=1)
        
        # Extract ground truth from test_samples manually for comparison
        y_test_labels = [s[1] for s in test_samples]
        y_test_encoded = lb.transform(y_test_labels)
        if num_classes == 2:
             # Handle binary shape differences from sklearn vs keras
             if y_test_encoded.shape[1] == 1:
                 # If binary, argmax won't work on (N,1). We need 0 or 1 directly.
                 # But model.predict outputs (N, 2) usually if we force to_categorical in generator.
                 # Let's align:
                 y_true = y_test_encoded.ravel()
             else:
                 y_true = np.argmax(y_test_encoded, axis=1)
        else:
             y_true = np.argmax(y_test_encoded, axis=1)
             
        # Robustness check for sizes
        min_len = min(len(y_true), len(y_pred))
        y_true = y_true[:min_len]
        y_pred = y_pred[:min_len]

        acc = accuracy_score(y_true, y_pred)
        f1 = f1_score(y_true, y_pred, average='weighted')
        
        results.append({
            "Model": model_name,
            "Accuracy": acc,
            "F1-Score (Weighted)": f1
        })
        
        # Confusion Matrix
        cm = confusion_matrix(y_true, y_pred)
        plt.figure(figsize=(6, 5))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=lb.classes_, yticklabels=lb.classes_)
        plt.title(f'Confusion Matrix: {model_name}')
        plt.show()

    # 6. Final Comparison
    print("\n" + "="*40)
    print("FINAL RESULTS")
    print("="*40)
    results_df = pd.DataFrame(results)
    print(results_df)
    
    plot_history(histories)
    
    best_model_name = results_df.sort_values(by="Accuracy", ascending=False).iloc[0]["Model"]
    print(f"\nBest performing model: {best_model_name}")

Scanning 0 annotation files...


0it [00:00, ?it/s]

No data found! Please run the subset creator first.

Total Samples Found: 0
Class distribution: (array([], dtype=float64), array([], dtype=int64))





ValueError: y has 0 samples: []

: 