# 🔮 ThermoSight: Inference Demo & Visualization

This notebook demonstrates how to:
- **Load a trained ThermoSight model** 🧠
- **Perform inference on new microscope images** (single and batch) 🖼️
- **Visualize prediction probabilities** 📊
- **Display results clearly** ✨

---

In [None]:
# Import Required Libraries
import os
import sys
import torch
from torchvision import transforms
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import random
import warnings
warnings.filterwarnings('ignore')

# Add src to path
sys.path.append('..')
from src.models.vit_model import ViT # Assuming ViT model is defined here
from src.inference.predict import predict_image # Using the standalone predict_image

# Setup
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("viridis")

print("🔬 Inference Demo Initialized!")
print(f"PyTorch version: {torch.__version__}")

In [None]:
# Configuration
class InferenceConfig:
    model_path = os.path.join('..', 'models', 'best_model.pth')
    # Example image directory (ensure this path is correct and contains images)
    sample_image_dir = os.path.join('..', 'data', 'processed', 'test') 
    
    img_size = 460 # Must match training
    patch_size = 8 # Must match training
    num_classes = 4 # Must match training
    
    # Class names (ensure order matches model output)
    class_names = ['200°C', '400°C', '600°C', '800°C'] 
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

config = InferenceConfig()

print(f"⚙️ Configuration Loaded:")
print(f"  Model path: {config.model_path}")
print(f"  Sample image directory: {config.sample_image_dir}")
print(f"  Device: {config.device}")

# Check if model exists
if not os.path.exists(config.model_path):
    print(f"🚨 WARNING: Model file not found at {config.model_path}")
    print("💡 Please ensure a trained model ('best_model.pth') is in the '../models/' directory.")

In [None]:
# Load Trained Model
def load_model(model_path, img_size, patch_size, num_classes, device):
    """Loads the trained ViT model."""
    try:
        model = ViT(
            img_size=img_size,
            patch_size=patch_size,
            num_classes=num_classes,
            # Add other ViT params if they differ from defaults used in src.models.vit_model
            embed_dim=768, # Example, ensure these match your trained model
            depth=12,
            num_heads=12,
            mlp_ratio=4.0,
            dropout=0.0 # Typically set dropout to 0 for eval
        )
        # Load state dict
        # Ensure the map_location is correctly set, especially if trained on GPU and inferring on CPU
        state_dict = torch.load(model_path, map_location=device)
        
        # Handle potential DataParallel prefix if model was saved that way
        if next(iter(state_dict)).startswith('module.'):
            state_dict = {k[len('module.'):]: v for k, v in state_dict.items()}
            
        model.load_state_dict(state_dict)
        model.to(device)
        model.eval() # Set model to evaluation mode
        print(f"✅ Model loaded successfully from {model_path} and set to evaluation mode.")
        return model
    except FileNotFoundError:
        print(f"❌ Error: Model file not found at {model_path}.")
        return None
    except Exception as e:
        print(f"❌ Error loading model: {e}")
        return None

# Load the model
model = load_model(config.model_path, config.img_size, config.patch_size, config.num_classes, config.device)

In [None]:
# Image Preprocessing
def preprocess_image(image_path, img_size):
    """Preprocesses a single image for inference."""
    try:
        img = Image.open(image_path).convert('RGB')
        
        # Define the same transformations used during validation/testing
        preprocess_transform = transforms.Compose([
            transforms.Resize((img_size, img_size)), # Or Resize(int(img_size * 256 / 224)) then CenterCrop(img_size)
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])
        img_tensor = preprocess_transform(img)
        return img_tensor.unsqueeze(0) # Add batch dimension
    except FileNotFoundError:
        print(f"❌ Error: Image file not found at {image_path}.")
        return None
    except Exception as e:
        print(f"❌ Error preprocessing image {image_path}: {e}")
        return None

