## Exploratory Data Analysis

In [None]:
# Allow imports from the src folder
import sys
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))

if project_root not in sys.path:
    sys.path.append(project_root)

# Create directories for saving processed data and visualizations
os.makedirs("../data/processed", exist_ok=True)
os.makedirs('../visualizations', exist_ok=True)

In [None]:
# Import libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import random
from src.s3_loader import get_image_s3
from src.save_fig import save_fig

### Data Overview

In [None]:
# Create DataFrame's from raw csv data
df_features = pd.read_csv('../data/raw/coral_multilabel_dataset.csv')
df_annotations = pd.read_csv('../data/raw/metadata_annotations.csv')
df_regions = pd.read_csv('../data/raw/metadata_regions.csv')

In [None]:
# Inspect df_features
duplicate_count = df_features['image_id'].duplicated().sum()
null_count = df_features.isnull().sum().sum()
n_rows = len(df_features)
print(f"Duplicate image_ids: {duplicate_count}")
print(f"Total null values: {null_count}")
print(f'Total rows: {n_rows}')
df_features.head()

There are no duplicate images or null values in this dataset. We can see there are 4821 rows (images) and 172 columns (image_id + benthic attributes) in the dataset. The benthic attributes are binary features. We will want to drop image_id for EDA purposes, leaving us with 171 benthic attributes to filter through.

### Feature Selection

#### Label-Aware Filtering
We don't want to overload our CNN's or Vision Transformer with too many benthic attributes to try to learn.  In order to keep the labels to train the model on under ~30, we will select features that appear at least 100 times in the image collection.

In [None]:
# Feature columns without 'image_id' 
feature_cols = df_features.columns.drop(['image_id']) 
# Label counts for each benthic attribute 
label_counts = df_features[feature_cols].sum() 
# Only keep counts with at least 100 positive examples 
filtered_labels = label_counts[label_counts >= 100].sort_values(ascending=False) 
# DataFrame of filtered labels + counts print(f'Attributes ≥100 positives: {len(filtered_labels)}') 
filtered_labels.to_frame(name="positive_count")

We identified 25 benthic attributes with at least 100 positive examples. One of these is the “Other” category. Since “Other” represents everything outside the full set of 171 benthic attributes, it does not provide meaningful ecological information for our reduced feature set. Because of this, we will exclude it from our modeling subset.

In [None]:
# Drop 'Other' attribute
if 'Other' in filtered_labels.index:
    filtered_labels = filtered_labels.drop('Other')

print(f'Attributes ≥100 positives (excluding "Other"): {len(filtered_labels)}')

In [None]:
# Before vs After Label-Aware Filtering: Side-by-Side Subplots

initial_counts = df_features[feature_cols].sum()
after_counts = label_counts[filtered_labels.index]

plt.figure(figsize=(14, 5))

# Subplot 1: BEFORE FILTERING
plt.subplot(1, 2, 1)
plt.hist(initial_counts.values, bins=15, color='steelblue', edgecolor='black')
plt.title("Label Count Distribution (Before Filtering)")
plt.xlabel("Positive Count")
plt.ylabel("Number of Labels")

# Subplot 2: AFTER FILTERING
plt.subplot(1, 2, 2)
plt.hist(after_counts.values, bins=15, color='seagreen', edgecolor='black')
plt.title("Label Count Distribution (After Filtering)")
plt.xlabel("Positive Count")
plt.ylabel("Number of Labels")

plt.tight_layout()
save_fig("label_count_distribution_before_after")
plt.show()

#### Region-Aware Filtering

Our goal is to evaluate the model on a geographic region it has never seen during training. For this to work, each region must contain enough examples of the benthic attributes we plan to model. Before deciding which features to keep, we examine the region metadata to understand how these attributes are distributed across regions.

In [None]:
# Merge features with regions so each image has its region_name
df_merged = df_features.merge(
    df_regions[['image_id', 'region_name']],
    on='image_id',
    how='left'
)

