In [2]:
import torch
import matplotlib.pyplot as plt
from ipywidgets import interact, widgets, VBox, HBox
import numpy as np
from src.dataset import ProjectDatasets
from src.models.UNet import UNet

ModuleNotFoundError: No module named 'src'

In [None]:
# Step 1: Load Model Output for Visualization
def load_model_outputs(checkpoint_path, model, dataloader, device='cuda'):
    """
    Load model output for visualization given a specific checkpoint.
    """
    model.eval()
    model.to(device)
    
    # Load checkpoint
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    original_images = []
    ground_truths = []
    predictions = []
    
    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            
            original_images.append(inputs.cpu().numpy())
            ground_truths.append(targets.cpu().numpy())
            predictions.append(outputs.argmax(dim=1).cpu().numpy())
    
    # Combine all batches
    original_images = np.concatenate(original_images, axis=0)
    ground_truths = np.concatenate(ground_truths, axis=0)
    predictions = np.concatenate(predictions, axis=0)
    
    return original_images, ground_truths, predictions

# Step 2: Define Visualization Function
def display_image(index, original_images, ground_truths, predictions):
    """
    Display original image, ground truth, and model prediction side by side.
    """
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Original Image
    axes[0].imshow(original_images[index].transpose(1, 2, 0))  # Assuming (C, H, W) format
    axes[0].set_title("Original Image")
    axes[0].axis("off")
    
    # Ground Truth
    axes[1].imshow(ground_truths[index], cmap="gray")
    axes[1].set_title("Ground Truth")
    axes[1].axis("off")
    
    # Model Prediction
    axes[2].imshow(predictions[index], cmap="gray")
    axes[2].set_title("Model Prediction")
    axes[2].axis("off")
    
    plt.show()

# Step 3: Interactive Playback with `ipywidgets`
def create_interactive_visualizer(original_images, ground_truths, predictions):
    """
    Create an interactive visualization with ipywidgets.
    """
    num_images = len(original_images)

    # Slider to navigate images
    slider = widgets.IntSlider(min=0, max=num_images-1, step=1, description="Image:")
    play_button = widgets.Play(interval=500)  # Play button for auto-sliding
    widgets.jslink((play_button, 'value'), (slider, 'value'))  # Link slider to play button

    # Interactive plot
    interact(display_image, 
             index=slider, 
             original_images=widgets.fixed(original_images), 
             ground_truths=widgets.fixed(ground_truths), 
             predictions=widgets.fixed(predictions))

    # Combine play button and slider
    controls = HBox([play_button, slider])
    display(VBox([controls]))

# Step 4: Main Function to Load Checkpoints and Visualize
def main(checkpoint_path, model, dataloader):
    """
    Main function to load model checkpoint, evaluate on dataset, and visualize results.
    """
    original_images, ground_truths, predictions = load_model_outputs(checkpoint_path, model, dataloader)
    create_interactive_visualizer(original_images, ground_truths, predictions)

# Step 5: Run the Visualizer
if __name__ == "__main__":
    # Set up paths and model
    checkpoint_path = "../experiments/experiment_20241220-131602/checkpoints/checkpoint_epoch_10.pth"
    model = ...  # Load your model here
    model = UNet(in_channels=3, out_channels=30)
    model = model.to(device)
    dataloader = ...  # Validation or test DataLoader
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Run the visualizer
    main(checkpoint_path, model, dataloader)