In [None]:
# Import required modules at the top of the file
import os
import time
import numpy as np
import torch
from glob import glob
from torch.utils.data import random_split, Subset, DataLoader
from dataset import ECoGDBSDataset
from hydra import compose, initialize
from utils import clear_gpu_memory, get_frequency_band_power
from visualization import plot_and_save_psd_comparison
from hyperparameter_tuning import tune_hyperparameters_safely
from trainer import train_ecog_dbs_model
from visualization import plot_and_save_psd_comparison, analyze_frequency_bands

def get_all_predictions_fast(model, test_dataset, batch_size=2048, device=None):
    """Memory-optimized and faster prediction generation"""
    model.eval()
    
    # Use model's device if none specified
    if device is None:
        device = model.device
    
    # Use larger batch size and enable cuda optimizations
    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True
        
    # Create DataLoader with optimized settings
    loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=2
    )
    
    # Pre-allocate lists with estimated size
    n_samples = len(test_dataset)
    real_dbs_all = []
    imputed_dbs_all = []
    
    print(f"\nProcessing {n_samples} samples in {len(loader)} batches...")
    
    try:
        with torch.no_grad(), torch.cuda.amp.autocast():  # Use mixed precision
            for batch_idx, batch in enumerate(loader):
                # Move data to device
                ecog_batch = batch['cond'].to(device, non_blocking=True)
                real_dbs_batch = batch['signal'].numpy()
                
                # Generate predictions
                imputed_dbs_batch = model.sample(
                    num_samples=ecog_batch.size(0),
                    cond=ecog_batch,
                    noise_type="alpha_beta"
                ).cpu().numpy()
                
                # Store results
                real_dbs_all.append(real_dbs_batch)
                imputed_dbs_all.append(imputed_dbs_batch)
                
                # Clear batch data
                del ecog_batch
                if batch_idx % 5 == 0:  # Reduce frequency of memory clearing
                    clear_gpu_memory()
                    print(f"Processed {batch_idx + 1}/{len(loader)} batches")
    
    except Exception as e:
        print(f"Error during prediction generation: {str(e)}")
        import traceback
        traceback.print_exc()
        clear_gpu_memory()
        raise
    
    # Concatenate results efficiently
    print("\nProcessing results...")
    real_dbs_all = np.concatenate(real_dbs_all, axis=0).squeeze()
    imputed_dbs_all = np.concatenate(imputed_dbs_all, axis=0).squeeze()
    
    # Calculate band powers more efficiently
    print("\nCalculating frequency band powers...")
    bands = ['theta', 'alpha', 'beta', 'gamma']
    band_powers_real = {band: [] for band in bands}
    band_powers_imputed = {band: [] for band in bands}
    
    # Process band powers in larger chunks
    chunk_size = 100
    for i in range(0, len(real_dbs_all), chunk_size):
        if i % 1000 == 0:
            print(f"Computing band powers: {i}/{len(real_dbs_all)}")
        
        # Process a chunk of signals at once
        chunk_slice = slice(i, min(i + chunk_size, len(real_dbs_all)))
        real_chunk = real_dbs_all[chunk_slice]
        imputed_chunk = imputed_dbs_all[chunk_slice]
        
        # Process each signal in the chunk
        for j in range(len(real_chunk)):
            powers_real = get_frequency_band_power(real_chunk[j])
            powers_imputed = get_frequency_band_power(imputed_chunk[j])
            
            for band in bands:
                band_powers_real[band].append(powers_real[band])
                band_powers_imputed[band].append(powers_imputed[band])
    
    # Convert to numpy arrays efficiently
    for band in bands:
        band_powers_real[band] = np.array(band_powers_real[band])
        band_powers_imputed[band] = np.array(band_powers_imputed[band])
    
    # Calculate and include additional metrics
    metrics = {
        'real_dbs': real_dbs_all,
        'imputed_dbs': imputed_dbs_all,
        'band_powers_real': band_powers_real,
        'band_powers_imputed': band_powers_imputed,
        'test_set_size': n_samples,
        'predictions_generated': len(real_dbs_all)
    }
    
    return metrics