# Compute counts of filtered_labels per region
region_label_counts = (
    df_merged.groupby('region_name')[filtered_labels.index]
    .sum()
    .T   # transpose: labels = rows, regions = columns
)

# Total positives per region across all filtered labels
region_positive_counts = region_label_counts.sum(axis=0)

# Proportion of all filtered-label positives that come from each region
region_positive_props = region_positive_counts / region_positive_counts.sum()

print("Proportion of filtered-label positives by region:\n")
print(region_positive_props)

region_label_counts

Central Indo-Pacific accounts for roughly 55.7 percent of the dataset, Western Indo-Pacific for 43.6 percent, and the Tropical Atlantic for less than 1 percent. Because of this distribution, using Western Indo-Pacific as the test region is reasonable, since the model can train on the larger Central Indo-Pacific subset. The table above also shows that several benthic attributes have very low or no overlap between these regions, so additional feature filtering is be necessary.

In [None]:
# Region Composition Visualization
plt.figure(figsize=(6, 4))
region_positive_props.sort_values().plot(kind='barh')

plt.title("Proportion of Filtered-Label Positives by Region")
plt.xlabel("Proportion")
plt.ylabel("Region")
plt.grid(axis='x', linestyle='--', alpha=0.4)
save_fig("region_positive_proportion_barplot")
plt.show()

We choose a minimum of 100 positive examples per region to ensure each benthic attribute has enough representation for CNNs and Vision Transformers to learn meaningful visual patterns and generalize across geographic regions. It’s acceptable to exclude the Tropical Atlantic because it represents less than one percent of the dataset, so it cannot provide enough examples to support meaningful model training or region-based evaluation.

In [None]:
# Per-region minimum threshold
min_per_region = 100  

major_regions = ['Central Indo-Pacific', 'Western Indo-Pacific']

region_filtered = region_label_counts.loc[
    (region_label_counts[major_regions] >= min_per_region).all(axis=1)
]

# Heatmap of Positive Counts Across Regions
plt.figure(figsize=(10, 6))
sns.heatmap(region_filtered, annot=True, fmt='d', cmap='Blues')

plt.title("Positive Counts per Benthic Attribute Across Regions")
plt.xlabel("Region")
plt.ylabel("Benthic Attribute")
save_fig("region_filtered_label_heatmap")
plt.show()

#### Check for Label Co-Occurrence Structure

We check for co-occurrence to ensure that no two benthic attributes always appear together, which would make them redundant and prevent the model from learning distinct visual patterns for each label.

In [None]:
# Correlation matrix for the final region-filtered labels
co_occurrence = df_features[region_filtered.index].corr()

plt.figure(figsize=(8, 6))
sns.heatmap(co_occurrence, cmap='coolwarm', vmin=-1, vmax=1)
plt.title("Correlation Between Selected Benthic Attributes")
save_fig("label_cooccurrence_heatmap")
plt.show()

# Count correlations above threshold (excluding diagonal)
high_corr_mask = (co_occurrence.abs() > 0.90)
high_corr_count = high_corr_mask.sum() - 1   # subtract diagonal

# Find high-correlation pairs
high_corr_pairs = [
    (i, j, co_occurrence.loc[i, j])
    for i in co_occurrence.index
    for j in co_occurrence.columns
    if i < j and abs(co_occurrence.loc[i, j]) > 0.90
]

# Print result
if len(high_corr_pairs) == 0:
    print("No high correlation pairs found.")
else:
    print("High correlation pairs (>0.90):")
    for pair in high_corr_pairs:
        print(pair)


#### Per-Region Negative Counts

We check per-region negative counts to ensure the model sees both the presence and absence of each benthic attribute, since a label cannot be learned reliably if it never appears as a negative example in a region.

In [None]:
region_neg_counts = df_merged.groupby('region_name')[region_filtered.index] \
                             .apply(lambda x: (x == 0).sum())

