In [None]:
import numpy as np
import torch
import gc
import time
from torch.utils.data import DataLoader, random_split
from hydra import compose, initialize
from ntd.diffusion_model import Diffusion, Trainer
from ntd.networks import AdaConv
from ntd.utils.kernels_and_diffusion_utils import OUProcess
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
import threading
from torch.utils.data import Subset
import numpy as np
import torch
from scipy import signal
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
from tqdm import tqdm
from torch.utils.data import DataLoader
import os
import time

def temporal_train_test_split(dataset, train_ratio=0.7):
    """
    Split dataset while preserving temporal order.
    
    Args:
        dataset: The full dataset
        train_ratio: Proportion of data to use for training (default: 0.7)
    
    Returns:
        train_dataset, test_dataset
    """
    # Calculate split index
    total_length = len(dataset)
    train_size = int(total_length * train_ratio)
    
    # Create train/test indices in order
    train_indices = list(range(0, train_size))
    test_indices = list(range(train_size, total_length))
    
    # Print split information
    print(f"\nTemporal Split Information:")
    print(f"Total samples: {total_length}")
    print(f"Training samples: {len(train_indices)} (indices 0 to {train_size-1})")
    print(f"Testing samples: {len(test_indices)} (indices {train_size} to {total_length-1})")
    
    # Create train/test datasets using Subset
    train_dataset = Subset(dataset, train_indices)
    test_dataset = Subset(dataset, test_indices)
    
    return train_dataset, test_dataset
    
class ECoGDBSDataset(Dataset):
    def __init__(self, ecog_data, dbs_data):
        """
        Args:
            ecog_data: numpy array of shape (num_epochs, 3, sequence_length)
            dbs_data: numpy array of shape (num_epochs, 1, sequence_length)
        """
        # Convert inputs to numpy arrays if they're not already
        if isinstance(ecog_data, torch.Tensor):
            ecog_data = ecog_data.numpy()
        if isinstance(dbs_data, torch.Tensor):
            dbs_data = dbs_data.numpy()
            
        print(f"Initial shapes:")
        print(f"ECoG data: {ecog_data.shape}")
        print(f"DBS data: {dbs_data.shape}")
        
        # Ensure correct dimensions
        if len(ecog_data.shape) != 3:
            raise ValueError(f"ECoG data should be 3D, got shape {ecog_data.shape}")
        if len(dbs_data.shape) != 3:
            raise ValueError(f"DBS data should be 3D, got shape {dbs_data.shape}")
        
        # Ensure matching dimensions
        if ecog_data.shape[0] != dbs_data.shape[0]:
            raise ValueError(f"Number of epochs don't match: {ecog_data.shape[0]} vs {dbs_data.shape[0]}")
        if ecog_data.shape[2] != dbs_data.shape[2]:
            raise ValueError(f"Sequence lengths don't match: {ecog_data.shape[2]} vs {dbs_data.shape[2]}")
        
        # Store as numpy arrays
        self.ecog_data = np.ascontiguousarray(ecog_data, dtype=np.float32)
        self.dbs_data = np.ascontiguousarray(dbs_data, dtype=np.float32)
        
        print(f"\nProcessed shapes:")
        print(f"ECoG data: {self.ecog_data.shape}")
        print(f"DBS data: {self.dbs_data.shape}")
        
        # Compute normalization statistics
        self.ecog_mean = np.mean(self.ecog_data, axis=(0, 2), keepdims=True)
        self.ecog_std = np.std(self.ecog_data, axis=(0, 2), keepdims=True)
        self.dbs_mean = np.mean(self.dbs_data, axis=(0, 2), keepdims=True)
        self.dbs_std = np.std(self.dbs_data, axis=(0, 2), keepdims=True)
        
        # Prevent division by zero
        self.ecog_std[self.ecog_std == 0] = 1
        self.dbs_std[self.dbs_std == 0] = 1
        
        # Store number of epochs
        self.num_epochs = ecog_data.shape[0]
        
        # Thread lock for safety
        self.lock = threading.Lock()
        
    def __len__(self):
        return self.num_epochs
    
    def __getitem__(self, idx):
        with self.lock:
            try:
                # Get data
                ecog_sample = self.ecog_data[idx].copy()  # Should be (3, 1000)
                dbs_sample = self.dbs_data[idx].copy()    # Should be (1, 1000)
                
                # Normalize
                ecog_sample = (ecog_sample - self.ecog_mean[0]) / self.ecog_std[0]
                dbs_sample = (dbs_sample - self.dbs_mean[0]) / self.dbs_std[0]
                
                # Convert to torch tensors
                ecog_tensor = torch.from_numpy(ecog_sample).float()
                dbs_tensor = torch.from_numpy(dbs_sample).float()
                
                # Final shape check
                assert dbs_tensor.shape == (1, 1000), f"Expected DBS shape (1, 1000), got {dbs_tensor.shape}"
                assert ecog_tensor.shape == (3, 1000), f"Expected ECoG shape (3, 1000), got {ecog_tensor.shape}"
                
                return {
                    "signal": dbs_tensor,
                    "cond": ecog_tensor
                }
            except Exception as e:
                print(f"Error accessing index {idx}: {str(e)}")
                raise