# Select a sample image
if os.path.exists(config.sample_image_dir):
    try:
        # Get a random class and a random image from that class
        random_class = random.choice([d for d in os.listdir(config.sample_image_dir) if os.path.isdir(os.path.join(config.sample_image_dir, d))])
        class_path = os.path.join(config.sample_image_dir, random_class)
        image_files = [f for f in os.listdir(class_path) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp'))]
        if image_files:
            sample_image_name = random.choice(image_files)
            sample_image_path = os.path.join(class_path, sample_image_name)
            print(f"🖼️ Sample image selected: {sample_image_path}")

            # Preprocess the sample image
            input_tensor = preprocess_image(sample_image_path, config.img_size)
            if input_tensor is not None:
                print(f"Processed tensor shape: {input_tensor.shape}")
        else:
            print(f"🚨 No images found in {class_path}")
            sample_image_path = None
            input_tensor = None
    except IndexError:
        print(f"🚨 No subdirectories (classes) found in {config.sample_image_dir}")
        sample_image_path = None
        input_tensor = None
    except Exception as e:
        print(f"🚨 Error selecting sample image: {e}")
        sample_image_path = None
        input_tensor = None
else:
    print(f"🚨 Sample image directory not found: {config.sample_image_dir}")
    sample_image_path = None
    input_tensor = None

In [None]:
# Single Image Inference
def predict_single(model, image_tensor, device, class_names):
    """Performs inference on a single preprocessed image tensor."""
    if model is None or image_tensor is None:
        print("Model or image tensor is not available for prediction.")
        return None, None
        
    with torch.no_grad(): # Ensure no gradients are computed
        image_tensor = image_tensor.to(device)
        logits = model(image_tensor)
        probabilities = torch.softmax(logits, dim=1).squeeze().cpu().numpy()
        predicted_class_idx = np.argmax(probabilities)
        predicted_class_name = class_names[predicted_class_idx]
        confidence = probabilities[predicted_class_idx]
        
    return predicted_class_name, confidence, probabilities

# Run prediction on the sample image
if model and input_tensor is not None:
    predicted_class, confidence, probs = predict_single(model, input_tensor, config.device, config.class_names)
    
    if predicted_class:
        print(f"\n🎯 Prediction for: {os.path.basename(sample_image_path) if sample_image_path else 'N/A'}")
        print(f"  Predicted Class: {predicted_class}")
        print(f"  Confidence: {confidence:.4f}")
        
        # Display image and probabilities
        fig, axes = plt.subplots(1, 2, figsize=(12, 5))
        
        # Display image
        if sample_image_path:
            img_display = Image.open(sample_image_path)
            axes[0].imshow(img_display)
        axes[0].set_title(f"Input Image\nTrue Class: {random_class if sample_image_path else 'N/A'}", fontsize=12)
        axes[0].axis('off')
        
        # Display probabilities
        bars = sns.barplot(x=config.class_names, y=probs, ax=axes[1], palette="viridis")
        axes[1].set_title(f"Prediction Probabilities\nPredicted: {predicted_class} ({confidence:.2f})", fontsize=12)
        axes[1].set_ylabel("Probability", fontsize=10)
        axes[1].set_xlabel("Temperature Class", fontsize=10)
        axes[1].tick_params(axis='x', rotation=45)
        axes[1].set_ylim(0, 1)

        # Add probability values on bars
        for i, bar in enumerate(bars.patches):
            axes[1].text(bar.get_x() + bar.get_width() / 2,
                         bar.get_height() + 0.02,
                         f'{probs[i]:.2f}',
                         ha='center', va='bottom', fontsize=9)
            if i == np.argmax(probs): # Highlight predicted class
                bar.set_color('orangered')


        plt.tight_layout()
        plt.show()
else:
    print("⚠️ Skipping single image inference as model or input tensor is missing.")


In [None]:
# Batch Inference and Visualization
def predict_batch(model, image_paths, img_size, device, class_names):
    """Performs inference on a batch of image paths."""
    if model is None:
        print("Model is not available for batch prediction.")
        return []
        
    results = []
    for img_path in image_paths:
        input_tensor = preprocess_image(img_path, img_size)
        if input_tensor is not None:
            pred_class, conf, _ = predict_single(model, input_tensor, device, class_names)
            if pred_class:
                results.append({
                    'path': img_path,
                    'true_class': os.path.basename(os.path.dirname(img_path)), # Assuming parent folder is class name
                    'predicted_class': pred_class,
                    'confidence': conf
                })
    return results

# Select a few random images for batch prediction
batch_image_paths = []
if os.path.exists(config.sample_image_dir):
    all_image_files = []
    for cls_name in os.listdir(config.sample_image_dir):
        cls_dir_path = os.path.join(config.sample_image_dir, cls_name)
        if os.path.isdir(cls_dir_path):
            for fname in os.listdir(cls_dir_path):
                if fname.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp')):
                    all_image_files.append(os.path.join(cls_dir_path, fname))
    
    if len(all_image_files) > 0:
        num_batch_samples = min(len(all_image_files), 4) # Predict on 4 images
        batch_image_paths = random.sample(all_image_files, num_batch_samples)
        print(f"\nSelected {num_batch_samples} images for batch prediction.")
    else:
        print("No images found for batch prediction.")

if model and batch_image_paths:
    batch_results = predict_batch(model, batch_image_paths, config.img_size, config.device, config.class_names)
    
    print("\n📦 Batch Prediction Results:")
    for res in batch_results:
        print(f"  Image: {os.path.basename(res['path'])} | True: {res['true_class']} | Predicted: {res['predicted_class']} | Confidence: {res['confidence']:.3f}")

    # Visualize batch predictions
    if batch_results:
        num_images = len(batch_results)
        fig, axes = plt.subplots(num_images, 1, figsize=(8, 3 * num_images))
        if num_images == 1: axes = [axes] # Ensure axes is iterable for single image

        for i, res in enumerate(batch_results):
            img = Image.open(res['path'])
            axes[i].imshow(img)
            title = (f"True: {res['true_class']}\n"
                     f"Predicted: {res['predicted_class']} (Conf: {res['confidence']:.2f})")
            axes[i].set_title(title, fontsize=10)
            axes[i].axis('off')
        
        plt.tight_layout()
        plt.show()
else:
    print("⚠️ Skipping batch inference as model or image paths are missing.")

## 🤔 Advanced: Visualizing Model Attention (e.g., Grad-CAM)

To understand *why* the model makes certain predictions, techniques like Grad-CAM can be used to highlight the regions of the image that were most influential. Implementing Grad-CAM for Vision Transformers is more involved than for CNNs but is possible.

**Conceptual Steps for ViT Grad-CAM:**
1.  Hook into the last attention layer or a specific block.
2.  Compute gradients of the predicted class score with respect to the attention map outputs.
3.  Weight the attention maps using these gradients.
4.  Average and upscale the resulting heatmap to overlay on the input image.

This typically requires a library that supports ViT attention visualization or custom implementation. For brevity, it's not fully implemented here but is a valuable next step for model interpretability.

---

✅ **Inference Demo Complete!** You can now use the `predict_single` or `predict_batch` functions with your own images.