# Feature Exploration: Understanding Biomass Prediction

**Goal**: Before training more neural networks, let's understand:
1. What visual features predict biomass?
2. Are there correlations between color and biomass values?
3. Can simple models (linear regression on color) predict biomass?
4. Are there data quality issues?

**Why this matters**: If simple color features don't correlate with biomass, then:
- ColorJitter is definitely hurting training (scrambling important signal)
- We need to understand what features the CNN should learn
- Might have data quality issues to fix first

---
## Part 1: Setup & Data Loading

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from PIL import Image
import warnings
warnings.filterwarnings('ignore')

from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score, mean_absolute_error
from sklearn.preprocessing import StandardScaler
from scipy import stats

sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

print("‚úì Imports complete")

In [None]:
# Load enriched training data
train_enriched = pd.read_csv('competition/train_enriched.csv')
train_enriched['Sampling_Date'] = pd.to_datetime(train_enriched['Sampling_Date'])
train_enriched['full_image_path'] = train_enriched['image_path'].apply(lambda x: f'competition/{x}')

# Define target columns
target_cols = ['Dry_Green_g', 'Dry_Dead_g', 'Dry_Clover_g', 'GDM_g', 'Dry_Total_g']
competition_weights = [0.1, 0.1, 0.1, 0.2, 0.5]

print(f"Total samples: {len(train_enriched)}")
print(f"Shape: {train_enriched.shape}")
print(f"\nTarget columns: {target_cols}")
print(f"Competition weights: {competition_weights}")

---
## Part 2: Target Variable Distributions

Let's understand the distribution of biomass values we're trying to predict.

In [None]:
# Summary statistics for all targets
print("="*80)
print("TARGET VARIABLE STATISTICS")
print("="*80)

for col in target_cols:
    print(f"\n{col}:")
    print(f"  Mean: {train_enriched[col].mean():.2f}g")
    print(f"  Std:  {train_enriched[col].std():.2f}g")
    print(f"  Min:  {train_enriched[col].min():.2f}g")
    print(f"  Max:  {train_enriched[col].max():.2f}g")
    print(f"  Median: {train_enriched[col].median():.2f}g")
    
    # Check for zeros (common in Dry_Clover_g)
    n_zeros = (train_enriched[col] == 0).sum()
    pct_zeros = 100 * n_zeros / len(train_enriched)
    print(f"  Zeros: {n_zeros} ({pct_zeros:.1f}%)")

In [None]:
# Distribution plots
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for i, col in enumerate(target_cols):
    ax = axes[i]
    
    # Histogram
    ax.hist(train_enriched[col], bins=30, alpha=0.7, color='steelblue', edgecolor='black')
    ax.axvline(train_enriched[col].mean(), color='red', linestyle='--', linewidth=2, label='Mean')
    ax.axvline(train_enriched[col].median(), color='orange', linestyle='--', linewidth=2, label='Median')
    
    ax.set_xlabel(f'{col} (grams)', fontsize=11)
    ax.set_ylabel('Count', fontsize=11)
    ax.set_title(f'Distribution of {col}', fontsize=12, fontweight='bold')
    ax.legend()
    ax.grid(alpha=0.3)

# Remove extra subplot
fig.delaxes(axes[5])

plt.tight_layout()
plt.savefig('target_distributions.png', dpi=150, bbox_inches='tight')
plt.show()

print("‚úì Target distributions plotted")

---
## Part 3: Visual Inspection - High vs Low Biomass

**Key Question**: Can we visually see the difference between high and low biomass images?

