In [None]:
# Resnet Based City Predictions with Softmax Output for the Building-to-Parcel Workflow
# Leonard Schrage, l.schrage@northeastern.edu / lschrage@mit.edu, 2024-25

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from PIL import Image
import json
import os
import random
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path
from collections import Counter
import time
from datetime import datetime
from tqdm.notebook import tqdm

# =============================================================================
# Configuration Parameters
# =============================================================================

# Model Configuration
RESNET_MODEL = 'ResNet50'  # Options: 'ResNet18', 'ResNet50'
CLASS_NAMES = ['Boston', 'Charlotte', 'Manhattan', 'Pittsburgh']
MODEL_PATH = "models/softmax-ResNet50_50-ep_128-bs_2024-12-06_10-33.pth"

# Input/Output Configuration
TEST_FOLDERS = {
    'brooklyn': "/home/ls/sites/re-blocking/image-generation/brooklyn_comparison/parcels"
}

# Verify paths exist
valid_folders = {}
for name, path in TEST_FOLDERS.items():
    if os.path.exists(path):
        print(f"Found valid path: {path}")
        valid_folders[name] = path
    else:
        print(f"Path not found: {path}")
TEST_FOLDERS = valid_folders

SAMPLE_SIZE = None  # Number of images to process per folder (None for all images)
OUTPUT_DIR = "softmax-output/city-predictions"  # Base directory for saving results

# Files to ignore (Mac and hidden files)
IGNORE_PATTERNS = {
    '.DS_Store',
    '._',
    '.AppleDouble',
    '.LSOverride',
    'Icon\r',
    '.Spotlight-V100',
    '.Trashes',
    '__MACOSX',
    'thumbs.db',
    'Thumbs.db',
    '.git',
    '.ipynb_checkpoints'
}

# Model Parameters
NUM_AUGMENTATIONS = 5  # Number of augmentations for test-time augmentation
TEMPERATURE = 1.5  # Temperature scaling for prediction sharpening
IMG_SIZE = 224  # Input image size
NORMALIZE_MEAN = [0.485, 0.456, 0.406]  
NORMALIZE_STD = [0.229, 0.224, 0.225]

# Augmentation Parameters
ROTATION_DEGREES = 10  # Max rotation degrees for augmentation
BRIGHTNESS_JITTER = 0.1  # Brightness adjustment range
TRANSLATION_RANGE = 0.05  # Max translation as fraction of image size

# Visualization Parameters
CONFIDENCE_BINS = 30  # Number of bins for confidence histogram
PLOT_SIZE_LARGE = (12, 6)  # Size for large plots
PLOT_SIZE_MEDIUM = (10, 6)  # Size for medium plots

# =============================================================================
# Model Setup
# =============================================================================

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else 
                     "mps" if torch.backends.mps.is_available() else 
                     "cpu")
print(f"Using device: {device}")

# Transform for prediction
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=NORMALIZE_MEAN, std=NORMALIZE_STD),
])

def is_valid_image_file(file_path):
    """Check if a file is a valid image and not a system or hidden file"""
    file_name = file_path.name
    
    # Check for ignored patterns
    if any(pattern in str(file_path) for pattern in IGNORE_PATTERNS):
        return False
    
    # Check if it's a hidden file
    if file_name.startswith('.') or file_name.startswith('_'):
        return False
        
    # Verify it's a valid image
    try:
        with Image.open(file_path) as img:
            img.verify()
        return True
    except Exception:
        return False

def load_model(model_path):
    """Load the trained model with the improved architecture"""
    try:
        if RESNET_MODEL == 'ResNet18':
            weights = models.ResNet18_Weights.DEFAULT
            model = models.resnet18(weights=weights)
        elif RESNET_MODEL == 'ResNet50':
            weights = models.ResNet50_Weights.DEFAULT
            model = models.resnet50(weights=weights)
        
        model.fc = nn.Sequential(
            nn.Linear(model.fc.in_features, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, len(CLASS_NAMES))
        )
        
        checkpoint = torch.load(model_path, map_location=device, weights_only=True)
        if 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
        else:
            model.load_state_dict(checkpoint)
        
        model = model.to(device)
        model.eval()
        return model
    except Exception as e:
        print(f"Error loading model: {str(e)}")
        raise

