In [None]:
import os
import torch
import time
from glob import glob
from datetime import datetime
from tqdm import tqdm
from utils import format_time, clear_gpu_memory
from prediction import process_single_subject

def batch_process_all_subjects(input_dir, output_dir, config_path):
    """Process all subjects with comprehensive logging and visualization"""
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Get all ECoG files
    ecog_files = sorted(glob(os.path.join(input_dir, "*ecog.npy")))
    total_subjects = len(ecog_files)
    
    print(f"\nFound {total_subjects} subjects to process")
    time.sleep(1)  # Give time to read the message
    
    # Initialize progress tracking
    results_summary = []
    total_start_time = time.time()
    successful_processes = 0
    
    # Create main progress bar
    with tqdm(total=total_subjects, desc="Overall Progress") as pbar:
        for subject_idx, ecog_file in enumerate(ecog_files, 1):
            # Get corresponding DBS file
            dbs_file = ecog_file.replace('ecog.npy', 'dbs.npy')
            
            if os.path.exists(dbs_file):
                # Display progress information
                elapsed_time = time.time() - total_start_time
                if successful_processes > 0:
                    avg_time_per_subject = elapsed_time / successful_processes
                    estimated_remaining = avg_time_per_subject * (total_subjects - subject_idx + 1)
                    print(f"\nEstimated time remaining: {format_time(estimated_remaining)}")
                
                print(f"\nProcessing subject {subject_idx}/{total_subjects}")
                print(f"File: {os.path.basename(ecog_file)}")
                
                # Process subject
                success, duration = process_single_subject(
                    ecog_path=ecog_file,
                    dbs_path=dbs_file,
                    output_dir=output_dir,
                    config_path=config_path
                )
                
                if success:
                    successful_processes += 1
                
                # Store results
                results_summary.append({
                    'subject': os.path.basename(ecog_file),
                    'success': success,
                    'duration': duration,
                    'processed_at': datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                })
                
                # Update progress
                pbar.update(1)
                
                # Save progress summary
                with open(os.path.join(output_dir, 'processing_summary.txt'), 'w') as f:
                    f.write("Processing Summary:\n")
                    f.write(f"Total subjects: {total_subjects}\n")
                    f.write(f"Completed: {subject_idx}\n")
                    f.write(f"Successful: {successful_processes}\n")
                    f.write(f"Current time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
                    
                    for result in results_summary:
                        f.write(f"Subject: {result['subject']}\n")
                        f.write(f"Success: {result['success']}\n")
                        f.write(f"Duration: {format_time(result['duration'])}\n")
                        f.write(f"Processed at: {result['processed_at']}\n\n")
            else:
                print(f"Warning: No matching DBS file found for {ecog_file}")
                pbar.update(1)
    
    # Final summary
    total_duration = time.time() - total_start_time
    print(f"\nBatch processing completed!")
    print(f"Total time: {format_time(total_duration)}")
    print(f"Successfully processed: {successful_processes}/{total_subjects}")
    
    # Create PSD figure directory
    psd_dir = os.path.join(output_dir, 'psd_figures')
    os.makedirs(psd_dir, exist_ok=True)

if __name__ == "__main__":
    # Configuration
    input_dir = r'E:\data_zixiao\uscf_npy_3d_4s_nor_rmbad_9'
    output_dir = r'E:\data_zixiao\raw_prediction_55'
    config_path = "../ecog_stn_icnworkstation/conf"
    
    # Print system info
    print(f"PyTorch version: {torch.__version__}")
    print(f"CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"CUDA device: {torch.cuda.get_device_name(0)}")
        print(f"Initial GPU memory allocated: {torch.cuda.memory_allocated(0)/1e9:.2f} GB")
    
    # Process all subjects
    batch_process_all_subjects(input_dir, output_dir, config_path)