def train_ecog_dbs_model(ecog_data, dbs_data, config_path):
    """Train DDPM model for ECoG to DBS imputation with temporal split."""
    
    # Initialize config
    with initialize(version_base=None, config_path=config_path):
        cfg = compose(config_name="config_ou")
    
    # Create dataset
    dataset = ECoGDBSDataset(ecog_data, dbs_data)
    
    # Use temporal split instead of random split
    train_dataset, test_dataset = temporal_train_test_split(
        dataset, 
        train_ratio=cfg.dataset.train_test_split
    )
    
    # Create data loader (without shuffling for test set)
    train_loader = DataLoader(
        train_dataset,
        batch_size=cfg.optimizer.train_batch_size,
        shuffle=True,  # Can still shuffle training data
        num_workers=0,
        pin_memory=False
    )
    
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Initialize network
    network = AdaConv(
        signal_length=cfg.dataset.signal_length,
        signal_channel=1,  # DBS data has 1 channel
        cond_dim=3,       # ECoG data has 3 channels
        hidden_channel=cfg.network.hidden_channel,
        in_kernel_size=cfg.network.in_kernel_size,
        out_kernel_size=cfg.network.out_kernel_size,
        slconv_kernel_size=cfg.network.slconv_kernel_size,
        num_scales=cfg.network.num_scales,
        num_blocks=cfg.network.num_blocks,
        num_off_diag=cfg.network.num_off_diag,
        use_pos_emb=cfg.network.use_pos_emb,
        padding_mode=cfg.network.padding_mode,
        use_fft_conv=cfg.network.use_fft_conv,
    ).to(device)
    
    # Initialize noise process (OU Process for 1/f noise)
    ou_process = OUProcess(
        cfg.diffusion_kernel.sigma_squared,
        cfg.diffusion_kernel.ell,
        cfg.dataset.signal_length
    ).to(device)
    
    # Initialize diffusion model
    diffusion = Diffusion(
        network=network,
        noise_sampler=ou_process,
        mal_dist_computer=ou_process,
        diffusion_time_steps=cfg.diffusion.diffusion_steps,
        schedule=cfg.diffusion.schedule,
        start_beta=cfg.diffusion.start_beta,
        end_beta=cfg.diffusion.end_beta,
    ).to(device)
    
    # Initialize optimizer
    optimizer = torch.optim.AdamW(
        network.parameters(),
        lr=cfg.optimizer.lr,
        weight_decay=cfg.optimizer.weight_decay,
    )
    
    # Custom Trainer class to handle our data format
    class CustomTrainer:
        def __init__(self, model, data_loader, optimizer, device):
            self.model = model
            self.data_loader = data_loader
            self.optimizer = optimizer
            self.device = device
        
        def train_epoch(self):
            batchwise_losses = []
            for batch in self.data_loader:
                # Extract signal and condition from batch
                sig_batch = batch["signal"].to(self.device)
                cond_batch = batch["cond"].to(self.device)
                
                batch_size = sig_batch.shape[0]
                
                # Train step
                batch_loss = self.model.train_batch(sig_batch, cond=cond_batch)
                batch_loss = torch.mean(batch_loss)
                
                batchwise_losses.append((batch_size, batch_loss.item()))
                
                self.optimizer.zero_grad()
                batch_loss.backward()
                
                # Optional: Gradient clipping
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                
                self.optimizer.step()
                
            return batchwise_losses
    
    # Initialize trainer
    trainer = CustomTrainer(diffusion, train_loader, optimizer, device)
    
    # Training loop
    print("\nStarting training...")
    for epoch in range(cfg.optimizer.num_epochs):
        try:
            print(f"\nEpoch {epoch}")
            batchwise_losses = trainer.train_epoch()
            
            # Calculate epoch loss
            epoch_loss = 0
            total_samples = 0
            for batch_size, batch_loss in batchwise_losses:
                epoch_loss += batch_size * batch_loss
                total_samples += batch_size
            epoch_loss /= total_samples
            
            print(f"Average loss: {epoch_loss:.6f}")
            
        except Exception as e:
            print(f"Error during epoch {epoch}:")
            import traceback
            traceback.print_exc()
            raise
            
    return diffusion, train_dataset, test_dataset