def predict_single_pass(model, image_tensor):
    """Make a single prediction pass"""
    outputs = model(image_tensor)
    scaled_outputs = outputs / TEMPERATURE
    probabilities = F.softmax(scaled_outputs, dim=1)[0]
    return probabilities

def predict_with_tta(model, image_path, num_augmentations=NUM_AUGMENTATIONS):
    """Predict with Test Time Augmentation"""
    predictions = []
    
    try:
        image = Image.open(image_path).convert('RGB')
        
        with torch.no_grad():
            # Base prediction
            base_tensor = transform(image).unsqueeze(0).to(device)
            base_pred = predict_single_pass(model, base_tensor)
            predictions.append(base_pred)
            
            # TTA predictions
            tta_transforms = [
                transforms.RandomHorizontalFlip(p=1.0),
                transforms.RandomRotation(ROTATION_DEGREES),
                transforms.ColorJitter(brightness=BRIGHTNESS_JITTER),
                transforms.RandomAffine(ROTATION_DEGREES, translate=(TRANSLATION_RANGE, TRANSLATION_RANGE)),
            ]
            
            for _ in range(num_augmentations):
                aug_tensor = base_tensor.clone()
                for t in random.sample(tta_transforms, 2):
                    aug_tensor = t(aug_tensor)
                aug_pred = predict_single_pass(model, aug_tensor)
                predictions.append(aug_pred)
            
            final_pred = torch.mean(torch.stack(predictions), dim=0)
            predicted_class = torch.argmax(final_pred).item()
            confidence = float(final_pred[predicted_class])
            
            return final_pred.cpu().numpy(), predicted_class, confidence
    
    except Exception as e:
        print(f"Error predicting image {image_path}: {str(e)}")
        return None, None, None
    
def analyze_results(results):
    """Analyze prediction results and generate statistics"""
    if not results:
        return {
            'confidence_stats': {},
            'class_distribution': {},
            'probability_stats': {}
        }
    
    df = pd.DataFrame(results)
    
    # Calculate confidence statistics
    confidence_stats = {
        'mean_confidence': df['confidence'].mean(),
        'median_confidence': df['confidence'].median(),
        'min_confidence': df['confidence'].min(),
        'max_confidence': df['confidence'].max(),
        'std_confidence': df['confidence'].std()
    }
    
    # Calculate class distribution
    class_distribution = df['predicted_class'].value_counts().to_dict()
    
    # Calculate probability statistics per class
    prob_stats = {}
    for class_name in CLASS_NAMES:
        probs = [r['probabilities'][class_name] for r in results]
        prob_stats[class_name] = {
            'mean': np.mean(probs),
            'median': np.median(probs),
            'min': np.min(probs),
            'max': np.max(probs),
            'std': np.std(probs)
        }
    
    return {
        'confidence_stats': confidence_stats,
        'class_distribution': class_distribution,
        'probability_stats': prob_stats
    }

