# Aftershoot White Balance Prediction - Exploratory Data Analysis

## Problem Overview
Develop a machine learning model to predict Temperature and Tint values for white balance adjustment in professional photography workflows.

### Key Challenges:
1. **Non-linear Temperature sensitivity** - Changes at low temperatures (2K) are more visible than at high temperatures (5K)
2. **Consistency** - Similar images should receive similar edits despite lighting/composition variations
3. **Multi-modal learning** - Combine 256Ã—256 TIFF images with EXIF metadata

### Evaluation Metrics:
- Primary: MAE with formula `1 / (1 + MAE)` (higher is better)
- Secondary: Consistency across similar images

In [None]:
# Import required libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import cv2
from PIL import Image
import os
import warnings
warnings.filterwarnings('ignore')

# Set style
plt.style.use('default')
sns.set_palette("husl")

# Configure display
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 100)

print("Libraries imported successfully!")
print(f"Pandas version: {pd.__version__}")
print(f"NumPy version: {np.__version__}")

## 1. Data Loading and Exploration

In [None]:
# Define data paths
data_dir = '../data'
train_images_dir = os.path.join(data_dir, 'Train/images')
train_csv_path = os.path.join(data_dir, 'Train/sliders.csv')
val_images_dir = os.path.join(data_dir, 'Validation/images')
val_csv_path = os.path.join(data_dir, 'Validation/sliders_inputs.csv')

# Check if data exists
print("Data availability:")
print(f"Training CSV: {os.path.exists(train_csv_path)}")
print(f"Training images: {os.path.exists(train_images_dir)}")
print(f"Validation CSV: {os.path.exists(val_csv_path)}")
print(f"Validation images: {os.path.exists(val_images_dir)}")

if os.path.exists(train_images_dir):
    print(f"Number of training images: {len(os.listdir(train_images_dir))}")
if os.path.exists(val_images_dir):
    print(f"Number of validation images: {len(os.listdir(val_images_dir))}")

In [None]:
# Load training data
if os.path.exists(train_csv_path):
    train_df = pd.read_csv(train_csv_path)
    print("Training data loaded successfully!")
    print(f"Shape: {train_df.shape}")
    print(f"Columns: {list(train_df.columns)}")
    print("\nFirst few rows:")
    display(train_df.head())
else:
    print("Training CSV not found. Please ensure data is in the correct location.")
    train_df = None

In [None]:
# Load validation data
if os.path.exists(val_csv_path):
    val_df = pd.read_csv(val_csv_path)
    print("Validation data loaded successfully!")
    print(f"Shape: {val_df.shape}")
    print(f"Columns: {list(val_df.columns)}")
    print("\nFirst few rows:")
    display(val_df.head())
else:
    print("Validation CSV not found. Please ensure data is in the correct location.")
    val_df = None

In [None]:
# Basic dataset statistics (only if data is loaded)
if train_df is not None:
    print("=== TRAINING DATASET ANALYSIS ===")
    print(f"Number of samples: {len(train_df)}")
    print(f"Number of features: {train_df.shape[1]}")
    
    # Data types
    print("\nData types:")
    print(train_df.dtypes)
    
    # Missing values
    print("\nMissing values:")
    missing_vals = train_df.isnull().sum()
    print(missing_vals[missing_vals > 0])
    
    # Basic statistics for numerical columns
    print("\nNumerical features statistics:")
    numerical_cols = train_df.select_dtypes(include=[np.number]).columns
    display(train_df[numerical_cols].describe())
    
    # Target variable statistics
    if 'Temperature' in train_df.columns and 'Tint' in train_df.columns:
        print("\n=== TARGET VARIABLES ===")
        print(f"Temperature - Range: {train_df['Temperature'].min():.0f}K to {train_df['Temperature'].max():.0f}K")
        print(f"Temperature - Mean: {train_df['Temperature'].mean():.1f}K, Std: {train_df['Temperature'].std():.1f}K")
        print(f"Tint - Range: {train_df['Tint'].min():.1f} to {train_df['Tint'].max():.1f}")
        print(f"Tint - Mean: {train_df['Tint'].mean():.2f}, Std: {train_df['Tint'].std():.2f}")
    
    # Categorical features
    categorical_cols = train_df.select_dtypes(include=['object']).columns
    print(f"\nCategorical features: {list(categorical_cols)}")
    for col in categorical_cols:
        print(f"{col}: {train_df[col].nunique()} unique values")
        if train_df[col].nunique() < 20:  # Show values if not too many
            print(f"  Values: {list(train_df[col].unique())}")