# Heatmap of Negative Counts Across Regions
plt.figure(figsize=(12, 6))
sns.heatmap(region_neg_counts, annot=True, fmt='d', cmap='Reds')

plt.title("Negative Counts per Benthic Attribute Across Regions")
plt.xlabel("Region")
plt.ylabel("Benthic Attribute")
save_fig("negative_counts_heatmap")
plt.show()

In our dataset, both major regions contain substantial negative counts for all selected attributes, which confirms that each label provides a meaningful learning signal and can be distinguished reliably across regions.

#### Prevalence Distribution

We examine the prevalence of each attribute to detect extremely rare or overly common labels that could cause imbalance during training or skew evaluation metrics. This is especially important for multi-label deep learning because extreme imbalance can cause unstable training, biased gradients and poor recall on rare classes

In [None]:
# Calculate prevalence for each selected label (proportion of images where label==1)
prevalence = df_features[region_filtered.index].mean().sort_values(ascending=False)

# Convert to percent for easier interpretation
prevalence_percent = (prevalence * 100).round(2)

plt.figure(figsize=(10, 6))
sns.barplot(x=prevalence_percent.values, y=prevalence_percent.index)

plt.xlabel("Prevalence (% of images)")
plt.ylabel("Benthic Attribute")
plt.title("Prevalence of Selected Benthic Attributes")
plt.grid(axis='x', linestyle='--', alpha=0.5)
save_fig("benthic_attribute_prevalence")
plt.show()

prevalence_percent

Most benthic attributes show moderate to high prevalence across the dataset, while even the rarest remaining classes appear in more than five percent of images, suggesting that the final feature set does not suffer from extreme imbalance.

### Image Sampling and Visual Diagnostics
#### Random Image Grid

To get a quick visual sense of the overall image quality in the MERMAID dataset, we sample nine images at random and display them here.

In [None]:
# 3x3 grid of randomly selected images from the dataset
random.seed(33)
sample_ids = random.sample(df_features['image_id'].tolist(), 9)

plt.figure(figsize=(12, 12))

for i, img_id in enumerate(sample_ids, 1):
    try:
        img = get_image_s3(img_id).convert("RGB")
        plt.subplot(3, 3, i)
        plt.imshow(img)
        plt.title(img_id, fontsize=8)
        plt.axis("off")
    except Exception as e:
        print(f"Error loading {img_id}: {e}")

plt.tight_layout()
save_fig("random_image_grid")
plt.show()

These samples show clear underwater scenes with good color balance and visible benthic structure, which suggests that the dataset is suitable for image-based learning.

#### Region-Based Image Grid

To compare imagery across geographic realms, we sample several images from each major region and place them side by side for visual inspection.

In [None]:
# 3x2 grid of region-specific images
random.seed(33)
regions = ["Central Indo-Pacific", "Western Indo-Pacific"]
images_per_region = 3

plt.figure(figsize=(12, 8))

index = 1
for region in regions:
    region_ids = df_regions[df_regions["region_name"] == region]["image_id"].tolist()
    subset = random.sample(region_ids, images_per_region)

    for img_id in subset:
        try:
            img = get_image_s3(img_id).convert("RGB")
            plt.subplot(len(regions), images_per_region, index)
            plt.imshow(img)
            plt.title(f"{region}\n{img_id}", fontsize=7)
            plt.axis("off")
            index += 1
        except Exception as e:
            print(f"Error loading {img_id}: {e}")

plt.tight_layout()
save_fig("region_based_image_grid")
plt.show()

Although the images originate from different biogeographic regions, they do not exhibit clear region-specific visual cues. This supports our choice of a region-based train-test split, since the model cannot rely on trivial region signatures and must instead learn benthic features that generalize across locations.

#### Image Resolution and Aspect Ratio Analysis

CNNs and ViTs assume consistent image sizes. MERMAID images are usually uniform, but we want to confirm this.

