In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
from datetime import datetime
from hydra import initialize, compose
from ntd.diffusion_model import Diffusion
from ntd.networks import AdaConv
from ntd.utils.kernels_and_diffusion_utils import OUProcess
from trainer import train_ecog_dbs_model
from prediction import get_all_predictions_fast_simple
from utils import clear_gpu_memory
from pathlib import Path
from copy import deepcopy
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim.lr_scheduler import LinearLR

class ECoGDBSDataset(Dataset):
    def __init__(self, ecog_data, dbs_data, augment=False, augment_params=None):
        """
        Args:
            ecog_data: shape (N, 3, 1000)
            dbs_data: shape (N, 1, 1000)
            augment: whether to apply data augmentation
            augment_params: dictionary containing augmentation parameters
        """
        self.ecog_data = torch.tensor(ecog_data, dtype=torch.float32)
        self.dbs_data = torch.tensor(dbs_data, dtype=torch.float32)
        self.augment = augment
        self.augment_params = augment_params or {
            'noise_std': 0.01,
            'max_shift': 20,
            'dropout_prob': 0.1
        }
        
    def __len__(self):
        return len(self.ecog_data)
    
    def apply_augmentation(self, ecog, dbs):
        # Random noise
        if torch.rand(1) < 0.5:
            ecog = ecog + torch.randn_like(ecog) * self.augment_params['noise_std']
        
        # Random time shift
        if torch.rand(1) < 0.5:
            shift = np.random.randint(-self.augment_params['max_shift'], 
                                    self.augment_params['max_shift'])
            ecog = torch.roll(ecog, shift, dims=-1)
            dbs = torch.roll(dbs, shift, dims=-1)
        
        # Random channel dropout for ECoG
        if torch.rand(1) < self.augment_params['dropout_prob']:
            channel = np.random.randint(0, ecog.shape[0])
            ecog[channel] = ecog[channel] * (torch.rand_like(ecog[channel]) > 0.1).float()
        
        return ecog, dbs
    
    def __getitem__(self, idx):
        ecog = self.ecog_data[idx]
        dbs = self.dbs_data[idx]
        
        if self.augment:
            ecog, dbs = self.apply_augmentation(ecog, dbs)
            
        return {
            'cond': ecog,
            'signal': dbs
        }

def train_base_model(train_dataset, val_dataset, config):
    """Train the base model with the initial dataset using existing training function"""
    print("Training base model...")
    diffusion_model, _, _ = train_ecog_dbs_model(
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        config=config
    )
    return diffusion_model