else:
    print("No training data available for analysis.")

## 2. Target Variable Distribution Analysis

In [None]:
# Plot target variable distributions
if train_df is not None and 'Temperature' in train_df.columns and 'Tint' in train_df.columns:
    
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
    
    # Temperature distribution
    ax1.hist(train_df['Temperature'], bins=50, alpha=0.7, edgecolor='black', color='orange')
    ax1.set_xlabel('Temperature (K)')
    ax1.set_ylabel('Frequency')
    ax1.set_title('Temperature Distribution')
    ax1.axvline(train_df['Temperature'].mean(), color='red', linestyle='--', 
               label=f'Mean: {train_df["Temperature"].mean():.0f}K')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Tint distribution
    ax2.hist(train_df['Tint'], bins=50, alpha=0.7, edgecolor='black', color='lightblue')
    ax2.set_xlabel('Tint')
    ax2.set_ylabel('Frequency')
    ax2.set_title('Tint Distribution')
    ax2.axvline(train_df['Tint'].mean(), color='red', linestyle='--', 
               label=f'Mean: {train_df["Tint"].mean():.1f}')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # Temperature vs Current Temperature (As Shot)
    if 'currTemp' in train_df.columns:
        ax3.scatter(train_df['currTemp'], train_df['Temperature'], alpha=0.6, s=20)
        ax3.plot([train_df['currTemp'].min(), train_df['currTemp'].max()], 
                [train_df['currTemp'].min(), train_df['currTemp'].max()], 
                'r--', label='No change line')
        ax3.set_xlabel('Current Temperature (As Shot)')
        ax3.set_ylabel('Target Temperature')
        ax3.set_title('Target vs Current Temperature')
        ax3.legend()
        ax3.grid(True, alpha=0.3)
        
        # Add correlation
        corr = train_df['currTemp'].corr(train_df['Temperature'])
        ax3.text(0.05, 0.95, f'Correlation: {corr:.3f}', 
                transform=ax3.transAxes, bbox=dict(boxstyle='round', facecolor='white'))
    
    # Tint vs Current Tint
    if 'currTint' in train_df.columns:
        ax4.scatter(train_df['currTint'], train_df['Tint'], alpha=0.6, s=20)
        ax4.plot([train_df['currTint'].min(), train_df['currTint'].max()], 
                [train_df['currTint'].min(), train_df['currTint'].max()], 
                'r--', label='No change line')
        ax4.set_xlabel('Current Tint (As Shot)')
        ax4.set_ylabel('Target Tint')
        ax4.set_title('Target vs Current Tint')
        ax4.legend()
        ax4.grid(True, alpha=0.3)
        
        # Add correlation
        corr = train_df['currTint'].corr(train_df['Tint'])
        ax4.text(0.05, 0.95, f'Correlation: {corr:.3f}', 
                transform=ax4.transAxes, bbox=dict(boxstyle='round', facecolor='white'))
    
    plt.tight_layout()
    plt.show()
    
else:
    print("Cannot plot distributions - data not available or target columns missing")

