In [3]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
from PIL import Image
from diffusers import DDPMPipeline, DDPMScheduler, DDIMScheduler
from scipy.spatial.distance import cosine
from tqdm import tqdm


class DiffusionModelTester:
    """Base class for diffusion model testing."""
    
    def __init__(self, model_path, device='cuda'):
        """
        Initialize the diffusion model tester.
        
        Args:
            model_path: Path to the pretrained diffusion model
            device: Device to run on ('cuda' or 'cpu')
        """
        print(f"Loading pipeline from {model_path}...")
        self.device = device
        self.model_path = model_path
        
        # Load the model pipeline
        self.pipeline = DDPMPipeline.from_pretrained(model_path).to(device)
        
        # Make operations deterministic
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        
        # Get the UNet model
        self.unet = self.pipeline.unet
        
        # Determine noise shape
        if hasattr(self.unet, 'config') and hasattr(self.unet.config, 'sample_size'):
            size = self.unet.config.sample_size
            self.noise_shape = (1, 3, size, size)
        else:
            self.noise_shape = (1, 3, 128, 128)
            
        print(f"Using noise shape: {self.noise_shape}")
    
    def generate_noise(self, seed):
        """
        Generate a deterministic noise tensor using a given seed.
        
        Args:
            seed: Random seed
            
        Returns:
            Noise tensor
        """
        generator = torch.Generator(device=self.device)
        generator.manual_seed(seed)
        return torch.randn(self.noise_shape, generator=generator, 
                           device=self.device, dtype=self.unet.dtype)
    
    def generate_image(self, noise, scheduler, num_inference_steps):
        """
        Generate an image from noise using the provided scheduler.
        
        Args:
            noise: Input noise tensor
            scheduler: The scheduler to use
            num_inference_steps: Number of denoising steps
            
        Returns:
            Tuple of (PIL Image, numpy array) of the generated result
        """
        # Clone noise to avoid modifying the original
        sample = noise.clone()
        
        # Set timesteps for the scheduler
        scheduler.set_timesteps(num_inference_steps)
        
        # Run diffusion process
        for i, t in enumerate(scheduler.timesteps):
            # Get model prediction
            with torch.no_grad():
                model_input = scheduler.scale_model_input(sample, t)
                noise_pred = self.unet(model_input, t).sample
                sample = scheduler.step(noise_pred, t, sample).prev_sample
                
        # Convert to image
        sample = sample.detach().cpu()
        sample = (sample + 1) / 2  # Scale from [-1, 1] to [0, 1]
        sample = sample.clamp(0, 1)
        sample = sample.permute(0, 2, 3, 1).numpy()[0]
        sample_np = sample.copy()
        sample = (sample * 255).astype(np.uint8)
        
        return Image.fromarray(sample), sample_np
    
    def compute_similarity(self, image_a, image_b):
        """
        Compute similarity between two images.
        
        Args:
            image_a: First image as numpy array
            image_b: Second image as numpy array
            
        Returns:
            Cosine similarity (higher = more similar)
        """
        return 1 - cosine(image_a.flatten(), image_b.flatten())
    
    def save_image(self, image, path):
        """
        Save an image to a specified path.
        
        Args:
            image: PIL Image to save
            path: Path to save the image
        """
        # Ensure directory exists
        os.makedirs(os.path.dirname(path), exist_ok=True)
        image.save(path)
        print(f"Saved image to {path}")