def get_frequency_band_power(signal_data, fs=250, nperseg=256):
    """Calculate power in different frequency bands"""
    f, psd = signal.welch(signal_data, fs=fs, nperseg=nperseg)
    
    # Define frequency bands
    bands = {
        'theta': (4, 8),
        'alpha': (8, 13),
        'beta': (13, 35),
        'gamma': (35, 100)
    }
    
    # Calculate power in each band
    powers = {}
    for band_name, (low_freq, high_freq) in bands.items():
        # Find frequencies within band
        mask = (f >= low_freq) & (f <= high_freq)
        # Calculate total power in band
        powers[band_name] = np.trapz(psd[mask], f[mask])
    
    return powers

def get_all_predictions_fast(model, test_dataset, batch_size=64):
    """Memory-optimized prediction generation"""
    model.eval()
    
    # Create DataLoader
    loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    total_batches = len(loader)
    
    # Initialize lists
    real_dbs_all = []
    imputed_dbs_all = []
    
    print(f"\nProcessing {len(test_dataset)} samples in {total_batches} batches...")
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(loader, desc="Generating predictions")):
            # Move data to device
            ecog_batch = batch['cond'].to(model.device)
            real_dbs_batch = batch['signal'].cpu().numpy()
            
            # Generate predictions
            imputed_dbs_batch = model.sample(
                num_samples=ecog_batch.size(0),
                cond=ecog_batch,
                noise_type="alpha_beta"
            ).cpu().numpy()
            
            # Store results and immediately clear GPU tensors
            real_dbs_all.append(real_dbs_batch)
            imputed_dbs_all.append(imputed_dbs_batch)
            
            # Clear batch data
            del ecog_batch
            if batch_idx % 10 == 0:  # Clear cache periodically
                clear_gpu_memory()
    
    # Concatenate results
    print("\nProcessing results...")
    real_dbs_all = np.concatenate(real_dbs_all, axis=0).squeeze()
    imputed_dbs_all = np.concatenate(imputed_dbs_all, axis=0).squeeze()
    
    # Calculate band powers
    print("\nCalculating frequency band powers...")
    band_powers_real = {}
    band_powers_imputed = {}
    
    for band in ['theta', 'alpha', 'beta', 'gamma']:
        band_powers_real[band] = []
        band_powers_imputed[band] = []
    
    for i in tqdm(range(len(real_dbs_all)), desc="Computing band powers"):
        powers_real = get_frequency_band_power(real_dbs_all[i])
        powers_imputed = get_frequency_band_power(imputed_dbs_all[i])
        
        for band in powers_real:
            band_powers_real[band].append(powers_real[band])
            band_powers_imputed[band].append(powers_imputed[band])
    
    # Convert to numpy arrays
    for band in band_powers_real:
        band_powers_real[band] = np.array(band_powers_real[band])
        band_powers_imputed[band] = np.array(band_powers_imputed[band])
    
    return {
        'real_dbs': real_dbs_all,
        'imputed_dbs': imputed_dbs_all,
        'band_powers_real': band_powers_real,
        'band_powers_imputed': band_powers_imputed
    }