In [None]:
# Sample image_ids 
random.seed(33)
sample_size = 40
sample_ids = random.sample(df_features['image_id'].tolist(), sample_size)

widths = []
heights = []
aspect_ratios = []

for img_id in sample_ids:
    try:
        img = get_image_s3(img_id).convert("RGB")
        w, h = img.size
        widths.append(w)
        heights.append(h)
        aspect_ratios.append(w / h)
    except Exception as e:
        print(f"Error loading {img_id}: {e}")

plt.figure(figsize=(18, 5))

# Plot width distribution
plt.subplot(1, 3, 1)
plt.hist(widths, bins=20, color='royalblue', edgecolor='black')
plt.title("Image Width Distribution")
plt.xlabel("Width (pixels)")
plt.ylabel("Count")

# Plot height distribution
plt.subplot(1, 3, 2)
plt.hist(heights, bins=20, color='seagreen', edgecolor='black')
plt.title("Image Height Distribution")
plt.xlabel("Height (pixels)")
plt.ylabel("Count")

# Plot aspect ratio distribution
plt.subplot(1, 3, 3)
plt.hist(aspect_ratios, bins=20, color='salmon', edgecolor='black')
plt.title("Aspect Ratio Distribution")
plt.xlabel("Width / Height")
plt.ylabel("Count")
plt.grid(alpha=0.4)

plt.tight_layout()
save_fig("resolution_and_aspect_ratio_distributions")
plt.show()

The widths and heights fall into several clear high-resolution groups rather than one unified size. This likely reflects different cameras or export settings used in the MERMAID surveys. Even though the resolutions vary, the aspect ratios cluster tightly around a few values near one, which means the images all have a similar overall square shape. Because of this consistency, resizing them to a single standard input size will only introduce mild and uniform distortion, so there is no need for multiple preprocessing pipelines. The dataset is fully compatible with a single global resize step before model training.

#### Pixel Intensity and RGB Distribution

Understanding the pixel intensity and RGB channel distributions matters because underwater images often have uneven lighting, strong blue–green color dominance, reduced red wavelengths, and variable contrast, all of which influence how the model perceives benthic features and determine the type of normalization and augmentation needed for stable training. Twenty images are sufficient because each image contains millions of pixels, so even a small sample produces tens of millions of color values, which is enough for the overall RGB and brightness distributions to stabilize and reveal the dataset’s true color characteristics.

In [None]:
# Sample image IDs
random.seed(33)
sample_size = 20
sample_ids = random.sample(df_features["image_id"].tolist(), sample_size)

all_pixels = []
r_vals = []
g_vals = []
b_vals = []

for img_id in sample_ids:
    try:
        img = get_image_s3(img_id).convert("RGB")
        arr = np.array(img).reshape(-1, 3)  # flatten H×W×3 into N×3

        all_pixels.append(arr)
        r_vals.extend(arr[:, 0])
        g_vals.extend(arr[:, 1])
        b_vals.extend(arr[:, 2])

    except Exception as e:
        print(f"Error loading {img_id}: {e}")

# Convert to arrays
r_vals = np.array(r_vals)
g_vals = np.array(g_vals)
b_vals = np.array(b_vals)

# Plot grayscale pixel intensity distribution
plt.figure(figsize=(8, 5))
plt.hist(np.concatenate([r_vals, g_vals, b_vals]), bins=50, color='gray', edgecolor='black')
plt.title("Overall Pixel Intensity Distribution")
plt.xlabel("Pixel Value (0–255)")
plt.ylabel("Count")
save_fig("pixel_intensity_distribution")
plt.show()

# Plot RGB channel histograms
plt.figure(figsize=(12, 4))

plt.subplot(1, 3, 1)
plt.hist(r_vals, bins=50, color='red', edgecolor='black')
plt.title("Red Channel")

