# CSIRO Image2Biomass - Exploratory Data Analysis

This notebook explores the training data to understand:
- Target variable distributions and correlations
- Metadata feature patterns
- Temporal and geographic variations
- Image characteristics
- Potential data quality issues

In [None]:
import sys
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import warnings
warnings.filterwarnings('ignore')

sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (14, 6)
%matplotlib inline

## Load Data

In [None]:
# Define paths
data_dir = Path.cwd().parent / "csiro_biomass_extract"
train_csv = data_dir / "train.csv"
test_csv = data_dir / "test.csv"
image_dir = data_dir / "train"

print(f"Loading data from {data_dir}")
df = pd.read_csv(train_csv)
print(f"Loaded {len(df)} rows")
print(f"\nColumns: {list(df.columns)}")
df.head()

## Data Structure

The data is in **long format** - each image has 4 rows (one per target variable).

In [None]:
# Pivot to wide format for analysis
df['image_id'] = df['sample_id'].str.split('__').str[0]

# Create wide format
df_wide = df.pivot_table(
    index=['image_id', 'image_path', 'Sampling_Date', 'State', 'Species', 'Pre_GSHH_NDVI', 'Height_Ave_cm'],
    columns='target_name',
    values='target'
).reset_index()

print(f"Unique images: {len(df_wide)}")
print(f"\nTarget columns: {[col for col in df_wide.columns if 'Dry' in col or 'GDM' in col]}")
df_wide.head()

## Target Variable Analysis

In [None]:
# Target statistics
target_cols = ['Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g', 'Dry_Total_g', 'GDM_g']
print("Target Variable Statistics:")
print("=" * 80)
df_wide[target_cols].describe()

In [None]:
# Distribution plots
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
axes = axes.flatten()

for idx, col in enumerate(target_cols):
    ax = axes[idx]
    df_wide[col].hist(bins=30, ax=ax, edgecolor='black')
    ax.set_title(f'{col} Distribution')
    ax.set_xlabel('Biomass (g)')
    ax.set_ylabel('Count')
    ax.axvline(df_wide[col].median(), color='red', linestyle='--', label=f'Median: {df_wide[col].median():.1f}g')
    ax.legend()

axes[-1].axis('off')
plt.tight_layout()
plt.show()

## Target Correlations

In [None]:
# Correlation matrix
plt.figure(figsize=(10, 8))
corr = df_wide[target_cols].corr()
sns.heatmap(corr, annot=True, fmt='.2f', cmap='coolwarm', center=0, 
            square=True, linewidths=1, cbar_kws={"shrink": 0.8})
plt.title('Target Variable Correlations')
plt.tight_layout()
plt.show()

print("\nKey Observations:")
print(f"- Dry_Total should equal Dry_Clover + Dry_Dead + Dry_Green")
print(f"- Correlation Dry_Total vs (Clover+Dead+Green): {corr.loc['Dry_Total_g', 'Dry_Green_g']:.3f}")

## Verify Biomass Composition Constraint

In [None]:
# Check if Dry_Total = Dry_Clover + Dry_Dead + Dry_Green
df_wide['computed_total'] = df_wide['Dry_Clover_g'] + df_wide['Dry_Dead_g'] + df_wide['Dry_Green_g']
df_wide['total_error'] = df_wide['Dry_Total_g'] - df_wide['computed_total']

print("Biomass Composition Constraint Check:")
print(f"Mean error: {df_wide['total_error'].mean():.4f} g")
print(f"Std error: {df_wide['total_error'].std():.4f} g")
print(f"Max absolute error: {df_wide['total_error'].abs().max():.4f} g")
print(f"\nSamples with error > 1g: {(df_wide['total_error'].abs() > 1).sum()}")

# Plot error distribution
plt.figure(figsize=(10, 5))
plt.hist(df_wide['total_error'], bins=50, edgecolor='black')
plt.xlabel('Error (Dry_Total - Sum of Components) [g]')
plt.ylabel('Count')
plt.title('Biomass Composition Constraint Errors')
plt.axvline(0, color='red', linestyle='--', label='Perfect constraint')
plt.legend()
plt.show()

## Metadata Analysis

In [None]:
# Geographic distribution
print("Geographic Distribution:")
print(df_wide['State'].value_counts())

# Species distribution
print("\nSpecies Distribution:")
print(df_wide['Species'].value_counts())

# Visualize
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

df_wide['State'].value_counts().plot(kind='bar', ax=axes[0], color='steelblue', edgecolor='black')
axes[0].set_title('Samples by State')
axes[0].set_ylabel('Count')
axes[0].set_xlabel('State')

df_wide['Species'].value_counts().plot(kind='bar', ax=axes[1], color='forestgreen', edgecolor='black')
axes[1].set_title('Samples by Species')
axes[1].set_ylabel('Count')
axes[1].set_xlabel('Species')
axes[1].tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()