class NoisePerturbationTester(DiffusionModelTester):
    """Tests how diffusion models respond to direct noise perturbations."""
    
    def __init__(self, model_path, device='cuda', output_dir='noise_stability_results'):
        """
        Initialize the noise perturbation tester.
        
        Args:
            model_path: Path to the diffusion model
            device: Device to run on
            output_dir: Directory to save results
        """
        super().__init__(model_path, device)
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        
        # Use DDIM scheduler for faster sampling
        self.scheduler = DDIMScheduler.from_config(self.pipeline.scheduler.config)
        self.scheduler.eta = 0.0  # Make fully deterministic
    
    def generate_perturbed_noise(self, base_noise, perturbation_scales, samples_per_scale=1, seed=42):
        """
        Generate perturbed versions of base noise by adding controlled amounts of random noise.
        
        Args:
            base_noise: Base noise tensor to perturb
            perturbation_scales: List of scales to test
            samples_per_scale: Number of samples per scale
            seed: Random seed for reproducibility
            
        Returns:
            Dictionary mapping scales to lists of perturbed noise tensors
        """
        print("Generating perturbed noise tensors...")
        perturbed_dict = {}
        
        for scale in perturbation_scales:
            perturbed_list = []
            
            # For each sample at this scale
            for i in range(samples_per_scale):
                # Use deterministic perturbation
                gen = torch.Generator(device=self.device)
                gen.manual_seed(seed + i)
                
                # Generate perturbation and scale it
                perturbation = torch.randn(base_noise.shape, generator=gen,
                                          device=self.device, dtype=base_noise.dtype)
                perturbation = perturbation * scale
                
                # Add to base noise
                perturbed = base_noise.clone() + perturbation
                
                # Add perturbed noise to list
                perturbed_list.append(perturbed)
                
            # Add to dict
            perturbed_dict[scale] = perturbed_list
            
        return perturbed_dict
    
    def test_noise_perturbation(self, perturbation_scales=[0, 0.05, 0.1, 0.5], 
                               samples_per_scale=1, base_seed=42, num_inference_steps=50):
        """
        Test the model's stability to direct noise perturbations.
        
        Args:
            perturbation_scales: List of perturbation scales to test
            samples_per_scale: Number of samples per scale
            base_seed: Random seed for base noise
            num_inference_steps: Number of diffusion steps
            
        Returns:
            Dictionary with results
        """
        print("\nRunning noise perturbation test...")
        
        # Generate base noise tensor deterministically
        base_noise = self.generate_noise(base_seed)
        
        # Add zero perturbation case
        if 0 not in perturbation_scales:
            perturbation_scales = [0] + perturbation_scales
        
        # Generate perturbed versions
        perturbed_dict = self.generate_perturbed_noise(
            base_noise,
            perturbation_scales=[s for s in perturbation_scales if s > 0],  # Skip zero case
            samples_per_scale=samples_per_scale,
            seed=base_seed + 100  # Use different seed from base noise
        )
        
        # Add base noise to perturbed_dict under scale 0
        perturbed_dict[0] = [base_noise.clone() for _ in range(samples_per_scale)]
        
        # First, generate the reference image from the base noise
        print("\nGenerating reference image from base noise...")
        reference_image, reference_np = self.generate_image(base_noise, self.scheduler, num_inference_steps)
        self.save_image(reference_image, f"{self.output_dir}/reference_image.png")
        
        # For each scale
        results = {}
        for scale in perturbation_scales:
            print(f"\nProcessing perturbation scale {scale}...")
            scale_images = []
            scale_np_images = []
            
            # Add reference to the list for consistency
            scale_images.append(reference_image)
            scale_np_images.append(reference_np)
            
            # For each perturbed noise at this scale
            for i, noise in enumerate(perturbed_dict[scale]):
                image, np_image = self.generate_image(noise, self.scheduler, num_inference_steps)
                
                # Add to lists
                scale_images.append(image)
                scale_np_images.append(np_image)
            
            # Compute metrics
            metrics = self.compute_metrics(scale_np_images)
            
            # Visualize
            self.visualize_results(scale_np_images, metrics, f"scale_{scale}")
            
            # Store results
            results[scale] = metrics
        
        return results
    
    def compute_metrics(self, images):
        """
        Compute similarity metrics between images.
        
        Args:
            images: List of numpy arrays with the images
            
        Returns:
            Dictionary of metrics
        """
        results = {
            'cosine_similarity': []
        }
        
        # Use the first image as reference
        reference = images[0]
        
        # Calculate metrics for all pairs with the reference
        for i in range(1, len(images)):
            # Cosine similarity (higher is more similar)
            cosine_sim = self.compute_similarity(reference, images[i])
            results['cosine_similarity'].append(cosine_sim)
        
        # Compute statistics
        results['mean_cosine'] = np.mean(results['cosine_similarity']) if results['cosine_similarity'] else 0
        results['std_cosine'] = np.std(results['cosine_similarity']) if results['cosine_similarity'] else 0
        
        return results
    
    def visualize_results(self, images, metrics, test_name):
        """
        Visualize the generated images and their metrics.
        
        Args:
            images: List of images as numpy arrays
            metrics: Dictionary of metrics
            test_name: Name for saving files
        """
        # Plot the images
        n_imgs = min(len(images), 5)  # Limit to 5
        fig, axes = plt.subplots(1, n_imgs, figsize=(4*n_imgs, 4))
        if n_imgs == 1:
            axes = [axes]
        
        for i, (ax, img) in enumerate(zip(axes, images[:n_imgs])):
            ax.imshow(img)
            if i == 0:
                ax.set_title("Reference")
            else:
                metrics_text = ""
                if 'cosine_similarity' in metrics and len(metrics['cosine_similarity']) >= i:
                    metrics_text += f"CosSim: {metrics['cosine_similarity'][i-1]:.4f}\n"
                ax.set_title(f"Sample {i}\n{metrics_text}")
            ax.axis('off')
        
        plt.tight_layout()
        plt.savefig(f"{self.output_dir}/comparison_{test_name}.png")
        plt.close()