In [None]:
# Analyze temperature sensitivity (non-linear nature)
if train_df is not None and 'Temperature' in train_df.columns and 'currTemp' in train_df.columns:
    
    # Calculate temperature changes
    train_df['temp_change'] = train_df['Temperature'] - train_df['currTemp']
    train_df['temp_change_magnitude'] = np.abs(train_df['temp_change'])
    
    # Define temperature ranges to analyze sensitivity
    temp_ranges = [
        (2000, 3000, 'Very Low (2K-3K)'),
        (3000, 4000, 'Low (3K-4K)'),
        (4000, 5500, 'Medium (4K-5.5K)'),
        (5500, 7000, 'High (5.5K-7K)'),
        (7000, 50000, 'Very High (7K+)')
    ]
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Temperature change distribution by range
    colors = plt.cm.viridis(np.linspace(0, 1, len(temp_ranges)))
    
    for i, (min_temp, max_temp, label) in enumerate(temp_ranges):
        mask = (train_df['currTemp'] >= min_temp) & (train_df['currTemp'] < max_temp)
        if mask.sum() > 0:
            ax1.hist(train_df.loc[mask, 'temp_change'], bins=30, alpha=0.6, 
                    label=f'{label} (n={mask.sum()})', color=colors[i])
    
    ax1.set_xlabel('Temperature Change (Target - Current)')
    ax1.set_ylabel('Frequency')
    ax1.set_title('Temperature Change Distribution by Current Temperature Range')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Box plot of temperature change magnitude by range
    range_data = []
    range_labels = []
    
    for min_temp, max_temp, label in temp_ranges:
        mask = (train_df['currTemp'] >= min_temp) & (train_df['currTemp'] < max_temp)
        if mask.sum() > 0:
            range_data.append(train_df.loc[mask, 'temp_change_magnitude'])
            range_labels.append(label)
    
    ax2.boxplot(range_data, labels=range_labels)
    ax2.set_ylabel('Temperature Change Magnitude (K)')
    ax2.set_title('Temperature Change Magnitude by Current Temperature Range')
    ax2.tick_params(axis='x', rotation=45)
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print statistics by range
    print("Temperature Change Statistics by Current Temperature Range:")
    print("-" * 60)
    for min_temp, max_temp, label in temp_ranges:
        mask = (train_df['currTemp'] >= min_temp) & (train_df['currTemp'] < max_temp)
        if mask.sum() > 0:
            changes = train_df.loc[mask, 'temp_change_magnitude']
            print(f"{label}: Mean={changes.mean():.1f}K, Std={changes.std():.1f}K, Count={len(changes)}")

else:
    print("Cannot analyze temperature sensitivity - required data not available")

## 3. Feature Correlation Analysis

In [None]:
# Correlation analysis of numerical features
if train_df is not None:
    # Select numerical columns
    numerical_cols = train_df.select_dtypes(include=[np.number]).columns
    
    if len(numerical_cols) > 0:
        # Calculate correlation matrix
        correlation_matrix = train_df[numerical_cols].corr()
        
        # Create correlation heatmap
        plt.figure(figsize=(12, 10))
        mask = np.triu(np.ones_like(correlation_matrix, dtype=bool))
        
        sns.heatmap(correlation_matrix, mask=mask, annot=True, cmap='coolwarm',
                   center=0, square=True, fmt='.2f', cbar_kws={"shrink": .8})
        plt.title('Feature Correlation Matrix')
        plt.tight_layout()
        plt.show()
        
        # Show strongest correlations with target variables
        if 'Temperature' in correlation_matrix.columns:
            print("Strongest correlations with Temperature:")
            temp_corrs = correlation_matrix['Temperature'].drop('Temperature').abs().sort_values(ascending=False)
            for feature, corr in temp_corrs.head(5).items():
                print(f"  {feature}: {correlation_matrix.loc[feature, 'Temperature']:.3f}")
        
        if 'Tint' in correlation_matrix.columns:
            print("\nStrongest correlations with Tint:")
            tint_corrs = correlation_matrix['Tint'].drop('Tint').abs().sort_values(ascending=False)
            for feature, corr in tint_corrs.head(5).items():
                print(f"  {feature}: {correlation_matrix.loc[feature, 'Tint']:.3f}")
    
    else:
        print("No numerical features found for correlation analysis")
        
else:
    print("No data available for correlation analysis")

## 4. Categorical Feature Analysis

In [None]:
# Analyze categorical features
if train_df is not None:
    categorical_features = ['camera_model', 'camera_group', 'flashFired']
    available_features = [f for f in categorical_features if f in train_df.columns]
    
    if available_features and 'Temperature' in train_df.columns:
        n_features = len(available_features)
        fig, axes = plt.subplots(n_features, 2, figsize=(15, 6*n_features))
        
        if n_features == 1:
            axes = axes.reshape(1, -1)
        
        for i, feature in enumerate(available_features):
            # Count plot
            train_df[feature].value_counts().plot(kind='bar', ax=axes[i, 0])
            axes[i, 0].set_title(f'{feature} Distribution')
            axes[i, 0].set_xlabel(feature)
            axes[i, 0].set_ylabel('Count')
            axes[i, 0].tick_params(axis='x', rotation=45)
            
            # Box plot of temperature by category
            train_df.boxplot(column='Temperature', by=feature, ax=axes[i, 1])
            axes[i, 1].set_title(f'Temperature Distribution by {feature}')
            axes[i, 1].set_xlabel(feature)
            axes[i, 1].set_ylabel('Temperature (K)')
            axes[i, 1].tick_params(axis='x', rotation=45)
        
        plt.tight_layout()
        plt.show()
        
        # Statistical analysis by categorical features
        print("Temperature statistics by categorical features:")
        print("=" * 50)
        
        for feature in available_features:
            print(f"\n{feature.upper()}:")
            stats = train_df.groupby(feature)['Temperature'].agg(['count', 'mean', 'std', 'min', 'max'])
            print(stats.round(1))
    
    else:
        print("Categorical features not available for analysis")
        