def create_dataset_splits(full_dataset, train_percentage, test_percentage=0.2, val_ratio=0.2, seed=42):
    """
    Create train/val/test splits while maintaining a fixed test set and proportional validation set
    
    Args:
        full_dataset: Complete dataset
        train_percentage: Percentage of non-test data to use for training (0.1 to 0.9)
        test_percentage: Percentage of data to reserve for testing
        val_ratio: Ratio of validation set size relative to training set size
        seed: Random seed for reproducibility
    """
    dataset_size = len(full_dataset)
    test_size = int(test_percentage * dataset_size)
    non_test_size = dataset_size - test_size
    
    # First create the train+val and test split
    generator = torch.Generator().manual_seed(seed)
    non_test_dataset, test_dataset = random_split(
        full_dataset, 
        [non_test_size, test_size],
        generator=generator
    )
    
    # Calculate sizes for train and validation
    # The validation set size will be proportional to the training set size
    train_size = int(train_percentage * non_test_size)
    val_size = int(train_size * val_ratio)  # Validation set is 20% the size of training set
    
    # Ensure we don't exceed the available non-test data
    if train_size + val_size > non_test_size:
        val_size = non_test_size - train_size
    
    # Split non_test_dataset into train and validation
    train_dataset, val_dataset, _ = random_split(
        non_test_dataset,
        [train_size, val_size, non_test_size - train_size - val_size],
        generator=generator
    )
    
    return train_dataset, val_dataset, test_dataset

def plot_training_size_trends(results_summary, output_dir):
    """Plot trends in model performance across different training sizes"""
    import matplotlib.pyplot as plt
    
    train_sizes = [r['train_percentage'] * 100 for r in results_summary]
    
    # Plot correlation trends for each frequency band
    plt.figure(figsize=(12, 8))
    bands = ['theta', 'alpha', 'beta', 'gamma']
    for band in bands:
        correlations = [r['correlations'][band] for r in results_summary]
        plt.plot(train_sizes, correlations, marker='o', label=f'{band} band')
    
    plt.xlabel('Training Data Size (%)')
    plt.ylabel('Correlation Coefficient')
    plt.title('Model Performance vs Training Data Size')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(output_dir, 'training_size_trends.png'))
    plt.close()

