In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
from datetime import datetime
from hydra import initialize, compose
from ntd.diffusion_model import Diffusion
from ntd.networks import AdaConv
from ntd.utils.kernels_and_diffusion_utils import OUProcess
from trainer import train_ecog_dbs_model
from prediction import get_all_predictions_fast_simple
from utils import clear_gpu_memory
from pathlib import Path

class ECoGDBSDataset(Dataset):
    def __init__(self, ecog_data, dbs_data):
        """
        Args:
            ecog_data: shape (N, 3, 1000)
            dbs_data: shape (N, 1, 1000)
        """
        self.ecog_data = torch.tensor(ecog_data, dtype=torch.float32)
        self.dbs_data = torch.tensor(dbs_data, dtype=torch.float32)
        
    def __len__(self):
        return len(self.ecog_data)
    
    def __getitem__(self, idx):
        return {
            'cond': self.ecog_data[idx],
            'signal': self.dbs_data[idx]
        }

def train_base_model(train_dataset, val_dataset, config):
    """Train the base model with the initial dataset using existing training function"""
    print("Training base model...")
    diffusion_model, _, _ = train_ecog_dbs_model(
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        config=config
    )
    return diffusion_model

def finetune_and_evaluate(base_model, new_subject_data, new_subject_labels, output_dir, new_sub_name, config,
                         finetune_epoch, batch_size):
    """Finetune the base model on new subject data and evaluate"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Split new subject data
    n_samples = len(new_subject_data)
    train_size = int(0.7 * n_samples)
    
    train_data = new_subject_data[:train_size]
    train_labels = new_subject_labels[:train_size]
    test_data = new_subject_data[train_size:]
    test_labels = new_subject_labels[train_size:]
    
    # Create datasets
    train_dataset = ECoGDBSDataset(train_data, train_labels)
    test_dataset = ECoGDBSDataset(test_data, test_labels)
    
    # Create data loader for fine-tuning
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.optimizer.train_batch_size,
        shuffle=True,
        num_workers=0,
        pin_memory=False,
        persistent_workers=False
    )
    
    # Modify config for fine-tuning
    finetune_config = config.copy()
    finetune_config.optimizer.lr *= 0.1  # Lower learning rate for fine-tuning
    finetune_config.optimizer.num_epochs = finetune_epoch  # Fewer epochs for fine-tuning
    
    # Fine-tune the model
    print("Fine-tuning model...")
    finetuned_model, _, _ = train_ecog_dbs_model(
        train_dataset=train_dataset,
        val_dataset=test_dataset,  # Use test dataset as validation
        config=finetune_config
    )
    
    # Generate predictions on test set
    print("Generating predictions...")
    results = get_all_predictions_fast_simple(finetuned_model, test_dataset, batch_size=batch_size)
    
    # Save model
    torch.save(base_model.state_dict(), os.path.join(output_dir, "finetuned_model_%s.pt"%(new_sub_name)))
    
    # Save predictions and actual values
    np.save(os.path.join(output_dir, "predicted_results_%s.npy"%(new_sub_name)), 
            np.stack([results['real_dbs'], results['imputed_dbs']]))
    return results

def load_subject_data(subject_id, data_dir):
    """Load DBS and ECoG data for a single subject"""
    dbs_path = os.path.join(data_dir, f'{subject_id}_dbs.npy')
    ecog_path = os.path.join(data_dir, f'{subject_id}_ecog.npy')
    
    dbs_data = np.load(dbs_path).astype(np.float32)  # Shape: (n_samples, 1, 1000)
    ecog_data = np.load(ecog_path).astype(np.float32)  # Shape: (n_samples, 3, 1000)
    
    return dbs_data, ecog_data

def create_train_test_split(test_subject, data_dir):
    """Create training and testing datasets using leave-one-subject-out"""
    # Initialize empty lists for training data
    train_dbs = []
    train_ecog = []
    
    # Load test subject data
    test_dbs, test_ecog = load_subject_data(test_subject, data_dir)
    
    # Load all other subjects' data for training
    for subject_id in subject_ids:
        if subject_id != test_subject:
            dbs_data, ecog_data = load_subject_data(subject_id, data_dir)
            train_dbs.append(dbs_data)
            train_ecog.append(ecog_data)
    
    # Concatenate all training data
    train_dbs = np.concatenate(train_dbs, axis=0)
    train_ecog = np.concatenate(train_ecog, axis=0)
    
    return (train_dbs, train_ecog), (test_dbs, test_ecog)
    #return (train_dbs[:2000,:,:], train_ecog[:2000,:,:]), (test_dbs[:2000,:,:], test_ecog[:2000,:,:])

def create_model(config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Initialize model components
    network = AdaConv(
        signal_length=config.dataset.signal_length,
        signal_channel=config.network.signal_channel,
        cond_dim=config.network.cond_dim,
        hidden_channel=config.network.hidden_channel,
        in_kernel_size=config.network.in_kernel_size,
        out_kernel_size=config.network.out_kernel_size,
        slconv_kernel_size=config.network.slconv_kernel_size,
        num_scales=config.network.num_scales,
        num_blocks=config.network.num_blocks,
        num_off_diag=config.network.num_off_diag,
        use_pos_emb=config.network.use_pos_emb,
        padding_mode=config.network.padding_mode,
        use_fft_conv=config.network.use_fft_conv,
    ).to(device)
    
    ou_process = OUProcess(
        config.diffusion_kernel.sigma_squared,
        config.diffusion_kernel.ell,
        config.dataset.signal_length
    ).to(device)
    
    diffusion = Diffusion(
        network=network,
        noise_sampler=ou_process,
        mal_dist_computer=ou_process,
        diffusion_time_steps=config.diffusion.diffusion_steps,
        schedule=config.diffusion.schedule,
        start_beta=config.diffusion.start_beta,
        end_beta=config.diffusion.end_beta,
    ).to(device)
    
    return diffusion

In [None]:
# Define the data directory
data_dir = r'E:\data_zixiao\uscf_npy_3d_4s_nor_rmbad_9'
# Get all subject IDs by looking at the dbs files
dbs_files = sorted([f for f in os.listdir(data_dir) if f.endswith('_dbs.npy')])
subject_ids = [f.split('_dbs.npy')[0] for f in dbs_files]

In [None]:
# Initialize config
with initialize(version_base=None, config_path = "../ecog_stn_icnworkstation/conf"):
    config = compose(config_name="54_config_ou_200_more_complex")
output_dir=r"E:\data_zixiao\raw_prediction_62_2"

In [None]:
def create_generalized_model(subject_ids, data_dir, config, output_dir):
    """Create a generalized model using data from all subjects with proper validation split"""
    # Initialize lists for all data
    all_dbs = []
    all_ecog = []
    all_subject_indices = []
    # Load all subjects' data
    for subject_id in subject_ids:
        dbs_data, ecog_data = load_subject_data(subject_id, data_dir)
        all_dbs.append(dbs_data)
        all_ecog.append(ecog_data)
        # Track which data belongs to which subject
        all_subject_indices.extend([subject_id] * len(dbs_data))
    # Concatenate all data
    all_dbs = np.concatenate(all_dbs, axis=0)
    all_ecog = np.concatenate(all_ecog, axis=0)
    all_subject_indices = np.array(all_subject_indices)
    # Get unique subjects and randomly select validation subjects
    unique_subjects = np.unique(all_subject_indices)
    n_val_subjects = max(1, int(0.3 * len(unique_subjects)))  # At least 1 subject for validation
    val_subjects = np.random.choice(unique_subjects, size=n_val_subjects, replace=False)
    # Create train/val masks based on subjects
    train_mask = ~np.isin(all_subject_indices, val_subjects)
    val_mask = np.isin(all_subject_indices, val_subjects)
    # Create datasets ensuring subject independence
    train_dataset = ECoGDBSDataset(all_ecog[train_mask], all_dbs[train_mask])
    val_dataset = ECoGDBSDataset(all_ecog[val_mask], all_dbs[val_mask])
    # Train generalized model
    model = train_base_model(train_dataset, val_dataset, config)
    # Save the model
    torch.save(model.state_dict(), os.path.join(output_dir, "generalized_model_usingall9ucsf.pt"))
    return model

# Usage
generalized_model = create_generalized_model(subject_ids, data_dir, config, output_dir)