plt.subplot(1, 3, 2)
plt.hist(g_vals, bins=50, color='green', edgecolor='black')
plt.title("Green Channel")

plt.subplot(1, 3, 3)
plt.hist(b_vals, bins=50, color='blue', edgecolor='black')
plt.title("Blue Channel")

plt.tight_layout()
save_fig("rgb_channel_histograms")
plt.show()

The overall pixel intensity distribution shows that most pixel values fall between roughly 40 and 160, which indicates moderate brightness without severe underexposure or overexposure. The tail above 200 suggests small regions of strong highlights, which is expected when sunlight reflects off sand or bleaching patches. The RGB channel histograms reveal a strong imbalance between color channels. The red channel is shifted the lowest, peaking around 50–100, while the green and blue channels peak higher and cover a wider range. This is characteristic of underwater imagery because red wavelengths are absorbed quickly with depth, leaving blue and green tones far more dominant in the raw pixel data. These distributions confirm the need for normalization and color-jitter augmentations during training so the model can generalize across varying lighting, turbidity, and depth conditions.

#### Region-Based RGB Comparison

To check whether the two major geographic regions differ in overall color characteristics, we compute the mean red, green, and blue values for a small sample of images from each region and compare their distributions.

In [None]:
random.seed(33)
regions = ["Central Indo-Pacific", "Western Indo-Pacific"]
sample_per_region = 50

results = []

for region in regions:
    region_ids = df_regions[df_regions["region_name"] == region]["image_id"].tolist()

    # avoid sampling empty list
    if len(region_ids) == 0:
        print(f"No images found for region: {region}")
        continue

    # sample
    n = min(sample_per_region, len(region_ids))
    subset = random.sample(region_ids, n)

    # compute RGB means
    for img_id in subset:
        try:
            img = get_image_s3(img_id).convert("RGB")
            arr = np.array(img)

            mean_r = arr[:, :, 0].mean()
            mean_g = arr[:, :, 1].mean()
            mean_b = arr[:, :, 2].mean()

            results.append({
                "region": region,
                "R": mean_r,
                "G": mean_g,
                "B": mean_b
            })

        except Exception as e:
            print(f"Error loading {img_id}: {e}")

df_rgb = pd.DataFrame(results)

# Visualization
fig, axes = plt.subplots(1, 3, figsize=(13, 4))

channels = ["R", "G", "B"]

for ax, ch in zip(axes, channels):
    sns.boxplot(
        data=df_rgb,
        x="region",
        y=ch,
        ax=ax
    )
    ax.set_title(f"{ch} Channel by Region")
    ax.set_xlabel("Region")
    ax.set_ylabel("Mean Pixel Value")

plt.tight_layout()
save_fig("rgb_region_boxplots")
plt.show()

The RGB boxplots show that Central Indo-Pacific images are generally a bit brighter, especially in the green and blue channels, while Western Indo-Pacific images tend to be darker and more variable. These differences are not extreme but they do confirm that the two regions have slightly different lighting and color conditions. This supports using normalization and basic color jitter in our preprocessing so the model focuses on benthic features rather than on small region-specific lighting differences.

#### Annotation Consistency Check

We want to check annotation-level consistency to ensure that the image-level benthic attributes correspond to what is present in the point-level annotation dataset.

In [None]:
# Ecological reliability check 
annotation_counts = df_annotations['benthic_attribute_name'].value_counts()

final_labels = region_filtered.index
annotation_final = annotation_counts.reindex(final_labels).fillna(0)

image_level_counts = label_counts.reindex(final_labels)

consistency_df = pd.DataFrame({
    "image_level_positives": image_level_counts,
    "annotation_points": annotation_final.astype(int)
}).sort_values(by="image_level_positives", ascending=False)

print("Image-level vs Annotation-level counts:\n")
print(consistency_df)

low_annotation_labels = consistency_df[consistency_df['annotation_points'] < 20]