class TransferLearningTrainer:
    def __init__(self, base_model, config, device='cuda'):
        self.base_model = base_model.to(device)
        self.config = config
        self.device = device
        self.lambda_reg = 0.01
        self.scaler = torch.cuda.amp.GradScaler()
    
    def compute_regularization_loss(self, current_model):
        """Compute L2 regularization loss towards base model weights"""
        reg_loss = 0
        for (name, param), (_, param_base) in zip(
            current_model.named_parameters(), 
            self.base_model.named_parameters()
        ):
            reg_loss += torch.nn.functional.mse_loss(param, param_base.detach())
        return reg_loss * self.lambda_reg
    
    def train_one_epoch(self, model, train_loader, optimizer, epoch):
        """Train for one epoch"""
        model.train()
        total_loss = 0
        batch_count = 0
        
        with tqdm(train_loader, desc=f'Epoch {epoch}') as pbar:
            for batch in pbar:
                optimizer.zero_grad(set_to_none=True)
                
                # Move batch to device
                sig_batch = batch["signal"].to(self.device, non_blocking=True)
                cond_batch = batch["cond"].to(self.device, non_blocking=True)
                
                # Forward pass using train_batch instead of get_loss
                with torch.cuda.amp.autocast():
                    loss = model.train_batch(sig_batch, cond=cond_batch)
                    loss = torch.mean(loss)  # Average the loss
                    
                    # Add regularization loss
                    reg_loss = self.compute_regularization_loss(model)
                    total_loss = loss + reg_loss
                
                # Backward pass with gradient scaling
                self.scaler.scale(total_loss).backward()
                self.scaler.step(optimizer)
                self.scaler.update()
                
                pbar.set_postfix({'loss': total_loss.item()})
                batch_count += 1
                
                if batch_count % 50 == 0:
                    clear_gpu_memory()
                
        return total_loss / batch_count
    
    def validate(self, model, val_loader):
        """Validate the model"""
        model.eval()
        total_loss = 0
        batch_count = 0
        
        with torch.no_grad():
            for batch in val_loader:
                sig_batch = batch["signal"].to(self.device, non_blocking=True)
                cond_batch = batch["cond"].to(self.device, non_blocking=True)
                
                with torch.cuda.amp.autocast():
                    loss = model.train_batch(sig_batch, cond=cond_batch)
                    loss = torch.mean(loss)
                
                total_loss += loss.item()
                batch_count += 1
                
        return total_loss / batch_count
    
    def gradual_unfreeze_finetune(self, model, train_loader, val_loader, finetune_config):
        """Implement gradual unfreezing during fine-tuning"""
        best_val_loss = float('inf')
        best_model_state = None
        
        # First freeze all blocks
        for param in model.network.parameters():
            param.requires_grad = False
            
        # Gradually unfreeze blocks from last to first
        num_blocks = len(model.network.blocks)
        
        for block_idx in range(num_blocks-1, -1, -1):
            print(f"\nUnfreezing block {block_idx}")
            
            # Unfreeze current block
            for param in model.network.blocks[block_idx].parameters():
                param.requires_grad = True
                
            # Create optimizer and scheduler
            optimizer = torch.optim.AdamW(
                filter(lambda p: p.requires_grad, model.parameters()),
                lr=finetune_config.optimizer.lr,
                weight_decay=finetune_config.optimizer.weight_decay
            )
            
            scheduler = torch.optim.lr_scheduler.OneCycleLR(
                optimizer,
                max_lr=finetune_config.optimizer.lr,
                epochs=finetune_config.optimizer.num_epochs_per_block,
                steps_per_epoch=len(train_loader)
            )
            
            # Train for specified epochs with current frozen state
            for epoch in range(finetune_config.optimizer.num_epochs_per_block):
                train_loss = self.train_one_epoch(model, train_loader, optimizer, epoch)
                val_loss = self.validate(model, val_loader)
                scheduler.step()
                
                print(f"Block {block_idx}, Epoch {epoch}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
                
                # Save best model
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    best_model_state = deepcopy(model.state_dict())
        
        # Load best model state
        model.load_state_dict(best_model_state)
        return model, best_val_loss

def finetune_and_evaluate(base_model, new_subject_data, new_subject_labels, output_dir, new_sub_name, config,
                         finetune_epoch, batch_size):
    """Main function for fine-tuning and evaluation"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Create splits
    n_samples = len(new_subject_data)
    train_size = int(0.6 * n_samples)
    val_size = int(0.2 * n_samples)
    test_size = n_samples - train_size - val_size
    
    # Create datasets with augmentation for training
    train_dataset = ECoGDBSDataset(
        new_subject_data[:train_size], 
        new_subject_labels[:train_size],
        augment=True
    )
    val_dataset = ECoGDBSDataset(
        new_subject_data[train_size:train_size+val_size],
        new_subject_labels[train_size:train_size+val_size]
    )
    test_dataset = ECoGDBSDataset(
        new_subject_data[train_size+val_size:],
        new_subject_labels[train_size+val_size:]
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.optimizer.train_batch_size,
        shuffle=True,
        num_workers=0,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.optimizer.train_batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )
    
    # Modify config for fine-tuning
    finetune_config = deepcopy(config)
    finetune_config.optimizer.lr *= 0.1
    finetune_config.optimizer.warmup_epochs = 5
    finetune_config.optimizer.num_epochs_per_block = 20
    finetune_config.optimizer.num_epochs = finetune_epoch
    
    # Initialize trainer
    trainer = TransferLearningTrainer(base_model, finetune_config, device)
    
    # Fine-tune with gradual unfreezing
    finetuned_model, best_val_loss = trainer.gradual_unfreeze_finetune(
        deepcopy(base_model),
        train_loader,
        val_loader,
        finetune_config
    )
    
    # Generate predictions on test set
    print("Generating predictions...")
    results = get_all_predictions_fast_simple(finetuned_model, test_dataset, batch_size=batch_size)
    
    # Save model and results
    save_path = os.path.join(output_dir, f"finetuned_model_{new_sub_name}.pt")
    torch.save({
        'model_state_dict': finetuned_model.state_dict(),
        'best_val_loss': best_val_loss,
        'config': finetune_config
    }, save_path)
    
    results_path = os.path.join(output_dir, f"predicted_results_{new_sub_name}.npy")
    np.save(results_path, np.stack([results['real_dbs'], results['imputed_dbs']]))
    
    return results, finetuned_model
    
def load_subject_data(subject_id, data_dir):
    """Load DBS and ECoG data for a single subject"""
    dbs_path = os.path.join(data_dir, f'{subject_id}_dbs.npy')
    ecog_path = os.path.join(data_dir, f'{subject_id}_ecog.npy')
    
    dbs_data = np.load(dbs_path).astype(np.float32)  # Shape: (n_samples, 1, 1000)
    ecog_data = np.load(ecog_path).astype(np.float32)  # Shape: (n_samples, 3, 1000)
    
    return dbs_data, ecog_data

def create_train_test_split(test_subject, data_dir):
    """Create training and testing datasets using leave-one-subject-out"""
    # Initialize empty lists for training data
    train_dbs = []
    train_ecog = []
    
    # Load test subject data
    test_dbs, test_ecog = load_subject_data(test_subject, data_dir)
    
    # Load all other subjects' data for training
    for subject_id in subject_ids:
        if subject_id != test_subject:
            dbs_data, ecog_data = load_subject_data(subject_id, data_dir)
            train_dbs.append(dbs_data)
            train_ecog.append(ecog_data)
    
    # Concatenate all training data
    train_dbs = np.concatenate(train_dbs, axis=0)
    train_ecog = np.concatenate(train_ecog, axis=0)
    
    return (train_dbs, train_ecog), (test_dbs, test_ecog)
    #return (train_dbs[:2000,:,:], train_ecog[:2000,:,:]), (test_dbs[:2000,:,:], test_ecog[:2000,:,:])

def create_model(config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Initialize model components
    network = AdaConv(
        signal_length=config.dataset.signal_length,
        signal_channel=config.network.signal_channel,
        cond_dim=config.network.cond_dim,
        hidden_channel=config.network.hidden_channel,
        in_kernel_size=config.network.in_kernel_size,
        out_kernel_size=config.network.out_kernel_size,
        slconv_kernel_size=config.network.slconv_kernel_size,
        num_scales=config.network.num_scales,
        num_blocks=config.network.num_blocks,
        num_off_diag=config.network.num_off_diag,
        use_pos_emb=config.network.use_pos_emb,
        padding_mode=config.network.padding_mode,
        use_fft_conv=config.network.use_fft_conv,
    ).to(device)
    
    ou_process = OUProcess(
        config.diffusion_kernel.sigma_squared,
        config.diffusion_kernel.ell,
        config.dataset.signal_length
    ).to(device)
    
    diffusion = Diffusion(
        network=network,
        noise_sampler=ou_process,
        mal_dist_computer=ou_process,
        diffusion_time_steps=config.diffusion.diffusion_steps,
        schedule=config.diffusion.schedule,
        start_beta=config.diffusion.start_beta,
        end_beta=config.diffusion.end_beta,
    ).to(device)
    
    return diffusion

In [None]:
def analyze_data_characteristics(ecog_data, dbs_data, subject_id=None):
    """
    Analyze characteristics of ECoG and DBS data to inform augmentation parameters.
    
    Args:
        ecog_data: shape (N, 3, 1000)
        dbs_data: shape (N, 1, 1000)
        subject_id: optional subject identifier for plotting
    """
    results = {}
    
    # Basic statistics
    results['ecog_mean'] = np.mean(ecog_data, axis=(0, 2))
    results['ecog_std'] = np.std(ecog_data, axis=(0, 2))
    results['dbs_mean'] = np.mean(dbs_data, axis=(0, 2))
    results['dbs_std'] = np.std(dbs_data, axis=(0, 2))
    
    # Compute signal ranges
    results['ecog_range'] = np.percentile(ecog_data, [1, 99], axis=(0, 2))
    results['dbs_range'] = np.percentile(dbs_data, [1, 99], axis=(0, 2))
    
    # Temporal characteristics
    results['temporal_stats'] = {}
    
    # Compute average autocorrelation for each channel
    for i in range(ecog_data.shape[1]):
        autocorr = np.array([np.correlate(ecog_data[j, i], ecog_data[j, i], mode='full') 
                            for j in range(min(1000, ecog_data.shape[0]))])
        results['temporal_stats'][f'ecog_ch{i}_autocorr'] = np.mean(autocorr, axis=0)
    
    autocorr_dbs = np.array([np.correlate(dbs_data[j, 0], dbs_data[j, 0], mode='full') 
                            for j in range(min(1000, dbs_data.shape[0]))])
    results['temporal_stats']['dbs_autocorr'] = np.mean(autocorr_dbs, axis=0)
    
    # Cross-correlation between channels
    results['cross_corr'] = np.zeros((3, 3))
    for i in range(3):
        for j in range(3):
            if i != j:
                corr = np.array([np.corrcoef(ecog_data[k, i], ecog_data[k, j])[0, 1] 
                               for k in range(ecog_data.shape[0])])
                results['cross_corr'][i, j] = np.mean(corr)    
    return results

def suggest_augmentation_params(results):
    """
    Suggest augmentation parameters based on data characteristics.
    """
    suggestions = {}
    
    # Noise level based on signal std
    ecog_noise_std = np.mean(results['ecog_std']) * 0.1
    suggestions['noise_std'] = float(ecog_noise_std)
    
    # Max shift based on autocorrelation
    autocorr_ecog = np.mean([results['temporal_stats'][f'ecog_ch{i}_autocorr'] 
                            for i in range(3)], axis=0)
    lag_threshold = 0.5  # correlation threshold
    max_lag = np.where(autocorr_ecog < lag_threshold * autocorr_ecog.max())[0]
    if len(max_lag) > 0:
        suggestions['max_shift'] = int(max_lag[0])
    else:
        suggestions['max_shift'] = 20  # default value
    
    # Dropout probability based on cross-correlations
    mean_cross_corr = np.mean(np.abs(results['cross_corr']))
    suggestions['dropout_prob'] = float(min(0.2, 1 - mean_cross_corr))
    
    return suggestions

def analyze_subject_data(data_dir, subject_id):
    """
    Load and analyze data for a specific subject.
    """
    dbs_data, ecog_data = load_subject_data(subject_id, data_dir)
    results = analyze_data_characteristics(ecog_data, dbs_data, subject_id)
    suggestions = suggest_augmentation_params(results)
    
    print(f"\nSuggested augmentation parameters for subject {subject_id}:")
    print(f"noise_std: {suggestions['noise_std']:.4f}")
    print(f"max_shift: {suggestions['max_shift']}")
    print(f"dropout_prob: {suggestions['dropout_prob']:.4f}")
    
    return results, suggestions

In [None]:
# Define the data directory
data_dir = r'E:\data_zixiao\uscf_npy_3d_4s_nor_rmbad_9'

# Get all subject IDs by looking at the dbs files
dbs_files = sorted([f for f in os.listdir(data_dir) if f.endswith('_dbs.npy')])
subject_ids = [f.split('_dbs.npy')[0] for f in dbs_files]

# Initialize config
with initialize(version_base=None, config_path = "../ecog_stn_icnworkstation/conf"):
    config = compose(config_name="62_config_ou_200_more_complex_transfer")
output_dir=r"E:\data_zixiao\raw_prediction_62_3"

In [None]:
for test_subject in tqdm(subject_ids):
    print(f"\nProcessing subject: {test_subject}")
    
    # Get train and test data
    (dbs_data, ecog_data), (new_subject_dbs, new_subject_ecog) = create_train_test_split(
        test_subject, data_dir
    )
    
    print(f"Training data shapes:")
    print(f"DBS: {dbs_data.shape}, ECoG: {ecog_data.shape}")
    print(f"Testing data shapes:")
    print(f"DBS: {new_subject_dbs.shape}, ECoG: {new_subject_ecog.shape}")

    results, suggestions = analyze_subject_data(data_dir, test_subject)

    # Update config with suggested parameters
    config.dataset.augmentation.noise_std = suggestions['noise_std']
    config.dataset.augmentation.max_shift = suggestions['max_shift']
    config.dataset.augmentation.dropout_prob = suggestions['dropout_prob']
    
    
    # Create full dataset and split for base model training
    full_dataset = ECoGDBSDataset(ecog_data, dbs_data)
    train_size = int(0.7 * len(full_dataset))
    train_dataset, val_dataset = random_split(
        full_dataset, 
        [train_size, len(full_dataset) - train_size]
    )
    # Train base model
    base_model = train_base_model(train_dataset, val_dataset, config)
    
    # Finetune and evaluate
    results, finetuned_model = finetune_and_evaluate(
        base_model,
        new_subject_ecog,
        new_subject_dbs,
        output_dir=output_dir,
        config=config,
        new_sub_name=test_subject,
        finetune_epoch=80,
        batch_size=1024
    )
    
    # Clear memory
    del dbs_data, ecog_data, new_subject_dbs, new_subject_ecog
    torch.cuda.empty_cache()