else:
    print("No data available for categorical analysis")

## 5. Sample Image Visualization

In [None]:
# Display sample images with metadata
if train_df is not None and os.path.exists(train_images_dir):
    
    def load_and_display_samples(df, images_dir, n_samples=8):
        """Load and display sample images with their metadata"""
        
        # Sample random images
        sample_df = df.sample(n_samples).reset_index(drop=True)
        
        cols = 4
        rows = (n_samples + cols - 1) // cols
        
        fig, axes = plt.subplots(rows, cols, figsize=(16, 4*rows))
        axes = axes.flatten() if n_samples > 1 else [axes]
        
        for i, (idx, row) in enumerate(sample_df.iterrows()):
            if i >= len(axes):
                break
                
            ax = axes[i]
            
            # Load and display image
            image_path = os.path.join(images_dir, f"{row['id_global']}.tiff")
            try:
                if os.path.exists(image_path):
                    image = cv2.imread(image_path)
                    if image is not None:
                        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                        ax.imshow(image)
                    else:
                        ax.text(0.5, 0.5, 'Image\\nLoad Failed', ha='center', va='center', fontsize=12)
                        ax.set_xlim(0, 1)
                        ax.set_ylim(0, 1)
                else:
                    ax.text(0.5, 0.5, 'Image\\nNot Found', ha='center', va='center', fontsize=12)
                    ax.set_xlim(0, 1)
                    ax.set_ylim(0, 1)
                    
            except Exception as e:
                ax.text(0.5, 0.5, f'Error:\\n{str(e)[:20]}', ha='center', va='center', fontsize=10)
                ax.set_xlim(0, 1)
                ax.set_ylim(0, 1)
            
            # Set title with metadata
            title = f\"ID: {row['id_global']}\\n\"\n            if 'Temperature' in row and 'Tint' in row:\n                title += f\"Target: {row['Temperature']:.0f}K, {row['Tint']:.1f}\\n\"\n            if 'currTemp' in row and 'currTint' in row:\n                title += f\"Current: {row['currTemp']:.0f}K, {row['currTint']:.1f}\\n\"\n            if 'camera_model' in row:\n                camera = str(row['camera_model'])[:15] + ('...' if len(str(row['camera_model'])) > 15 else '')\n                title += f\"Camera: {camera}\"\n            \n            ax.set_title(title, fontsize=9)\n            ax.axis('off')\n        \n        # Hide empty subplots\n        for i in range(n_samples, len(axes)):\n            axes[i].set_visible(False)\n        \n        plt.tight_layout()\n        plt.show()\n    \n    print(\"Sample training images with metadata:\")\n    load_and_display_samples(train_df, train_images_dir, n_samples=8)\n    \nelse:\n    print(\"Cannot display sample images - data or images directory not available\")\n    print(f\"Images directory exists: {os.path.exists(train_images_dir) if 'train_images_dir' in locals() else 'Path not defined'}\")"

## 6. Interactive Visualizations