def analyze_frequency_bands(results):
    """Analyze frequency band powers and their correlations"""
    band_powers_real = results['band_powers_real']
    band_powers_imputed = results['band_powers_imputed']
    
    # Calculate correlations for each band
    correlations = {}
    for band in band_powers_real:
        corr = pearsonr(band_powers_real[band], band_powers_imputed[band])[0]
        correlations[band] = corr
    
    # Print correlations
    print("\nFrequency Band Power Correlations:")
    for band, corr in correlations.items():
        print(f"{band.capitalize()}: {corr:.3f}")
    
    # Create scatter plots for each band
    fig, axes = plt.subplots(2, 2, figsize=(15, 15))
    axes = axes.flatten()
    
    for idx, band in enumerate(band_powers_real):
        ax = axes[idx]
        ax.scatter(band_powers_real[band], band_powers_imputed[band], 
                  alpha=0.5, label=f'r = {correlations[band]:.3f}')
        
        # Add diagonal line
        min_val = min(band_powers_real[band].min(), band_powers_imputed[band].min())
        max_val = max(band_powers_real[band].max(), band_powers_imputed[band].max())
        ax.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.5)
        
        ax.set_xlabel('Real Power')
        ax.set_ylabel('Imputed Power')
        ax.set_title(f'{band.capitalize()} Band Power')
        ax.legend()
        ax.grid(True)
    
    plt.tight_layout()
    plt.show()
    
    return correlations


In [None]:
import os
import torch
import numpy as np
import time
import gc
from glob import glob
from tqdm import tqdm
from datetime import datetime, timedelta

def clear_gpu_memory():
    """Enhanced GPU memory clearing"""
    if torch.cuda.is_available():
        # Clear CUDA cache
        torch.cuda.empty_cache()
        # Force garbage collection
        gc.collect()
        # Ensure CUDA synchronization
        torch.cuda.synchronize()

def format_time(seconds):
    """Convert seconds to human-readable format"""
    return str(timedelta(seconds=int(seconds)))