def generate_visualizations(results, output_dir):
    """Generate and save visualization plots"""
    if not results:
        return
    
    df = pd.DataFrame(results)
    
    # 1. Confidence Distribution
    plt.figure(figsize=PLOT_SIZE_MEDIUM)
    sns.histplot(data=df, x='confidence', bins=CONFIDENCE_BINS)
    plt.title('Distribution of Prediction Confidence')
    plt.xlabel('Confidence')
    plt.ylabel('Count')
    plt.savefig(os.path.join(output_dir, 'confidence_distribution.png'))
    plt.close()
    
    # 2. Class Distribution
    plt.figure(figsize=PLOT_SIZE_MEDIUM)
    sns.countplot(data=df, x='predicted_class')
    plt.title('Distribution of Predicted Classes')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'class_distribution.png'))
    plt.close()
    
    # 3. Probability Distribution by Class
    prob_data = []
    for result in results:
        for class_name, prob in result['probabilities'].items():
            prob_data.append({
                'class': class_name,
                'probability': prob
            })
    
    prob_df = pd.DataFrame(prob_data)
    plt.figure(figsize=PLOT_SIZE_LARGE)
    sns.boxplot(data=prob_df, x='class', y='probability')
    plt.title('Probability Distribution by Class')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'probability_distribution.png'))
    plt.close()

def generate_comparative_analysis(all_results, base_output_dir):
    """Generate comparative visualizations and reports across folders"""
    comparative_dir = base_output_dir / 'comparative_analysis'
    comparative_dir.mkdir(exist_ok=True)
    
    # Collect data for comparison
    comparison_data = {
        'confidence': [],
        'class_distribution': [],
        'probability_distribution': []
    }
    
    # Process each folder's results
    for folder_name, results in all_results.items():
        # Confidence stats
        confidence_stats = results['analysis']['confidence_stats']
        comparison_data['confidence'].append({
            'folder': folder_name,
            'mean': confidence_stats['mean_confidence'],
            'median': confidence_stats['median_confidence'],
            'std': confidence_stats['std_confidence']
        })
        
        # Class distribution
        class_dist = results['analysis']['class_distribution']
        for class_name in CLASS_NAMES:
            comparison_data['class_distribution'].append({
                'folder': folder_name,
                'class': class_name,
                'count': class_dist.get(class_name, 0)
            })
        
        # Probability distributions
        for pred in results['predictions']:
            for class_name, prob in pred['probabilities'].items():
                comparison_data['probability_distribution'].append({
                    'folder': folder_name,
                    'class': class_name,
                    'probability': prob
                })
    
    # Create DataFrames
    confidence_df = pd.DataFrame(comparison_data['confidence'])
    class_dist_df = pd.DataFrame(comparison_data['class_distribution'])
    prob_dist_df = pd.DataFrame(comparison_data['probability_distribution'])
    
    # 1. Confidence Comparison
    plt.figure(figsize=(15, 8))
    confidence_summary = confidence_df.melt(
        id_vars=['folder'], 
        value_vars=['mean', 'median', 'std'],
        var_name='metric'
    )
    g = sns.barplot(data=confidence_summary, x='folder', y='value', hue='metric')
    plt.title('Confidence Metrics Comparison Across Folders')
    plt.xticks(rotation=45)
    plt.legend(title='Metric', bbox_to_anchor=(1.05, 1))
    plt.tight_layout()
    plt.savefig(comparative_dir / 'confidence_comparison.png', bbox_inches='tight')
    plt.close()
    
    # 2. Class Distribution Comparison
    plt.figure(figsize=(15, 8))
    pivot_dist = class_dist_df.pivot(index='folder', columns='class', values='count')
    ax = pivot_dist.plot(kind='bar', stacked=True)
    plt.title('Class Distribution Comparison Across Folders')
    plt.xlabel('Folder')
    plt.ylabel('Number of Images')
    plt.legend(title='Predicted Class', bbox_to_anchor=(1.05, 1))
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(comparative_dir / 'class_distribution_comparison.png', bbox_inches='tight')
    plt.close()
    
    # 3. Probability Distribution Heatmap
    plt.figure(figsize=(15, 10))
    pivot_prob = prob_dist_df.groupby(['folder', 'class'])['probability'].mean().unstack()
    sns.heatmap(pivot_prob, annot=True, fmt='.2f', cmap='YlOrRd', cbar_kws={'label': 'Mean Probability'})
    plt.title('Mean Prediction Probability Heatmap')
    plt.ylabel('Folder')
    plt.xlabel('Predicted Class')
    plt.tight_layout()
    plt.savefig(comparative_dir / 'probability_heatmap.png')
    plt.close()
    
    # Generate comparative report
    comparative_report = {
        'confidence_summary': {
            folder: {
                'mean': float(stats['mean']),
                'median': float(stats['median']),
                'std': float(stats['std'])
            }
            for folder, stats in confidence_df.set_index('folder').to_dict('index').items()
        },
        'class_distribution': pivot_dist.to_dict('index'),
        'probability_matrix': pivot_prob.round(3).to_dict('index'),
        'relative_metrics': {
            'highest_confidence_folder': confidence_df.loc[confidence_df['mean'].idxmax(), 'folder'],
            'most_diverse_predictions': class_dist_df.groupby('folder')['count'].std().idxmin(),
        }
    }
    
    # Save comparative report
    with open(comparative_dir / 'comparative_report.json', 'w') as f:
        json.dump(comparative_report, f, indent=4)
    
    return comparative_report