class SchedulerStabilityTester(DiffusionModelTester):
    """Compares stability between DDPM and DDIM schedulers."""
    
    def __init__(self, model_path, device='cuda', output_dir='stability_test_results'):
        """
        Initialize the scheduler stability tester.
        
        Args:
            model_path: Path to the pretrained diffusion model
            device: Device to run on
            output_dir: Directory to save results
        """
        super().__init__(model_path, device)
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
    
    def test_reproducibility(self, base_seed=42, num_runs=5, ddpm_steps=1000, ddim_steps=50):
        """
        Test how reproducible the results are when using the same noise multiple times.
        
        Args:
            base_seed: Base random seed
            num_runs: Number of times to run the generation process
            ddpm_steps: Number of steps for DDPM
            ddim_steps: Number of steps for DDIM
        """
        print("\n=== Testing Reproducibility ===")
        
        # Generate base noise
        base_noise = self.generate_noise(base_seed)
        
        # Initialize schedulers
        ddpm_scheduler = DDPMScheduler.from_config(self.pipeline.scheduler.config)
        ddim_scheduler = DDIMScheduler.from_config(self.pipeline.scheduler.config)
        ddim_scheduler.eta = 0.0  # Make DDIM fully deterministic
        
        # Store results
        ddpm_images = []
        ddim_images = []
        ddpm_np_images = []
        ddim_np_images = []
        
        # Run multiple times with the same noise
        for run in range(num_runs):
            print(f"Run {run+1}/{num_runs}:")
            
            print(f"  Generating with DDPM ({ddpm_steps} steps)...")
            ddpm_img, ddpm_np = self.generate_image(base_noise, ddpm_scheduler, ddpm_steps)
            self.save_image(ddpm_img, f"{self.output_dir}/reproducibility_ddpm_run{run}.png")
            ddpm_images.append(ddpm_img)
            ddpm_np_images.append(ddpm_np)
            
            print(f"  Generating with DDIM ({ddim_steps} steps)...")
            ddim_img, ddim_np = self.generate_image(base_noise, ddim_scheduler, ddim_steps)
            self.save_image(ddim_img, f"{self.output_dir}/reproducibility_ddim_run{run}.png")
            ddim_images.append(ddim_img)
            ddim_np_images.append(ddim_np)
        
        # Calculate similarity metrics between runs
        ddpm_similarities = []
        ddim_similarities = []
        
        # Compare all pairs of runs
        for i in range(num_runs):
            for j in range(i+1, num_runs):
                # DDPM similarity
                ddpm_sim = self.compute_similarity(ddpm_np_images[i], ddpm_np_images[j])
                ddpm_similarities.append(ddpm_sim)
                
                # DDIM similarity
                ddim_sim = self.compute_similarity(ddim_np_images[i], ddim_np_images[j])
                ddim_similarities.append(ddim_sim)
        
        # Print similarity statistics
        print("\nSimilarity between runs (cosine similarity, higher = more similar):")
        print(f"  DDPM: mean={np.mean(ddpm_similarities):.4f}, std={np.std(ddpm_similarities):.4f}")
        print(f"  DDIM: mean={np.mean(ddim_similarities):.4f}, std={np.std(ddim_similarities):.4f}")
        
        # Visualize results
        self.visualize_reproducibility(ddpm_images, ddim_images, ddpm_similarities, ddim_similarities)
        
        return {
            'ddpm_similarities': ddpm_similarities,
            'ddim_similarities': ddim_similarities,
            'ddpm_images': ddpm_images,
            'ddim_images': ddim_images
        }
    
    def test_perturbation_stability(self, base_seed=42, perturbation_scales=[0.01, 0.05, 0.1],
                                  ddpm_steps=1000, ddim_steps=50):
        """
        Test how stable the results are to small perturbations in the input noise.
        
        Args:
            base_seed: Base random seed
            perturbation_scales: List of scales for noise perturbation
            ddpm_steps: Number of steps for DDPM
            ddim_steps: Number of steps for DDIM
        """
        print("\n=== Testing Perturbation Stability ===")
        
        # Generate base noise
        base_noise = self.generate_noise(base_seed)
        
        # Initialize schedulers
        ddpm_scheduler = DDPMScheduler.from_config(self.pipeline.scheduler.config)
        ddim_scheduler = DDIMScheduler.from_config(self.pipeline.scheduler.config)
        ddim_scheduler.eta = 0.0  # Make DDIM fully deterministic
        
        # Generate reference images with unperturbed noise
        print("Generating reference images with unperturbed noise:")
        ddpm_ref_img, ddpm_ref_np = self.generate_image(base_noise, ddpm_scheduler, ddpm_steps)
        self.save_image(ddpm_ref_img, f"{self.output_dir}/perturbation_ddpm_ref.png")
        
        ddim_ref_img, ddim_ref_np = self.generate_image(base_noise, ddim_scheduler, ddim_steps)
        self.save_image(ddim_ref_img, f"{self.output_dir}/perturbation_ddim_ref.png")
        
        # Store results
        results = {
            'ddpm': {'images': [ddpm_ref_img], 'np_images': [ddpm_ref_np], 'similarities': []},
            'ddim': {'images': [ddim_ref_img], 'np_images': [ddim_ref_np], 'similarities': []}
        }
        
        # Test each perturbation scale
        for scale in perturbation_scales:
            print(f"\nTesting perturbation scale {scale}:")
            
            # Generate perturbation
            pert_gen = torch.Generator(device=self.device)
            pert_gen.manual_seed(base_seed + 100)  # Different seed for perturbation
            perturbation = torch.randn(self.noise_shape, generator=pert_gen,
                                     device=self.device, dtype=self.unet.dtype)
            perturbation = perturbation * scale
            
            # Add perturbation to base noise
            perturbed_noise = base_noise + perturbation
            
            # Generate with DDPM
            print(f"  Generating with DDPM ({ddpm_steps} steps)...")
            ddpm_img, ddpm_np = self.generate_image(perturbed_noise, ddpm_scheduler, ddpm_steps)
            self.save_image(ddpm_img, f"{self.output_dir}/perturbation_ddpm_scale{scale}.png")
            
            # Calculate similarity to reference
            ddpm_sim = self.compute_similarity(ddpm_ref_np, ddpm_np)
            results['ddpm']['images'].append(ddpm_img)
            results['ddpm']['np_images'].append(ddpm_np)
            results['ddpm']['similarities'].append(ddpm_sim)
            
            # Generate with DDIM
            print(f"  Generating with DDIM ({ddim_steps} steps)...")
            ddim_img, ddim_np = self.generate_image(perturbed_noise, ddim_scheduler, ddim_steps)
            self.save_image(ddim_img, f"{self.output_dir}/perturbation_ddim_scale{scale}.png")
            
            # Calculate similarity to reference
            ddim_sim = self.compute_similarity(ddim_ref_np, ddim_np)
            results['ddim']['images'].append(ddim_img)
            results['ddim']['np_images'].append(ddim_np)
            results['ddim']['similarities'].append(ddim_sim)
            
            print(f"  Similarity to reference (cosine similarity, higher = more similar):")
            print(f"    DDPM: {ddpm_sim:.4f}")
            print(f"    DDIM: {ddim_sim:.4f}")
        
        # Visualize results
        self.visualize_perturbation_stability(results, perturbation_scales)
        
        return results
    
    def visualize_reproducibility(self, ddpm_images, ddim_images, ddpm_similarities, ddim_similarities):
        """
        Visualize the reproducibility test results.
        
        Args:
            ddpm_images: List of DDPM-generated images
            ddim_images: List of DDIM-generated images
            ddpm_similarities: List of similarity scores between DDPM images
            ddim_similarities: List of similarity scores between DDIM images
        """
        num_runs = len(ddpm_images)
        
        # Create figure
        fig = plt.figure(figsize=(10, 8))
        
        # Plot DDPM images
        for i in range(num_runs):
            ax = fig.add_subplot(2, num_runs, i + 1)
            ax.imshow(np.array(ddpm_images[i]))
            ax.set_title(f"DDPM Run {i+1}")
            ax.axis('off')
        
        # Plot DDIM images
        for i in range(num_runs):
            ax = fig.add_subplot(2, num_runs, i + 1 + num_runs)
            ax.imshow(np.array(ddim_images[i]))
            ax.set_title(f"DDIM Run {i+1}")
            ax.axis('off')
        
        plt.tight_layout()
        plt.savefig(f"{self.output_dir}/reproducibility_comparison.png")
        plt.close()
        
        # Plot similarity statistics
        fig, ax = plt.subplots(figsize=(10, 6))
        
        # Create boxplot
        box_data = [ddpm_similarities, ddim_similarities]
        box = ax.boxplot(box_data, patch_artist=True)
        
        # Set colors
        colors = ['lightblue', 'lightgreen']
        for patch, color in zip(box['boxes'], colors):
            patch.set_facecolor(color)
        
        # Add individual points
        for i, data in enumerate(box_data):
            x = np.random.normal(i + 1, 0.05, len(data))
            ax.scatter(x, data, alpha=0.5)
        
        ax.set_title("Image Similarity Between Runs (Higher = More Similar)")
        ax.set_ylabel("Cosine Similarity")
        ax.set_xticklabels(['DDPM', 'DDIM'])
        ax.set_ylim([0, 1.05])
        ax.grid(True, linestyle='--', alpha=0.7)
        
        # Add stats in text
        ax.text(1, 0.1, f"Mean: {np.mean(ddpm_similarities):.4f}\nStd: {np.std(ddpm_similarities):.4f}",
               bbox=dict(facecolor='white', alpha=0.5))
        ax.text(2, 0.1, f"Mean: {np.mean(ddim_similarities):.4f}\nStd: {np.std(ddim_similarities):.4f}",
               bbox=dict(facecolor='white', alpha=0.5))
        
        plt.tight_layout()
        plt.savefig(f"{self.output_dir}/reproducibility_stats.png")
        plt.close()
    
    def visualize_perturbation_stability(self, results, perturbation_scales):
        """
        Visualize the perturbation stability test results.
        
        Args:
            results: Dictionary with test results
            perturbation_scales: List of perturbation scales tested
        """
        # Number of scales tested (including reference)
        num_scales = len(perturbation_scales) + 1
        
        # Create figure
        fig = plt.figure(figsize=(12, 6))
        
        # Plot DDPM images
        for i in range(num_scales):
            ax = fig.add_subplot(2, num_scales, i + 1)
            ax.imshow(np.array(results['ddpm']['images'][i]))
            if i == 0:
                ax.set_title("DDPM\nReference")
            else:
                ax.set_title(f"DDPM\nScale {perturbation_scales[i-1]}")
            ax.axis('off')
        
        # Plot DDIM images
        for i in range(num_scales):
            ax = fig.add_subplot(2, num_scales, i + 1 + num_scales)
            ax.imshow(np.array(results['ddim']['images'][i]))
            if i == 0:
                ax.set_title("DDIM\nReference")
            else:
                ax.set_title(f"DDIM\nScale {perturbation_scales[i-1]}")
            ax.axis('off')
        
        plt.tight_layout()
        plt.savefig(f"{self.output_dir}/perturbation_comparison.png")
        plt.close()
        
        # Plot similarity graph
        fig, ax = plt.subplots(figsize=(10, 6))
        
        # Plot similarity vs perturbation scale
        ax.plot(perturbation_scales, results['ddpm']['similarities'], 'o-', color='blue', label='DDPM')
        ax.plot(perturbation_scales, results['ddim']['similarities'], 'o-', color='green', label='DDIM')
        
        # Add labels
        ax.set_title("Similarity to Reference Image vs. Perturbation Scale")
        ax.set_xlabel("Perturbation Scale")
        ax.set_ylabel("Cosine Similarity (higher = more similar)")
        ax.set_ylim([0, 1.05])
        ax.grid(True, linestyle='--', alpha=0.7)
        ax.legend()
        
        # Add text callouts for specific points
        for i, scale in enumerate(perturbation_scales):
            ax.annotate(f"{results['ddpm']['similarities'][i]:.4f}",
                       (scale, results['ddpm']['similarities'][i]),
                       textcoords="offset points", xytext=(0,10), ha='center')
            ax.annotate(f"{results['ddim']['similarities'][i]:.4f}",
                       (scale, results['ddim']['similarities'][i]),
                       textcoords="offset points", xytext=(0,-15), ha='center')
        
        plt.tight_layout()
        plt.savefig(f"{self.output_dir}/perturbation_stability_graph.png")
        plt.close()