In [None]:
# NDVI and Height distributions
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].hist(df_wide['Pre_GSHH_NDVI'].dropna(), bins=30, edgecolor='black', color='green', alpha=0.7)
axes[0].set_title('NDVI Distribution')
axes[0].set_xlabel('NDVI')
axes[0].set_ylabel('Count')
axes[0].axvline(df_wide['Pre_GSHH_NDVI'].median(), color='red', linestyle='--', 
                label=f'Median: {df_wide["Pre_GSHH_NDVI"].median():.2f}')
axes[0].legend()

axes[1].hist(df_wide['Height_Ave_cm'].dropna(), bins=30, edgecolor='black', color='brown', alpha=0.7)
axes[1].set_title('Average Height Distribution')
axes[1].set_xlabel('Height (cm)')
axes[1].set_ylabel('Count')
axes[1].axvline(df_wide['Height_Ave_cm'].median(), color='red', linestyle='--',
                label=f'Median: {df_wide["Height_Ave_cm"].median():.1f} cm')
axes[1].legend()

plt.tight_layout()
plt.show()

## Temporal Analysis

In [None]:
# Convert date to datetime
df_wide['date'] = pd.to_datetime(df_wide['Sampling_Date'])
df_wide['year'] = df_wide['date'].dt.year
df_wide['month'] = df_wide['date'].dt.month

print("Temporal Coverage:")
print(f"Date range: {df_wide['date'].min()} to {df_wide['date'].max()}")
print(f"\nSamples by year:")
print(df_wide['year'].value_counts().sort_index())

# Plot temporal distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

df_wide['year'].value_counts().sort_index().plot(kind='bar', ax=axes[0], color='purple', edgecolor='black')
axes[0].set_title('Samples by Year')
axes[0].set_ylabel('Count')
axes[0].set_xlabel('Year')

df_wide['month'].value_counts().sort_index().plot(kind='bar', ax=axes[1], color='orange', edgecolor='black')
axes[1].set_title('Samples by Month')
axes[1].set_ylabel('Count')
axes[1].set_xlabel('Month')

plt.tight_layout()
plt.show()

## Biomass vs Metadata Relationships

In [None]:
# Scatter plots: Biomass vs NDVI and Height
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Total biomass vs NDVI
axes[0, 0].scatter(df_wide['Pre_GSHH_NDVI'], df_wide['Dry_Total_g'], alpha=0.5, s=20)
axes[0, 0].set_xlabel('NDVI')
axes[0, 0].set_ylabel('Dry Total (g)')
axes[0, 0].set_title('Total Biomass vs NDVI')
axes[0, 0].grid(True, alpha=0.3)

# Total biomass vs Height
axes[0, 1].scatter(df_wide['Height_Ave_cm'], df_wide['Dry_Total_g'], alpha=0.5, s=20, color='brown')
axes[0, 1].set_xlabel('Height (cm)')
axes[0, 1].set_ylabel('Dry Total (g)')
axes[0, 1].set_title('Total Biomass vs Height')
axes[0, 1].grid(True, alpha=0.3)

# Green biomass vs NDVI
axes[1, 0].scatter(df_wide['Pre_GSHH_NDVI'], df_wide['Dry_Green_g'], alpha=0.5, s=20, color='green')
axes[1, 0].set_xlabel('NDVI')
axes[1, 0].set_ylabel('Dry Green (g)')
axes[1, 0].set_title('Green Biomass vs NDVI')
axes[1, 0].grid(True, alpha=0.3)

# Dead biomass vs NDVI
axes[1, 1].scatter(df_wide['Pre_GSHH_NDVI'], df_wide['Dry_Dead_g'], alpha=0.5, s=20, color='gray')
axes[1, 1].set_xlabel('NDVI')
axes[1, 1].set_ylabel('Dry Dead (g)')
axes[1, 1].set_title('Dead Biomass vs NDVI')
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Summary Statistics

In [None]:
print("=" * 80)
print("CSIRO Image2Biomass Dataset Summary")
print("=" * 80)
print(f"Total unique images: {len(df_wide)}")
print(f"Total rows (long format): {len(df)}")
print(f"\nTarget Variables:")
for col in target_cols:
    print(f"  {col:20s}: mean={df_wide[col].mean():6.2f}g, std={df_wide[col].std():6.2f}g, max={df_wide[col].max():6.2f}g")
print(f"\nMetadata:")
print(f"  States: {df_wide['State'].nunique()} ({', '.join(df_wide['State'].unique())})")
print(f"  Species: {df_wide['Species'].nunique()}")
print(f"  Date range: {df_wide['date'].min().date()} to {df_wide['date'].max().date()}")
print(f"  NDVI range: {df_wide['Pre_GSHH_NDVI'].min():.2f} to {df_wide['Pre_GSHH_NDVI'].max():.2f}")
print(f"  Height range: {df_wide['Height_Ave_cm'].min():.1f} to {df_wide['Height_Ave_cm'].max():.1f} cm")
print("=" * 80)