# Main execution code to process images and generate predictions
def process_folder(folder_name, folder_path, model, sample_size=SAMPLE_SIZE):
    """Process all images in a folder and generate predictions"""
    folder_path = Path(folder_path)
    print(f"Processing folder: {folder_name} ({folder_path})")
    
    # Find all valid image files
    image_files = []
    for file_path in folder_path.glob('**/*'):
        if file_path.is_file() and file_path.suffix.lower() in ['.jpg', '.jpeg', '.png'] and is_valid_image_file(file_path):
            image_files.append(file_path)
    
    # Sample if needed
    if sample_size and len(image_files) > sample_size:
        image_files = random.sample(image_files, sample_size)
    
    print(f"Found {len(image_files)} valid images to process")
    
    # Process images
    results = []
    for img_path in tqdm(image_files, desc=f"Predicting {folder_name}"):
        probs, class_idx, confidence = predict_with_tta(model, img_path)
        
        if probs is not None:
            result = {
                'image_path': str(img_path),
                'predicted_class': CLASS_NAMES[class_idx],
                'confidence': confidence,
                'probabilities': {class_name: float(prob) for class_name, prob in zip(CLASS_NAMES, probs)}
            }
            results.append(result)
    
    return results

# Set up output directory
output_dir = Path(OUTPUT_DIR)
output_dir.mkdir(exist_ok=True, parents=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_dir = output_dir / f"run_{timestamp}"
run_dir.mkdir(exist_ok=True)

# Load model once
print("Loading model...")
model = load_model(MODEL_PATH)

# Process each folder
all_results = {}
for folder_name, folder_path in TEST_FOLDERS.items():
    folder_output_dir = run_dir / folder_name
    folder_output_dir.mkdir(exist_ok=True)
    
    # Process images
    predictions = process_folder(folder_name, folder_path, model)
    
    # Analyze results
    analysis = analyze_results(predictions)
    
    # Generate visualizations
    generate_visualizations(predictions, folder_output_dir)
    
    # Save predictions to JSON
    with open(folder_output_dir / 'predictions.json', 'w') as f:
        json.dump(predictions, f, indent=4)
    
    # Save analysis to JSON
    with open(folder_output_dir / 'analysis.json', 'w') as f:
        json.dump(analysis, f, indent=4)
    
    # Store results for comparative analysis
    all_results[folder_name] = {
        'predictions': predictions,
        'analysis': analysis
    }
    
    print(f"Processed {len(predictions)} images from {folder_name}")
    print(f"Class distribution: {analysis['class_distribution']}")
    print(f"Mean confidence: {analysis['confidence_stats']['mean_confidence']:.4f}")

# Generate comparative analysis if multiple folders were processed
if len(all_results) > 1:
    comparative_report = generate_comparative_analysis(all_results, run_dir)
    print("Generated comparative analysis across all folders")

print(f"All processing complete. Results saved to {run_dir}")