def process_single_subject_with_varying_sizes(ecog_path, dbs_path, output_dir, config_path):
    """Process a single subject's data with different training set sizes"""
    # Extract subject ID
    subject_id = os.path.basename(ecog_path).split('_')[0]
    side = os.path.basename(ecog_path).split('_')[1][:4]
    full_subject_id = f"{subject_id}_{side}"
    
    try:
        # Create subject-specific output directory
        subject_output_dir = os.path.join(output_dir, full_subject_id)
        os.makedirs(subject_output_dir, exist_ok=True)
        
        print(f"\nProcessing subject {full_subject_id}")
        start_time = time.time()
        
        # Load data
        print("\nLoading data...")
        ecog_data = np.load(ecog_path)
        dbs_data = np.load(dbs_path)
        
        # Initialize config
        with initialize(version_base=None, config_path=config_path):
            base_config = compose(config_name="53_config_ou_200")
        
        # Add signal_channel and cond_dim to network config
        base_config.network.signal_channel = 1
        base_config.network.cond_dim = 3
        
        # Create full dataset
        full_dataset = ECoGDBSDataset(ecog_data, dbs_data)
        
        # Training percentages to evaluate (including full dataset)
        train_percentages = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
        results_summary = []
        
        for train_pct in train_percentages:
            print(f"\nEvaluating with {train_pct*100}% training data")
            
            # Create splits with consistent test set
            train_dataset, val_dataset, test_dataset = create_dataset_splits(
                full_dataset, 
                train_percentage=train_pct
            )
            
            print(f"Split sizes - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
            
            # Create size-specific output directory
            size_output_dir = os.path.join(subject_output_dir, f"train_{int(train_pct*100)}pct")
            os.makedirs(size_output_dir, exist_ok=True)
            
            # Run hyperparameter tuning with reduced trials for efficiency
            print("\nRunning hyperparameter tuning...")
            best_params, _, final_config = tune_hyperparameters_safely(
                ecog_data=ecog_data,
                dbs_data=dbs_data,
                base_config=base_config,
                output_dir=size_output_dir,
                n_trials=20,  # Reduced number of trials
                timeout=3600*2,  # Reduced timeout
                tuning_epochs=30,  # Reduced epochs for tuning
                final_epochs=200  # Reduced final epochs
            )
            
            # Train final model with best parameters
            print("\nTraining final model...")
            diffusion_model, _, _ = train_ecog_dbs_model(
                train_dataset=train_dataset,
                val_dataset=val_dataset,
                config=final_config
            )
            
            # Generate predictions on test set
            print("\nGenerating predictions on test set...")
            results = get_all_predictions_fast(diffusion_model, test_dataset, batch_size=2048)
            
            # Plot and save PSD comparison
            plot_and_save_psd_comparison(
                results, 
                f"{full_subject_id}_train{int(train_pct*100)}pct", 
                size_output_dir
            )
            
            # Analyze frequency bands
            correlations = analyze_frequency_bands(results)
            
            # Save results for this training size
            results_path = os.path.join(
                size_output_dir, 
                f"{subject_id}_{side}_train{int(train_pct*100)}pct_results.npz"
            )
            np.savez(
                results_path,
                real_dbs=results['real_dbs'],
                imputed_dbs=results['imputed_dbs'],
                band_powers_real=results['band_powers_real'],
                band_powers_imputed=results['band_powers_imputed'],
                best_hyperparameters=best_params,
                correlations=correlations,
                train_size_percentage=train_pct
            )
            
            # Store summary metrics
            results_summary.append({
                'train_percentage': train_pct,
                'train_samples': len(train_dataset),
                'correlations': correlations,
                'best_params': best_params
            })
            
            # Clear memory
            del diffusion_model
            clear_gpu_memory()
        
        # Save overall summary
        summary_path = os.path.join(subject_output_dir, 'training_size_analysis_summary.npz')
        np.savez(summary_path, results=results_summary)
        
        # Calculate and plot trends
        plot_training_size_trends(results_summary, subject_output_dir)
        
        duration = time.time() - start_time
        print(f"\nCompleted processing {full_subject_id}")
        print(f"Total duration: {duration:.2f} seconds")
        
        return True, duration
        
    except Exception as e:
        print(f"Error processing {full_subject_id}: {str(e)}")
        import traceback
        traceback.print_exc()
        clear_gpu_memory()
        return False, 0

def batch_process_all_subjects_with_varying_sizes(input_dir, output_dir, config_path):
    """Process all subjects with varying training sizes"""
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Get all ECoG files
    ecog_files = sorted(glob(os.path.join(input_dir, "*ecog.npy")))
    total_subjects = len(ecog_files)
    
    print(f"\nFound {total_subjects} subjects to process")
    
    # Initialize progress tracking
    results_summary = []
    total_start_time = time.time()
    successful_processes = 0
    
    for subject_idx, ecog_file in enumerate(ecog_files, 1):
        dbs_file = ecog_file.replace('ecog.npy', 'dbs.npy')
        
        if os.path.exists(dbs_file):
            success, duration = process_single_subject_with_varying_sizes(
                ecog_path=ecog_file,
                dbs_path=dbs_file,
                output_dir=output_dir,
                config_path=config_path
            )
            
            if success:
                successful_processes += 1
            
            results_summary.append({
                'subject': os.path.basename(ecog_file),
                'success': success,
                'duration': duration
            })
    
    # Save final summary
    print(f"\nBatch processing completed!")
    print(f"Successfully processed: {successful_processes}/{total_subjects}")
    
    return results_summary

In [None]:
# Configuration
input_dir = r'E:\data_zixiao\uscf_npy_3d_4s_nor_rmbad_9'
output_dir = r'E:\data_zixiao\raw_prediction_56'
config_path = "../ecog_stn_icnworkstation/conf"

# Print system info
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

# Process all subjects with varying training sizes
results = batch_process_all_subjects_with_varying_sizes(input_dir, output_dir, config_path)