def plot_and_save_psd_comparison(results, subject_id, output_dir):
    """Plot and save PSD comparison for a subject"""
    try:
        # Create figure directory if it doesn't exist
        fig_dir = os.path.join(output_dir, 'psd_figures')
        os.makedirs(fig_dir, exist_ok=True)
        
        print("\nCalculating average PSDs...")
        real_dbs = results['real_dbs']
        imputed_dbs = results['imputed_dbs']
        
        # Calculate PSDs for all trials
        f, psd_real_all = signal.welch(real_dbs[0], fs=250, nperseg=256)
        psd_imputed_all = signal.welch(imputed_dbs[0], fs=250, nperseg=256)[1]
        
        for i in tqdm(range(1, len(real_dbs)), desc="Computing PSDs"):
            psd_real = signal.welch(real_dbs[i], fs=250, nperseg=256)[1]
            psd_imputed = signal.welch(imputed_dbs[i], fs=250, nperseg=256)[1]
            psd_real_all = np.vstack((psd_real_all, psd_real))
            psd_imputed_all = np.vstack((psd_imputed_all, psd_imputed))
        
        # Calculate mean and std
        psd_real_mean = np.mean(psd_real_all, axis=0)
        psd_real_std = np.std(psd_real_all, axis=0)
        psd_imputed_mean = np.mean(psd_imputed_all, axis=0)
        psd_imputed_std = np.std(psd_imputed_all, axis=0)
        
        # Create figure
        plt.figure(figsize=(12, 8))
        
        # Plot real DBS
        plt.semilogy(f, psd_real_mean, 'b', label='Real DBS (mean)', alpha=0.7)
        plt.fill_between(f, 
                        psd_real_mean - psd_real_std, 
                        psd_real_mean + psd_real_std, 
                        color='b', alpha=0.2)
        
        # Plot imputed DBS
        plt.semilogy(f, psd_imputed_mean, 'r', label='Imputed DBS (mean)', alpha=0.7)
        plt.fill_between(f, 
                        psd_imputed_mean - psd_imputed_std, 
                        psd_imputed_mean + psd_imputed_std, 
                        color='r', alpha=0.2)
        
        # Add frequency band annotations
        bands = {
            'theta': (4, 8),
            'alpha': (8, 13),
            'beta': (13, 30),
            'gamma': (30, 100)
        }
        
        # Add frequency band shading and labels
        colors = ['lightgray', 'lightblue', 'lightgreen', 'lightpink']
        for (band, (fmin, fmax)), color in zip(bands.items(), colors):
            plt.axvspan(fmin, fmax, color=color, alpha=0.2)
            plt.text((fmin + fmax)/2, plt.ylim()[0]*1.1, band, 
                    horizontalalignment='center', verticalalignment='bottom')
        
        plt.xlabel('Frequency (Hz)')
        plt.ylabel('Power Spectral Density')
        plt.title(f'Average Power Spectral Density with Std Dev\n{subject_id}')
        plt.xlim(0, 125)
        plt.legend()
        plt.grid(True)
        
        # Save figure
        fig_path = os.path.join(fig_dir, f'{subject_id}_psd_comparison.png')
        plt.savefig(fig_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        # Save PSD data
        psd_data_path = os.path.join(fig_dir, f'{subject_id}_psd_data.npz')
        np.savez(psd_data_path,
                 frequencies=f,
                 psd_real_mean=psd_real_mean,
                 psd_real_std=psd_real_std,
                 psd_imputed_mean=psd_imputed_mean,
                 psd_imputed_std=psd_imputed_std)
        
        print(f"PSD comparison saved to {fig_path}")
        
    except Exception as e:
        print(f"Error plotting PSD comparison: {str(e)}")
        
def process_single_subject(ecog_path, dbs_path, output_dir, config_path):
    """Process a single subject's data with PSD plotting"""
    try:
        # Extract subject ID from filename
        subject_id = os.path.basename(ecog_path).split('_')[0]
        side = os.path.basename(ecog_path).split('_')[1][:4]
        full_subject_id = f"{subject_id}_{side}"
        print(f"\nProcessing subject {full_subject_id}")
        
        # Load data
        with tqdm(total=2, desc="Loading data") as pbar:
            ecog_data = np.load(ecog_path)
            pbar.update(1)
            dbs_data = np.load(dbs_path)
            pbar.update(1)
        
        print(f"Data shapes - ECoG: {ecog_data.shape}, DBS: {dbs_data.shape}")
        if torch.cuda.is_available():
            print(f"GPU memory before training: {torch.cuda.memory_allocated(0)/1e9:.2f} GB")
        
        # Create output path
        output_path = os.path.join(output_dir, f"{subject_id}_{side}_prediction_results.npz")
        
        # Start time
        start_time = time.time()
        
        # Train model
        print("\nTraining model...")
        diffusion_model, train_dataset, test_dataset = train_ecog_dbs_model(
            ecog_data=ecog_data,
            dbs_data=dbs_data,
            config_path=config_path
        )
        
        # Clear unnecessary training data
        del ecog_data, dbs_data
        clear_gpu_memory()
        if torch.cuda.is_available():
            print(f"GPU memory after training: {torch.cuda.memory_allocated(0)/1e9:.2f} GB")
        
        # Generate predictions
        print("\nGenerating predictions...")
        results = get_all_predictions_fast(diffusion_model, test_dataset, 
                                           batch_size=2048)
        
        # Plot and save PSD comparison
        plot_and_save_psd_comparison(results, full_subject_id, output_dir)
        
        # Clear model and datasets
        del diffusion_model, train_dataset
        clear_gpu_memory()
        
        if torch.cuda.is_available():
            print(f"GPU memory after predictions: {torch.cuda.memory_allocated(0)/1e9:.2f} GB")
        
        # Analyze and save results
        print("\nAnalyzing frequency bands...")
        correlations = analyze_frequency_bands(results)
        
        # Save results with progress bar
        print(f"\nSaving results to {output_path}")
        with tqdm(total=1, desc="Saving results") as pbar:
            np.savez(output_path,
                     real_dbs=results['real_dbs'],
                     imputed_dbs=results['imputed_dbs'],
                     band_powers_real=results['band_powers_real'],
                     band_powers_imputed=results['band_powers_imputed'])
            pbar.update(1)
        
        # Calculate duration
        duration = time.time() - start_time
        print(f"\nCompleted in {format_time(duration)}")
        
        # Final cleanup
        del results, test_dataset, correlations
        clear_gpu_memory()
        if torch.cuda.is_available():
            print(f"Final GPU memory: {torch.cuda.memory_allocated(0)/1e9:.2f} GB")
        
        return True, duration
        
    except Exception as e:
        print(f"Error processing {full_subject_id}: {str(e)}")
        clear_gpu_memory()
        return False, 0

def batch_process_all_subjects(input_dir, output_dir, config_path):
    """Process all subjects with PSD plotting"""
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Get all ECoG files
    ecog_files = sorted(glob(os.path.join(input_dir, "*ecog.npy")))
    total_subjects = len(ecog_files)
    
    print(f"\nFound {total_subjects} subjects to process")
    time.sleep(1)  # Give time to read the message
    
    # Initialize progress tracking
    results_summary = []
    total_start_time = time.time()
    successful_processes = 0
    
    # Create main progress bar
    with tqdm(total=total_subjects, desc="Overall Progress") as pbar:
        for subject_idx, ecog_file in enumerate(ecog_files, 1):
            # Get corresponding DBS file
            dbs_file = ecog_file.replace('ecog.npy', 'dbs.npy')
            
            if os.path.exists(dbs_file):
                # Display progress information
                elapsed_time = time.time() - total_start_time
                if successful_processes > 0:
                    avg_time_per_subject = elapsed_time / successful_processes
                    estimated_remaining = avg_time_per_subject * (total_subjects - subject_idx + 1)
                    print(f"\nEstimated time remaining: {format_time(estimated_remaining)}")
                
                print(f"\nProcessing subject {subject_idx}/{total_subjects}")
                print(f"File: {os.path.basename(ecog_file)}")
                
                # Process subject
                start_time = time.time()
                success, duration = process_single_subject(
                    ecog_path=ecog_file,
                    dbs_path=dbs_file,
                    output_dir=output_dir,
                    config_path=config_path
                )
                
                if success:
                    successful_processes += 1
                
                # Store results
                results_summary.append({
                    'subject': os.path.basename(ecog_file),
                    'success': success,
                    'duration': duration,
                    'processed_at': datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                })
                
                # Update progress
                pbar.update(1)
                
                # Save progress summary
                with open(os.path.join(output_dir, 'processing_summary.txt'), 'w') as f:
                    f.write("Processing Summary:\n")
                    f.write(f"Total subjects: {total_subjects}\n")
                    f.write(f"Completed: {subject_idx}\n")
                    f.write(f"Successful: {successful_processes}\n")
                    f.write(f"Current time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
                    
                    for result in results_summary:
                        f.write(f"Subject: {result['subject']}\n")
                        f.write(f"Success: {result['success']}\n")
                        f.write(f"Duration: {format_time(result['duration'])}\n")
                        f.write(f"Processed at: {result['processed_at']}\n\n")
            else:
                print(f"Warning: No matching DBS file found for {ecog_file}")
                pbar.update(1)
    
    # Final summary
    total_duration = time.time() - total_start_time
    print(f"\nBatch processing completed!")
    print(f"Total time: {format_time(total_duration)}")
    print(f"Successfully processed: {successful_processes}/{total_subjects}")
    
    # Create PSD figure directory
    psd_dir = os.path.join(output_dir, 'psd_figures')
    os.makedirs(psd_dir, exist_ok=True)
    
# Create a summary plot at the end
def create_summary_plot():
    try:
        print("\nCreating summary PSD plot...")
        plt.figure(figsize=(15, 10))
        
        # Get all PSD data files
        psd_files = glob(os.path.join(psd_dir, '*_psd_data.npz'))
        
        # Plot each subject's PSD with lighter colors
        for psd_file in psd_files:
            data = np.load(psd_file)
            subject = os.path.basename(psd_file).split('_psd_data.npz')[0]
            
            plt.semilogy(data['frequencies'], data['psd_real_mean'], 
                       alpha=0.2, color='blue')
            plt.semilogy(data['frequencies'], data['psd_imputed_mean'], 
                       alpha=0.2, color='red')
        
        # Add frequency band annotations
        bands = {
            'theta': (4, 8),
            'alpha': (8, 13),
            'beta': (13, 30),
            'gamma': (30, 100)
        }
        
        # Add frequency band shading
        colors = ['lightgray', 'lightblue', 'lightgreen', 'lightpink']
        for (band, (fmin, fmax)), color in zip(bands.items(), colors):
            plt.axvspan(fmin, fmax, color=color, alpha=0.2)
            plt.text((fmin + fmax)/2, plt.ylim()[0]*1.1, band, 
                    horizontalalignment='center', verticalalignment='bottom')
        
        plt.xlabel('Frequency (Hz)')
        plt.ylabel('Power Spectral Density')
        plt.title('Summary of All Subjects\nBlue: Real DBS, Red: Imputed DBS')
        plt.xlim(0, 125)
        plt.grid(True)
        
        # Save summary plot
        summary_path = os.path.join(psd_dir, 'summary_psd_comparison.png')
        plt.savefig(summary_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"Summary PSD plot saved to {summary_path}")
        
    except Exception as e:
        print(f"Error creating summary plot: {str(e)}")

In [None]:
import numpy as np
import torch
from scipy import signal
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from scipy.stats import pearsonr
from tqdm import tqdm
from torch.utils.data import DataLoader
import os

def get_all_predictions_fast(model, test_dataset, device='cuda', batch_size=64):
    """Get all real and imputed DBS data from test dataset - faster version with progress bars"""
    model.eval()
    
    # Create DataLoader
    loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    total_batches = len(loader)
    
    # Initialize lists
    real_dbs_all = []
    imputed_dbs_all = []
    
    # Main progress bar
    print(f"\nProcessing {len(test_dataset)} samples in {total_batches} batches...")
    pbar = tqdm(total=total_batches, desc="Generating predictions")
    
    with torch.no_grad():
        for batch in loader:
            # Move data to device
            ecog_batch = batch['cond'].to(device)
            real_dbs_batch = batch['signal'].cpu().numpy()
            
            # Generate predictions in batch
            imputed_dbs_batch = model.sample(
                num_samples=ecog_batch.size(0),
                cond=ecog_batch,
                noise_type="alpha_beta"
            ).cpu().numpy()
            
            # Store results
            real_dbs_all.append(real_dbs_batch)
            imputed_dbs_all.append(imputed_dbs_batch)
            
            pbar.update(1)
    
    pbar.close()
    
    # Concatenate all batches
    print("\nProcessing results...")
    real_dbs_all = np.concatenate(real_dbs_all, axis=0).squeeze()
    imputed_dbs_all = np.concatenate(imputed_dbs_all, axis=0).squeeze()
    
    # Calculate metrics (vectorized)
    print("Calculating metrics...")
    mse_scores = np.mean((real_dbs_all - imputed_dbs_all) ** 2, axis=1)
    
    # Calculate correlations with progress bar
    print("Calculating correlations...")
    corr_scores = []
    for i in tqdm(range(len(real_dbs_all)), desc="Computing correlations"):
        corr = pearsonr(real_dbs_all[i], imputed_dbs_all[i])[0]
        corr_scores.append(corr)
    corr_scores = np.array(corr_scores)
    
    return {
        'real_dbs': real_dbs_all,
        'imputed_dbs': imputed_dbs_all,
        'mse_scores': mse_scores,
        'corr_scores': corr_scores
    }

def analyze_predictions_fast(results, max_samples_psd=1000):
    """Analyze predictions faster by using subset for PSD"""
    real_dbs = results['real_dbs']
    imputed_dbs = results['imputed_dbs']
    mse_scores = results['mse_scores']
    corr_scores = results['corr_scores']
    
    # Print basic statistics
    print("\nOverall Statistics:")
    print(f"Number of test samples: {len(real_dbs)}")
    print(f"\nMSE scores:")
    print(f"Mean: {np.mean(mse_scores):.6f}")
    print(f"Std: {np.std(mse_scores):.6f}")
    print(f"Min: {np.min(mse_scores):.6f}")
    print(f"Max: {np.max(mse_scores):.6f}")
    
    print(f"\nCorrelation scores:")
    print(f"Mean: {np.mean(corr_scores):.6f}")
    print(f"Std: {np.std(corr_scores):.6f}")
    print(f"Min: {np.min(corr_scores):.6f}")
    print(f"Max: {np.max(corr_scores):.6f}")
    
    # Calculate PSD on a subset of samples
    print(f"\nCalculating PSDs on {max_samples_psd} samples...")
    indices = np.random.choice(len(real_dbs), min(max_samples_psd, len(real_dbs)), replace=False)
    
    # Calculate PSDs with progress bar
    psd_real_all = []
    psd_imputed_all = []
    
    for idx in tqdm(indices, desc="Computing PSDs"):
        f, psd_real = signal.welch(real_dbs[idx], fs=250, nperseg=256)
        _, psd_imputed = signal.welch(imputed_dbs[idx], fs=250, nperseg=256)
        psd_real_all.append(psd_real)
        psd_imputed_all.append(psd_imputed)
    
    psd_real_all = np.array(psd_real_all)
    psd_imputed_all = np.array(psd_imputed_all)
    
    # Calculate mean and std of PSDs
    psd_real_mean = np.mean(psd_real_all, axis=0)
    psd_real_std = np.std(psd_real_all, axis=0)
    psd_imputed_mean = np.mean(psd_imputed_all, axis=0)
    psd_imputed_std = np.std(psd_imputed_all, axis=0)
    
    print("\nGenerating plots...")
    
    # Plot average PSD with std
    plt.figure(figsize=(12, 6))
    plt.semilogy(f, psd_real_mean, 'b', label='Real DBS (mean)', alpha=0.7)
    plt.fill_between(f, 
                     psd_real_mean - psd_real_std, 
                     psd_real_mean + psd_real_std, 
                     color='b', alpha=0.2)
    plt.semilogy(f, psd_imputed_mean, 'r', label='Imputed DBS (mean)', alpha=0.7)
    plt.fill_between(f, 
                     psd_imputed_mean - psd_imputed_std, 
                     psd_imputed_mean + psd_imputed_std, 
                     color='r', alpha=0.2)
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Power Spectral Density')
    plt.title('Average Power Spectral Density with Std Dev')
    plt.xlim(0, 125)
    plt.legend()
    plt.grid(True)
    plt.show()
    
    # Plot histogram of correlation scores
    plt.figure(figsize=(8, 6))
    plt.hist(corr_scores, bins=50, alpha=0.7)
    plt.axvline(np.mean(corr_scores), color='r', linestyle='dashed', 
                label=f'Mean ({np.mean(corr_scores):.3f})')
    plt.xlabel('Correlation Coefficient')
    plt.ylabel('Count')
    plt.title('Distribution of Correlation Scores')
    plt.legend()
    plt.grid(True)
    plt.show()
    
    return {
        'f': f,
        'psd_real_mean': psd_real_mean,
        'psd_real_std': psd_real_std,
        'psd_imputed_mean': psd_imputed_mean,
        'psd_imputed_std': psd_imputed_std
    }

if __name__ == "__main__":
    # Output path
    output_path = r'D:\data_zixiao\raw_prediction_46\dbs_prediction_results.npz'
    
    # Create directory if it doesn't exist
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    # Start time
    start_time = time.time()
    
    # Get predictions
    print("Starting prediction generation...")
    results = get_all_predictions_fast(diffusion_model, test_dataset, batch_size=2048)
    
    # Analyze predictions
    print("\nStarting analysis...")
    psd_results = analyze_predictions_fast(results, max_samples_psd=1000)
    
    # Save results
    print(f"\nSaving results to {output_path}")
    np.savez(output_path,
             real_dbs=results['real_dbs'],
             imputed_dbs=results['imputed_dbs'],
             mse_scores=results['mse_scores'],
             corr_scores=results['corr_scores'],
             frequencies=psd_results['f'],
             psd_real_mean=psd_results['psd_real_mean'],
             psd_real_std=psd_results['psd_real_std'],
             psd_imputed_mean=psd_results['psd_imputed_mean'],
             psd_imputed_std=psd_results['psd_imputed_std'])
    
    # End time
    end_time = time.time()
    duration = end_time - start_time
    
    print(f"\nCompleted in {duration//60:.0f}m {duration%60:.1f}s")
    print(f"Results saved to: {output_path}")