In [None]:
def show_image_grid(df, title, n_images=8, figsize=(16, 4)):
    """Display a grid of images with their biomass values."""
    fig, axes = plt.subplots(1, n_images, figsize=figsize)
    
    for i, (idx, row) in enumerate(df.iterrows()):
        if i >= n_images:
            break
            
        # Load and display image
        img = Image.open(row['full_image_path'])
        axes[i].imshow(img)
        axes[i].axis('off')
        
        # Add biomass info as title
        axes[i].set_title(
            f"Total: {row['Dry_Total_g']:.0f}g\n"
            f"Green: {row['Dry_Green_g']:.0f}g\n"
            f"Dead: {row['Dry_Dead_g']:.0f}g",
            fontsize=9
        )
    
    plt.suptitle(title, fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()
    return fig

In [None]:
# High Dry_Total_g images (should be dense/green)
high_biomass = train_enriched.nlargest(8, 'Dry_Total_g')
fig = show_image_grid(high_biomass, 'HIGHEST Dry_Total_g Images (Top 8)')
plt.savefig('high_biomass_images.png', dpi=150, bbox_inches='tight')
plt.show()

print("High biomass stats:")
print(f"  Mean Dry_Total_g: {high_biomass['Dry_Total_g'].mean():.0f}g")
print(f"  Mean NDVI: {high_biomass['Pre_GSHH_NDVI'].mean():.3f}")
print(f"  Mean Height: {high_biomass['Height_Ave_cm'].mean():.1f}cm")

In [None]:
# Low Dry_Total_g images (should be sparse/brown)
low_biomass = train_enriched.nsmallest(8, 'Dry_Total_g')
fig = show_image_grid(low_biomass, 'LOWEST Dry_Total_g Images (Bottom 8)')
plt.savefig('low_biomass_images.png', dpi=150, bbox_inches='tight')
plt.show()

print("Low biomass stats:")
print(f"  Mean Dry_Total_g: {low_biomass['Dry_Total_g'].mean():.0f}g")
print(f"  Mean NDVI: {low_biomass['Pre_GSHH_NDVI'].mean():.3f}")
print(f"  Mean Height: {low_biomass['Height_Ave_cm'].mean():.1f}cm")

In [None]:
# High Dry_Green_g images (should be very green)
high_green = train_enriched.nlargest(8, 'Dry_Green_g')
fig = show_image_grid(high_green, 'HIGHEST Dry_Green_g Images (Green Vegetation)')
plt.savefig('high_green_images.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# High Dry_Dead_g images (should be brown/dead)
high_dead = train_enriched.nlargest(8, 'Dry_Dead_g')
fig = show_image_grid(high_dead, 'HIGHEST Dry_Dead_g Images (Dead Vegetation)')
plt.savefig('high_dead_images.png', dpi=150, bbox_inches='tight')
plt.show()

---
## Part 4: Color Feature Extraction

Extract RGB and HSV statistics from all images to see if they correlate with biomass.

In [None]:
def extract_color_features(image_path, resize=(224, 224)):
    """Extract RGB and HSV color statistics from an image."""
    try:
        # Load and resize image
        img = Image.open(image_path).convert('RGB')
        img = img.resize(resize)
        
        # Convert to numpy array
        img_array = np.array(img) / 255.0  # Normalize to [0, 1]
        
        # RGB features
        r_mean = img_array[:, :, 0].mean()
        g_mean = img_array[:, :, 1].mean()
        b_mean = img_array[:, :, 2].mean()
        
        r_std = img_array[:, :, 0].std()
        g_std = img_array[:, :, 1].std()
        b_std = img_array[:, :, 2].std()
        
        # Color ratios (useful for vegetation)
        green_red_ratio = g_mean / (r_mean + 1e-6)
        green_blue_ratio = g_mean / (b_mean + 1e-6)
        
        # Convert to HSV
        img_hsv = img.convert('HSV')
        hsv_array = np.array(img_hsv) / 255.0
        
        h_mean = hsv_array[:, :, 0].mean()
        s_mean = hsv_array[:, :, 1].mean()
        v_mean = hsv_array[:, :, 2].mean()
        
        # Overall brightness and variance
        brightness = img_array.mean()
        variance = img_array.std()
        
        return {
            'r_mean': r_mean,
            'g_mean': g_mean,
            'b_mean': b_mean,
            'r_std': r_std,
            'g_std': g_std,
            'b_std': b_std,
            'green_red_ratio': green_red_ratio,
            'green_blue_ratio': green_blue_ratio,
            'h_mean': h_mean,
            's_mean': s_mean,
            'v_mean': v_mean,
            'brightness': brightness,
            'variance': variance
        }
    except Exception as e:
        print(f"Error processing {image_path}: {e}")
        return None

print("‚úì Color feature extraction function defined")

In [None]:
# Extract color features from all images
print("Extracting color features from all images...")
print("This may take 2-3 minutes for 357 images...\n")

from tqdm.auto import tqdm

color_features_list = []
for idx, row in tqdm(train_enriched.iterrows(), total=len(train_enriched)):
    features = extract_color_features(row['full_image_path'])
    if features:
        features['index'] = idx
        color_features_list.append(features)

# Create dataframe
color_features_df = pd.DataFrame(color_features_list)
color_features_df.set_index('index', inplace=True)

# Merge with original data
data_with_features = train_enriched.join(color_features_df)

print(f"\n‚úì Extracted {len(color_features_df)} color feature sets")
print(f"\nColor features: {list(color_features_df.columns)}")

In [None]:
# Quick look at color features
print("Color Feature Statistics:")
print("="*60)
print(color_features_df.describe())

---
## Part 5: Correlation Analysis

**Critical Question**: Do color features correlate with biomass targets?

In [None]:
# Calculate correlations between color features and biomass targets
color_feature_cols = list(color_features_df.columns)
tabular_feature_cols = ['Pre_GSHH_NDVI', 'Height_Ave_cm', 'temp_mean_7d', 'rainfall_7d']
all_feature_cols = color_feature_cols + tabular_feature_cols

# Create correlation matrix
correlation_data = data_with_features[all_feature_cols + target_cols]
correlation_matrix = correlation_data.corr()

# Extract correlations with targets only
target_correlations = correlation_matrix[target_cols].loc[all_feature_cols]

print("Top 10 Correlations with Each Target:")
print("="*80)

for target in target_cols:
    print(f"\n{target}:")
    top_corr = target_correlations[target].abs().sort_values(ascending=False).head(10)
    for feature, corr_val in top_corr.items():
        actual_corr = target_correlations.loc[feature, target]
        print(f"  {feature:20s}: {actual_corr:+.3f}")

In [None]:
# Heatmap of correlations
fig, axes = plt.subplots(1, 2, figsize=(18, 8))

# Full correlation heatmap (features vs targets)
ax = axes[0]
sns.heatmap(target_correlations, annot=False, cmap='coolwarm', center=0, 
            vmin=-1, vmax=1, ax=ax, cbar_kws={'label': 'Correlation'})
ax.set_title('Correlations: Features vs Biomass Targets', fontsize=14, fontweight='bold')
ax.set_xlabel('Biomass Targets', fontsize=12)
ax.set_ylabel('Features', fontsize=12)

# Target intercorrelations
ax = axes[1]
target_intercorr = correlation_matrix.loc[target_cols, target_cols]
sns.heatmap(target_intercorr, annot=True, fmt='.2f', cmap='coolwarm', center=0,
            vmin=-1, vmax=1, ax=ax, square=True, cbar_kws={'label': 'Correlation'})
ax.set_title('Biomass Target Intercorrelations', fontsize=14, fontweight='bold')
ax.set_xlabel('', fontsize=12)
ax.set_ylabel('', fontsize=12)

plt.tight_layout()
plt.savefig('correlation_heatmaps.png', dpi=150, bbox_inches='tight')
plt.show()

print("‚úì Correlation heatmaps plotted")

---
## Part 6: Scatter Plots - Visual Correlations

Visualize the strongest correlations to understand relationships.

In [None]:
# Find strongest correlations for each target
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.flatten()

for i, target in enumerate(target_cols):
    ax = axes[i]
    
    # Find feature with strongest correlation
    abs_corr = target_correlations[target].abs()
    strongest_feature = abs_corr.idxmax()
    corr_value = target_correlations.loc[strongest_feature, target]
    
    # Scatter plot
    ax.scatter(data_with_features[strongest_feature], 
              data_with_features[target],
              alpha=0.5, s=30, color='steelblue')
    
    # Add trend line
    z = np.polyfit(data_with_features[strongest_feature], data_with_features[target], 1)
    p = np.poly1d(z)
    x_line = np.linspace(data_with_features[strongest_feature].min(), 
                         data_with_features[strongest_feature].max(), 100)
    ax.plot(x_line, p(x_line), "r--", linewidth=2, label=f'Trend (r={corr_value:.3f})')
    
    ax.set_xlabel(strongest_feature, fontsize=11)
    ax.set_ylabel(target, fontsize=11)
    ax.set_title(f'{target} vs {strongest_feature}', fontsize=12, fontweight='bold')
    ax.legend()
    ax.grid(alpha=0.3)

# Remove extra subplot
fig.delaxes(axes[5])

plt.tight_layout()
plt.savefig('strongest_correlations.png', dpi=150, bbox_inches='tight')
plt.show()

print("‚úì Strongest correlation scatter plots created")

In [None]:
# Specific scatter plots: Color features vs biomass
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# Green channel vs Dry_Green_g
ax = axes[0, 0]
ax.scatter(data_with_features['g_mean'], data_with_features['Dry_Green_g'], 
          alpha=0.5, s=30, color='green')
corr = data_with_features[['g_mean', 'Dry_Green_g']].corr().iloc[0, 1]
ax.set_title(f'Green Channel vs Dry_Green_g (r={corr:.3f})', fontweight='bold')
ax.set_xlabel('Mean Green Channel')
ax.set_ylabel('Dry_Green_g')
ax.grid(alpha=0.3)

# Red channel vs Dry_Dead_g
ax = axes[0, 1]
ax.scatter(data_with_features['r_mean'], data_with_features['Dry_Dead_g'], 
          alpha=0.5, s=30, color='brown')
corr = data_with_features[['r_mean', 'Dry_Dead_g']].corr().iloc[0, 1]
ax.set_title(f'Red Channel vs Dry_Dead_g (r={corr:.3f})', fontweight='bold')
ax.set_xlabel('Mean Red Channel')
ax.set_ylabel('Dry_Dead_g')
ax.grid(alpha=0.3)

# Green/Red ratio vs Dry_Total_g
ax = axes[0, 2]
ax.scatter(data_with_features['green_red_ratio'], data_with_features['Dry_Total_g'], 
          alpha=0.5, s=30, color='olive')
corr = data_with_features[['green_red_ratio', 'Dry_Total_g']].corr().iloc[0, 1]
ax.set_title(f'Green/Red Ratio vs Dry_Total_g (r={corr:.3f})', fontweight='bold')
ax.set_xlabel('Green/Red Ratio')
ax.set_ylabel('Dry_Total_g')
ax.grid(alpha=0.3)

# NDVI vs Dry_Green_g
ax = axes[1, 0]
ax.scatter(data_with_features['Pre_GSHH_NDVI'], data_with_features['Dry_Green_g'], 
          alpha=0.5, s=30, color='darkgreen')
corr = data_with_features[['Pre_GSHH_NDVI', 'Dry_Green_g']].corr().iloc[0, 1]
ax.set_title(f'NDVI vs Dry_Green_g (r={corr:.3f})', fontweight='bold')
ax.set_xlabel('Pre_GSHH_NDVI')
ax.set_ylabel('Dry_Green_g')
ax.grid(alpha=0.3)

# NDVI vs Dry_Total_g
ax = axes[1, 1]
ax.scatter(data_with_features['Pre_GSHH_NDVI'], data_with_features['Dry_Total_g'], 
          alpha=0.5, s=30, color='navy')
corr = data_with_features[['Pre_GSHH_NDVI', 'Dry_Total_g']].corr().iloc[0, 1]
ax.set_title(f'NDVI vs Dry_Total_g (r={corr:.3f})', fontweight='bold')
ax.set_xlabel('Pre_GSHH_NDVI')
ax.set_ylabel('Dry_Total_g')
ax.grid(alpha=0.3)

# Height vs Dry_Total_g
ax = axes[1, 2]
ax.scatter(data_with_features['Height_Ave_cm'], data_with_features['Dry_Total_g'], 
          alpha=0.5, s=30, color='purple')
corr = data_with_features[['Height_Ave_cm', 'Dry_Total_g']].corr().iloc[0, 1]
ax.set_title(f'Height vs Dry_Total_g (r={corr:.3f})', fontweight='bold')
ax.set_xlabel('Height_Ave_cm')
ax.set_ylabel('Dry_Total_g')
ax.grid(alpha=0.3)

plt.tight_layout()
plt.savefig('color_vs_biomass_scatter.png', dpi=150, bbox_inches='tight')
plt.show()

print("‚úì Color vs biomass scatter plots created")

---
## Part 7: Simple Baseline Models

**Key Question**: What R¬≤ can we achieve with simple linear regression?

This gives us a baseline to beat with neural networks.

In [None]:
from sklearn.model_selection import train_test_split

# Train/val split (same as in main notebook)
train_data, val_data = train_test_split(data_with_features, test_size=0.2, random_state=42)

print(f"Training samples: {len(train_data)}")
print(f"Validation samples: {len(val_data)}")

In [None]:
def evaluate_linear_model(X_train, y_train, X_val, y_val, feature_names, model_name):
    """Train and evaluate a linear regression model."""
    # Scale features
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_val_scaled = scaler.transform(X_val)
    
    # Train model
    model = LinearRegression()
    model.fit(X_train_scaled, y_train)
    
    # Predict
    y_pred_train = model.predict(X_train_scaled)
    y_pred_val = model.predict(X_val_scaled)
    
    # Calculate R¬≤ for each target
    results = {}
    competition_score = 0
    
    print(f"\n{'='*70}")
    print(f"{model_name}")
    print(f"Features: {', '.join(feature_names)}")
    print(f"{'='*70}")
    
    for i, target in enumerate(target_cols):
        r2_train = r2_score(y_train[:, i], y_pred_train[:, i])
        r2_val = r2_score(y_val[:, i], y_pred_val[:, i])
        mae_val = mean_absolute_error(y_val[:, i], y_pred_val[:, i])
        
        results[target] = {
            'r2_train': r2_train,
            'r2_val': r2_val,
            'mae_val': mae_val
        }
        
        competition_score += competition_weights[i] * r2_val
        
        print(f"\n{target}:")
        print(f"  Train R¬≤: {r2_train:+.4f}")
        print(f"  Val R¬≤:   {r2_val:+.4f}")
        print(f"  Val MAE:  {mae_val:.2f}g")
    
    print(f"\n{'='*70}")
    print(f"Competition Score: {competition_score:.4f}")
    print(f"{'='*70}")
    
    return results, competition_score

print("‚úì Evaluation function defined")

In [None]:
# Model 1: NDVI only (Upper bound for image-only models)
X_train_ndvi = train_data[['Pre_GSHH_NDVI']].values
X_val_ndvi = val_data[['Pre_GSHH_NDVI']].values
y_train = train_data[target_cols].values
y_val = val_data[target_cols].values

ndvi_results, ndvi_score = evaluate_linear_model(
    X_train_ndvi, y_train, X_val_ndvi, y_val,
    ['Pre_GSHH_NDVI'],
    "MODEL 1: NDVI Only"
)

In [None]:
# Model 2: Color features only
X_train_color = train_data[color_feature_cols].values
X_val_color = val_data[color_feature_cols].values

color_results, color_score = evaluate_linear_model(
    X_train_color, y_train, X_val_color, y_val,
    color_feature_cols[:3],  # Show first 3 for brevity
    "MODEL 2: Color Features Only (RGB, HSV, ratios)"
)

In [None]:
# Model 3: Simple RGB only (most basic)
simple_rgb_cols = ['r_mean', 'g_mean', 'b_mean']
X_train_rgb = train_data[simple_rgb_cols].values
X_val_rgb = val_data[simple_rgb_cols].values

rgb_results, rgb_score = evaluate_linear_model(
    X_train_rgb, y_train, X_val_rgb, y_val,
    simple_rgb_cols,
    "MODEL 3: Simple RGB Only (r_mean, g_mean, b_mean)"
)

In [None]:
# Model 4: Color + NDVI + Height (What a CNN should learn)
X_train_combined = train_data[color_feature_cols + ['Pre_GSHH_NDVI', 'Height_Ave_cm']].values
X_val_combined = val_data[color_feature_cols + ['Pre_GSHH_NDVI', 'Height_Ave_cm']].values

combined_results, combined_score = evaluate_linear_model(
    X_train_combined, y_train, X_val_combined, y_val,
    ['Color features', 'NDVI', 'Height'],
    "MODEL 4: Combined (Color + NDVI + Height)"
)

In [None]:
# Summary comparison
print("\n" + "="*80)
print("SIMPLE MODEL COMPARISON")
print("="*80)

summary_df = pd.DataFrame({
    'Model': [
        'NDVI Only',
        'Color Features Only', 
        'Simple RGB Only',
        'Color + NDVI + Height',
        '---',
        'CNN Baseline (actual)',
    ],
    'Competition Score': [
        ndvi_score,
        color_score,
        rgb_score,
        combined_score,
        np.nan,
        -1.2527  # From your actual results
    ]
})

print(summary_df.to_string(index=False))
print("="*80)

print("\nüìä Key Insights:")
print(f"  ‚Ä¢ NDVI alone achieves: {ndvi_score:.4f}")
print(f"  ‚Ä¢ Color features achieve: {color_score:.4f}")
print(f"  ‚Ä¢ Simple RGB achieves: {rgb_score:.4f}")
print(f"  ‚Ä¢ Combined achieves: {combined_score:.4f}")
print(f"  ‚Ä¢ CNN Baseline achieved: -1.2527 (WORSE than simple linear models!)")

if color_score > -1.2527:
    print("\n‚ö†Ô∏è  WARNING: Simple color features beat the CNN!")
    print("    This suggests the CNN is NOT learning properly.")
    print("    Likely causes:")
    print("      1. ColorJitter destroying color information")
    print("      2. CNN architecture too complex/overparameterized")
    print("      3. Training setup issues (learning rate, loss function)")
else:
    print("\n‚úì CNN baseline is better than simple linear models (as expected)")

---
## Part 8: Red Flags Analysis

Check for data quality issues that could explain model failures.

In [None]:
print("="*80)
print("RED FLAGS ANALYSIS")
print("="*80)

red_flags = []

# 1. Check if green images have low Dry_Green_g (label mismatch)
high_green_channel = data_with_features.nlargest(20, 'g_mean')
avg_dry_green = high_green_channel['Dry_Green_g'].mean()
overall_avg = data_with_features['Dry_Green_g'].mean()

print(f"\n1. Label Consistency Check:")
print(f"   Images with highest green channel:")
print(f"     Avg Dry_Green_g: {avg_dry_green:.2f}g")
print(f"   Overall dataset:")
print(f"     Avg Dry_Green_g: {overall_avg:.2f}g")

if avg_dry_green < overall_avg:
    red_flags.append("‚ö†Ô∏è  Green images have LOWER Dry_Green_g than average (label mismatch?)")
else:
    print("   ‚úì Green images have higher Dry_Green_g (labels consistent)")

# 2. Check for insufficient variance in images
print(f"\n2. Image Variance Check:")
print(f"   RGB variance range: {data_with_features['variance'].min():.3f} - {data_with_features['variance'].max():.3f}")
print(f"   RGB variance mean: {data_with_features['variance'].mean():.3f}")

if data_with_features['variance'].std() < 0.02:
    red_flags.append("‚ö†Ô∏è  Very low variance in images (all look similar)")
else:
    print("   ‚úì Sufficient variance in images")

# 3. Check for correlation with State/Species (location dependency)
print(f"\n3. Location Dependency Check:")
state_biomass = data_with_features.groupby('State')['Dry_Total_g'].mean()
print(f"   Dry_Total_g by State:")
for state, biomass in state_biomass.items():
    print(f"     {state}: {biomass:.2f}g")

if state_biomass.std() / state_biomass.mean() > 0.5:
    red_flags.append("‚ö†Ô∏è  Large biomass variation by State (location-dependent, hard for CNN)")
else:
    print("   ‚úì Biomass relatively consistent across states")

# 4. Check for outliers
print(f"\n4. Outlier Check:")
for target in target_cols:
    Q1 = data_with_features[target].quantile(0.25)
    Q3 = data_with_features[target].quantile(0.75)
    IQR = Q3 - Q1
    outliers = ((data_with_features[target] < Q1 - 1.5*IQR) | 
                (data_with_features[target] > Q3 + 1.5*IQR)).sum()
    pct_outliers = 100 * outliers / len(data_with_features)
    print(f"   {target}: {outliers} outliers ({pct_outliers:.1f}%)")

# 5. Check weak correlations
print(f"\n5. Feature Correlation Strength:")
max_color_corr = target_correlations.loc[color_feature_cols].abs().max().max()
ndvi_corr = target_correlations.loc['Pre_GSHH_NDVI'].abs().max()
print(f"   Strongest color feature correlation: {max_color_corr:.3f}")
print(f"   Strongest NDVI correlation: {ndvi_corr:.3f}")

if max_color_corr < 0.3:
    red_flags.append("‚ö†Ô∏è  Very weak color-biomass correlations (r < 0.3)")
elif max_color_corr < 0.5:
    print("   ‚ö†Ô∏è  Moderate color-biomass correlations (0.3 < r < 0.5)")
else:
    print("   ‚úì Strong color-biomass correlations (r > 0.5)")

# Summary
print(f"\n{'='*80}")
print("RED FLAGS SUMMARY")
print(f"{'='*80}")

if red_flags:
    print(f"\nFound {len(red_flags)} potential issues:\n")
    for flag in red_flags:
        print(f"  {flag}")
else:
    print("\n‚úì No major red flags detected")
    print("  Data quality appears acceptable for modeling")

---
## Part 9: Summary & Recommendations

In [None]:
print("="*80)
print("FINAL SUMMARY & RECOMMENDATIONS")
print("="*80)

print("\nüìä FINDINGS:")
print(f"\n1. Simple Linear Models Performance:")
print(f"   ‚Ä¢ NDVI only: {ndvi_score:.4f}")
print(f"   ‚Ä¢ Color features: {color_score:.4f}")
print(f"   ‚Ä¢ Simple RGB: {rgb_score:.4f}")
print(f"   ‚Ä¢ Combined: {combined_score:.4f}")

print(f"\n2. Current CNN Performance:")
print(f"   ‚Ä¢ Baseline CNN: -1.2527 (FAILED)")
print(f"   ‚Ä¢ Teacher CNN: -2.1383 (FAILED)")
print(f"   ‚Ä¢ Student CNN: -2.0922 (FAILED)")

print(f"\n3. Correlation Strengths:")
max_corr_per_target = {}
for target in target_cols:
    max_corr = target_correlations[target].abs().max()
    max_feature = target_correlations[target].abs().idxmax()
    max_corr_per_target[target] = (max_feature, max_corr)
    print(f"   ‚Ä¢ {target}: r={max_corr:.3f} (strongest: {max_feature})")

print("\n" + "="*80)
print("üîß RECOMMENDATIONS:")
print("="*80)

# Decision logic
if color_score > 0.2:
    print("\n‚úÖ PROCEED WITH CNN (with fixes):")
    print("\n   Color features show promise (R¬≤ > 0.2).")
    print("   CNN should be able to learn these patterns.")
    print("\n   Required fixes:")
    print("   1. ‚ùå REMOVE ColorJitter - it's destroying color signal")
    print("   2. üîß SIMPLIFY model architecture - current is too complex")
    print("   3. ‚è±Ô∏è  TRAIN for 5 epochs first - quick validation")
    print("   4. üéØ TARGET: Beat simple linear model (R¬≤ > {:.2f})".format(color_score))
    
elif color_score > 0.0:
    print("\n‚ö†Ô∏è  CAUTIOUSLY PROCEED:")
    print("\n   Color features show weak but positive correlation.")
    print("   CNN might work, but expectations should be modest.")
    print("\n   Recommended actions:")
    print("   1. ‚ùå REMOVE ColorJitter")
    print("   2. üîß Try SIMPLER model first (fewer layers)")
    print("   3. üìä Consider ensemble with linear model")
    print("   4. üéØ TARGET: R¬≤ > 0.1 minimum")
    
else:
    print("\n‚ùå STOP - DATA PROBLEM:")
    print("\n   Even simple linear models can't predict biomass from color.")
    print("   This indicates fundamental data quality issues.")
    print("\n   Required investigation:")
    print("   1. üîç Verify image-label alignment (correct IDs?)")
    print("   2. üìÖ Check date alignment (images from same day as biomass?)")
    print("   3. üñºÔ∏è  Manually inspect 20+ images vs labels")
    print("   4. üìä Investigate why color doesn't correlate with biomass")

print("\n" + "="*80)
print("NEXT STEPS:")
print("="*80)
print("\n1. Review all visualizations above")
print("2. Check if high-biomass images LOOK greener than low-biomass")
print("3. If yes ‚Üí Fix CNN training setup (remove ColorJitter, simplify)")
print("4. If no ‚Üí Investigate data quality issues")
print("5. Train simple baseline for 5 epochs only")
print("6. If R¬≤ > 0.0 ‚Üí Scale up training")
print("7. If R¬≤ < 0.0 ‚Üí Debug further before long training runs")

print("\n" + "="*80)
print("‚úì Feature exploration complete!")
print("="*80)

In [None]:
# Save the enriched data with color features for later use
data_with_features.to_csv('train_with_color_features.csv', index=False)
print("\n‚úì Saved enriched data with color features to: train_with_color_features.csv")
print("  You can use this for further analysis or quick experiments")