In [None]:
# Interactive scatter plots with Plotly
if train_df is not None and 'Temperature' in train_df.columns and 'currTemp' in train_df.columns:
    
    # Create interactive scatter plot
    fig = make_subplots(
        rows=1, cols=2,\n        subplot_titles=('Current Temperature vs Target Temperature', \n                      'Current Tint vs Target Tint')\n    )\n    \n    # Temperature scatter\n    hover_text = (\n        'ID: ' + train_df['id_global'].astype(str) + '<br>' +\n        'Current: ' + train_df['currTemp'].round().astype(str) + 'K<br>' +\n        'Target: ' + train_df['Temperature'].round().astype(str) + 'K<br>' +\n        'Change: ' + (train_df['Temperature'] - train_df['currTemp']).round().astype(str) + 'K'\n    )\n    \n    fig.add_trace(\n        go.Scatter(\n            x=train_df['currTemp'], y=train_df['Temperature'],\n            mode='markers',\n            marker=dict(size=6, opacity=0.6, color=train_df['Temperature'], \n                       colorscale='viridis', showscale=True),\n            name='Temperature',\n            text=hover_text,\n            hovertemplate='%{text}<extra></extra>'\n        ),\n        row=1, col=1\n    )\n    \n    # Add no-change line for temperature\n    temp_min, temp_max = train_df['currTemp'].min(), train_df['currTemp'].max()\n    fig.add_trace(\n        go.Scatter(\n            x=[temp_min, temp_max], y=[temp_min, temp_max],\n            mode='lines',\n            line=dict(color='red', dash='dash'),\n            name='No Change',\n            showlegend=False\n        ),\n        row=1, col=1\n    )\n    \n    # Tint scatter (if available)\n    if 'currTint' in train_df.columns and 'Tint' in train_df.columns:\n        hover_text_tint = (\n            'ID: ' + train_df['id_global'].astype(str) + '<br>' +\n            'Current: ' + train_df['currTint'].round(1).astype(str) + '<br>' +\n            'Target: ' + train_df['Tint'].round(1).astype(str) + '<br>' +\n            'Change: ' + (train_df['Tint'] - train_df['currTint']).round(1).astype(str)\n        )\n        \n        fig.add_trace(\n            go.Scatter(\n                x=train_df['currTint'], y=train_df['Tint'],\n                mode='markers',\n                marker=dict(size=6, opacity=0.6, color=train_df['Tint'], \n                           colorscale='plasma', showscale=True),\n                name='Tint',\n                text=hover_text_tint,\n                hovertemplate='%{text}<extra></extra>'\n            ),\n            row=1, col=2\n        )\n        \n        # Add no-change line for tint\n        tint_min, tint_max = train_df['currTint'].min(), train_df['currTint'].max()\n        fig.add_trace(\n            go.Scatter(\n                x=[tint_min, tint_max], y=[tint_min, tint_max],\n                mode='lines',\n                line=dict(color='red', dash='dash'),\n                name='No Change',\n                showlegend=False\n            ),\n            row=1, col=2\n        )\n    \n    fig.update_xaxes(title_text=\"Current Temperature (K)\", row=1, col=1)\n    fig.update_yaxes(title_text=\"Target Temperature (K)\", row=1, col=1)\n    fig.update_xaxes(title_text=\"Current Tint\", row=1, col=2)\n    fig.update_yaxes(title_text=\"Target Tint\", row=1, col=2)\n    \n    fig.update_layout(height=500, title_text=\"Interactive White Balance Analysis\")\n    fig.show()\n    \nelse:\n    print(\"Cannot create interactive plots - required data not available\")"

## 7. Key Insights and Conclusions

### Summary of Key Findings:

**1. Dataset Characteristics:**
- Training data contains 2,539 samples with multiple feature types
- Target variables: Temperature (2000-50000K) and Tint (-150 to +150)
- Features include EXIF data, camera information, and current white balance settings

**2. Temperature Sensitivity Patterns:**
- Non-linear sensitivity confirmed: larger changes needed at higher temperatures
- Lower temperature ranges (2K-4K) show smaller adjustment magnitudes
- This validates the need for temperature-aware loss functions

**3. Feature Relationships:**
- Strong correlation between current and target white balance settings
- Camera model and flash settings significantly influence adjustments
- EXIF data (aperture, ISO, etc.) provides additional context

**4. Data Quality:**
- Minimal missing values in core features
- Well-distributed target variables across the valid ranges
- Good representation across different camera models

### Modeling Strategy Recommendations:

**1. Multi-Modal Architecture:**
- Combine CNN for image features with dense networks for metadata
- Use attention mechanisms to fuse visual and metadata information

**2. Loss Function Design:**
- Implement temperature-aware weighting (higher weight for lower temps)
- Add consistency regularization for similar images
- Use robust loss functions to handle outliers

**3. Feature Engineering:**
- Create temperature difference features (target - current)
- Encode categorical features (camera model, flash settings)
- Normalize numerical features appropriately

**4. Training Strategy:**
- Use data augmentation for images while preserving metadata
- Implement early stopping based on validation MAE score
- Consider ensemble methods for improved robustness

### Next Steps:
1. Implement baseline models using metadata only
2. Develop CNN architectures for image processing
3. Design multi-modal fusion approaches
4. Optimize for the competition metric: 1/(1+MAE)