class DiffusionInterpolationTester(DiffusionModelTester):
    """Performs interpolation between different noise vectors."""
    
    def __init__(self, model_path, device='cuda', output_dir='interpolation_results'):
        """
        Initialize the interpolation test for diffusion models.
        
        Args:
            model_path: Path to the pretrained diffusion model
            device: Device to run on
            output_dir: Directory to save results
        """
        super().__init__(model_path, device)
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        
        # Use DDIM scheduler for faster sampling
        self.scheduler = DDIMScheduler.from_config(self.pipeline.scheduler.config)
        self.scheduler.eta = 0.0  # Deterministic for consistent interpolation
        self.pipeline.scheduler = self.scheduler
    
    def interpolate_between_noises(self, seed_a=42, seed_b=43, total_steps=100,
                                 visualization_steps=10, inference_steps=50):
        """
        Generate a sequence of images by interpolating between two noise vectors.
        Computes many interpolation steps but only keeps a subset for visualization.
        
        Args:
            seed_a: Random seed for the first noise
            seed_b: Random seed for the second noise
            total_steps: Total number of interpolation steps to compute
            visualization_steps: Number of steps to keep for visualization
            inference_steps: Number of denoising steps per image
            
        Returns:
            List of (t, PIL Image) tuples for visualization
        """
        print(f"Interpolating between seed {seed_a} and seed {seed_b} with {total_steps} steps...")
        print(f"Will visualize {visualization_steps} evenly spaced images")
        
        # Generate the two endpoint noise tensors
        noise_a = self.generate_noise(seed_a)
        noise_b = self.generate_noise(seed_b)
        
        # Determine which steps to visualize
        if visualization_steps >= total_steps:
            visualization_indices = list(range(total_steps))
        else:
            # Ensure we include the first and last step
            step_size = (total_steps - 1) / (visualization_steps - 1)
            visualization_indices = [round(i * step_size) for i in range(visualization_steps)]
        
        # Generate images (only keep the ones we want to visualize)
        visualization_images = []
        
        for step in tqdm(range(total_steps)):
            # Calculate interpolation weight
            t = step / (total_steps - 1)
            
            # Only generate images for steps we want to visualize
            if step in visualization_indices:
                # Spherical linear interpolation (slerp)
                omega = torch.acos((noise_a * noise_b).sum() / 
                                  (torch.norm(noise_a) * torch.norm(noise_b)))
                
                if omega.item() != 0:  # Avoid division by zero
                    so = torch.sin(omega)
                    interp_noise = torch.sin((1-t) * omega) / so * noise_a + torch.sin(t * omega) / so * noise_b
                else:
                    # Fallback to linear interpolation if vectors are already aligned
                    interp_noise = (1-t) * noise_a + t * noise_b
                
                # Generate image from interpolated noise
                img, _ = self.generate_image(interp_noise, self.scheduler, inference_steps)
                visualization_images.append((t, img))
                
                # Save individual images
                img.save(f"{self.output_dir}/interp_seed{seed_a}_to_seed{seed_b}_t{t:.3f}.png")
        
        return visualization_images
    
    def create_interpolation_grid(self, images_with_t, seed_a, seed_b, cols=5):
        """
        Create a grid visualization of the interpolation sequence.
        
        Args:
            images_with_t: List of (t, img) tuples where t is the interpolation parameter
            seed_a: First seed
            seed_b: Second seed
            cols: Number of columns in the grid
        """
        num_images = len(images_with_t)
        rows = (num_images + cols - 1) // cols
        
        # Create figure
        fig, axes = plt.subplots(rows, cols, figsize=(cols*3, rows*3))
        if rows == 1 and cols == 1:
            axes = np.array([[axes]])
        elif rows == 1:
            axes = np.array([axes])
        elif cols == 1:
            axes = np.array([[ax] for ax in axes])
        
        # Plot images
        for i, (t, img) in enumerate(images_with_t):
            row = i // cols
            col = i % cols
            axes[row, col].imshow(np.array(img))
            axes[row, col].set_title(f"t={t:.3f}")
            axes[row, col].axis('off')
        
        # Hide empty subplots
        for i in range(len(images_with_t), rows*cols):
            row = i // cols
            col = i % cols
            axes[row, col].axis('off')
        
        plt.suptitle(f"Interpolation from Seed {seed_a} to Seed {seed_b}")
        plt.tight_layout()
        plt.savefig(f"{self.output_dir}/grid_seed{seed_a}_to_seed{seed_b}.png", dpi=300)
        plt.close()
        
        print(f"Saved visualization grid to {self.output_dir}/grid_seed{seed_a}_to_seed{seed_b}.png")
        
        return fig