if len(low_annotation_labels) == 0:
    print("\nNo annotation inconsistencies detected (annotation_points ≥ 20).")
else:
    print("\nPotential annotation inconsistencies (annotation_points < 20):")
    display(low_annotation_labels)


# Logical consistency check (annotation ≥ image)
consistency_df["annotation_ge_image"] = (
    consistency_df["annotation_points"] >= consistency_df["image_level_positives"]
)

if consistency_df["annotation_ge_image"].all():
    print("\nAll labels are logically consistent: annotation_points >= image_level_positives.")
else:
    print("\nLogical inconsistency detected. These labels violate annotation >= image rule:")
    display(consistency_df[~consistency_df["annotation_ge_image"]])

In [None]:
# Annotation vs Image-Level Comparison
plt.figure(figsize=(8, 6))

plt.scatter(
    consistency_df["image_level_positives"],
    consistency_df["annotation_points"],
    s=60
)

for label in consistency_df.index:
    x = consistency_df.loc[label, "image_level_positives"]
    y = consistency_df.loc[label, "annotation_points"]
    plt.text(x, y, label, fontsize=8)

plt.xlabel("Image-Level Positives")
plt.ylabel("Annotation Points")
plt.title("Annotation Points vs Image-Level Positives")
plt.grid(alpha=0.4)
save_fig("image_vs_annotation_point_scatter")
plt.show()

This scatter plot compares the number of image-level positives to the number of point-level annotations for each selected attribute, and the upward-right trend shows that attributes with more positive images also have more supporting annotation points, indicating strong and consistent ecological labeling.

#### Annotated Image 

To visualize how the point-level labels align with the raw images, we overlay the MERMAID benthic annotations directly on each photo, showing both the grid structure and the specific benthic attribute at every sampled point.

In [None]:
# pick a few random images that appear in the annotation table
random.seed(34)
sample_ids = random.sample(
    df_annotations["image_id"].unique().tolist(),
    3
)

for img_id in sample_ids:
    # load image
    img = get_image_s3(img_id).convert("RGB")
    arr = np.array(img)

    # get annotations for this image
    ann = df_annotations[df_annotations["image_id"] == img_id]

    plt.figure(figsize=(9, 9))
    plt.imshow(arr)
    
    # plot each point + text
    for _, row in ann.iterrows():
        x = row["col"]
        y = row["row"]
        label = row["benthic_attribute_name"]
        
        # yellow dot
        plt.scatter(x, y, c="yellow", s=40, edgecolors="black")
        
        # text next to the point
        plt.text(
            x + 5, y - 5, 
            label, 
            color="white", 
            fontsize=7,
            bbox=dict(facecolor="black", alpha=0.5, pad=1)
        )

    plt.title(f"Annotated Benthic Labels for Image {img_id}")
    plt.axis("off")
    save_fig(f"annotated_image_{img_id}")
    plt.show()

These examples confirm that the point-level annotations are dense, well-distributed, and accurately linked to their images, providing a strong ecological signal that supports our multi-label training labels.

### Save Processed Data

In [None]:
# Final labels list 
final_label_cols = ["image_id"] + list(final_labels)

# Build a dataframe that contains image_id + region_name + selected labels
final_df = df_merged[["image_id", "region_name"] + list(final_labels)]

# Define regions
test_region = "Western Indo-Pacific"
train_regions = ["Central Indo-Pacific", "Tropical Atlantic"]

# Create datasets
df_full = df_features[final_label_cols]
df_train = final_df[final_df["region_name"].isin(train_regions)].reset_index(drop=True)
df_test = final_df[final_df["region_name"] == test_region].reset_index(drop=True)

# Save datasets
df_full.to_csv("../data/processed/final_labels_full.csv",index=False)
df_train.to_csv("../data/processed/final_labels_train.csv", index=False)
df_test.to_csv("../data/processed/final_labels_test.csv", index=False)

print("Processed EDA outputs saved to ../data/processed/")