def main():
    """Run diffusion model stability tests."""
    import argparse
    
    parser = argparse.ArgumentParser(description='Diffusion Model Stability Testing')
    parser.add_argument('--model', type=str, default="otausendschoen/ddpm-our-faces-reduced", 
                        help='Path to the diffusion model')
    parser.add_argument('--output', type=str, default="results", 
                        help='Output directory for results')
    parser.add_argument('--tests', type=str, default="all", 
                        help='Tests to run: "all", "noise", "scheduler", or "interpolation"')
    parser.add_argument('--device', type=str, default="cuda" if torch.cuda.is_available() else "cpu",
                        help='Device to use (cuda or cpu)')
    
    args = parser.parse_args()
    
    # Create output directory
    os.makedirs(args.output, exist_ok=True)
    
    # Run the selected tests
    if args.tests in ["all", "noise"]:
        noise_dir = os.path.join(args.output, "noise_perturbation")
        tester = NoisePerturbationTester(args.model, args.device, output_dir=noise_dir)
        tester.test_noise_perturbation()
    
    if args.tests in ["all", "scheduler"]:
        scheduler_dir = os.path.join(args.output, "scheduler_stability")
        tester = SchedulerStabilityTester(args.model, args.device, output_dir=scheduler_dir)
        tester.test_reproducibility()
        tester.test_perturbation_stability()
    
    if args.tests in ["all", "interpolation"]:
        interp_dir = os.path.join(args.output, "interpolation")
        tester = DiffusionInterpolationTester(args.model, args.device, output_dir=interp_dir)
        # Test 1
        tester.interpolate_between_noises(  seed_a=42, 
                                            seed_b=43,
                                            total_steps=500,
                                            visualization_steps=10,
                                            inference_steps=50)
         # Test 2 with different seeds 
        tester.interpolate_between_noises(  seed_a=100, 
                                            seed_b=201,
                                            total_steps=500,
                                            visualization_steps=10,
                                            inference_steps=50)
    
    print("\nAll tests completed!")
    print(f"Results saved to: {args.output}")

if __name__ == "__main__":
    main()