In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from scipy.stats import pearsonr, zscore
import torch.fft as fft
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from tqdm import tqdm
import copy
import gc
import ast
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

In [None]:
def apply_hilbert_torch(x, envelope=False, do_log=False, compute_val='power', data_srate=250):
    def hilbert_torch(x):
        N = x.size(-1)
        
        # Disable mixed precision for FFT
        with torch.cuda.amp.autocast(enabled=False):
            Xf = fft.fft(x.float(), dim=-1)
        
        h = torch.zeros(N, dtype=torch.complex64, device=x.device)
        if N % 2 == 0:
            h[0] = h[N // 2] = 1
            h[1:N // 2] = 2
        else:
            h[0] = 1
            h[1:(N + 1) // 2] = 2
        
        Xf_hilbert = Xf * h
        x_hilbert = fft.ifft(Xf_hilbert, dim=-1)
        
        return x_hilbert
    def angle_custom(z):
        return torch.atan2(z.imag, z.real)
    def unwrap(p, discont=np.pi):
        dp = p[..., 1:] - p[..., :-1]
        ddp = torch.remainder(dp + np.pi, 2 * np.pi) - np.pi
        ddp[torch.abs(dp) < discont] = 0
        p_unwrapped = p.clone()
        p_unwrapped[..., 1:] = p[..., 0][..., None] + torch.cumsum(dp + ddp, dim=-1)
        return p_unwrapped
    def diff(x):
        return x[..., 1:] - x[..., :-1]
    n_x = x.size(-1)
    hilb_sig = hilbert_torch(x)
    
    if compute_val == 'power':
        out = torch.abs(hilb_sig)
        if do_log:
            out = torch.log1p(out)
    elif compute_val == 'phase':
        out = unwrap(angle_custom(hilb_sig))
    elif compute_val == 'freqslide':
        ang = angle_custom(hilb_sig)
        ang = data_srate * diff(unwrap(ang)) / (2 * np.pi)
        out = torch.nn.functional.pad(ang, (0, 1), mode='constant')
        # TO DO: apply median filter (use torch.median or a custom implementation)
    return out
    
class CtxNet(nn.Module):
    def __init__(self, Chans=3, Samples=375, dropoutRate=0.65, kernLength=64, F1=4, 
                 D=2, F2=8, F3=16, norm_rate=0.25, kernLength_sep=16,
                 do_log=False, data_srate=1, base_split=4):
        super(CtxNet, self).__init__()
        self.do_log = do_log
        self.data_srate = data_srate
        
        # Block 1 remains the same
        self.block1 = nn.Sequential(
            nn.Conv2d(1, F1, (1, kernLength), padding='same', bias=False),
            nn.BatchNorm2d(F1),
            nn.Conv2d(F1, F1*D, (Chans, 1), groups=F1, bias=False, padding='same'),
            nn.BatchNorm2d(F1*D),
            nn.ELU(),
            nn.AvgPool2d((1, 4)),
            nn.Dropout(dropoutRate)
        )
        
        # Block 2 remains the same
        self.block2 = nn.Sequential(
            nn.Conv2d(F1*D, F2, (1, kernLength_sep), bias=False, padding='same'),
            nn.BatchNorm2d(F2),
            nn.ELU(),
            nn.AvgPool2d((1, 8)),
            nn.Dropout(dropoutRate)
        )
        
        # New Block 3
        self.block3 = nn.Sequential(
            nn.Conv2d(F2, F3, (1, kernLength_sep//2), bias=False, padding='same'),
            nn.BatchNorm2d(F3),
            nn.ELU(),
            nn.AvgPool2d((1, 4)),
            nn.Dropout(dropoutRate)
        )
        
        self.flatten = nn.Flatten()
        
        # Modified to account for block3
        flatten_size = self.calculate_flatten_size(Chans, Samples, F3)
        
        # Enhanced dense layers
        self.dense = nn.Sequential(
            nn.Linear(flatten_size, 128),
            nn.BatchNorm1d(128),
            nn.ELU(),
            nn.Dropout(dropoutRate),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ELU(),
            nn.Dropout(dropoutRate)
        )
        
        self.output = nn.Sequential(
                        nn.BatchNorm1d(64),
                        nn.Linear(64, 1)
                                    )

    def calculate_flatten_size(self, Chans, Samples, F3):
        with torch.no_grad():
            x = torch.randn(1, 1, Chans, Samples)
            x = self.block1(x)
            x = self.block2(x)
            x = self.block3(x)
            return x.numel()

    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.block1(x)
        x = self.apply_hilbert(x)
        x = self.block2(x)
        x = self.block3(x)  # Added block3 to forward pass
        x = self.flatten(x)
        x = self.dense(x)
        x = self.output(x)
        return x

    def apply_hilbert(self, x):
        return apply_hilbert_torch(x, do_log=self.do_log, compute_val='power', data_srate=self.data_srate)

def create_model(Chans, Samples=375, dropoutRate=0.65, kernLength=64, F1=4, D=2, F2=8):
    model = CtxNet(Chans=Chans, Samples=Samples, dropoutRate=dropoutRate, 
                 kernLength=kernLength, F1=F1, D=D, F2=F2)
    return model

def correlation_loss(y_true, y_pred, epsilon=1e-8):
    vx = y_pred - y_pred.mean()
    vy = y_true - y_true.mean()
    corr = (vx * vy).sum() / ((torch.sqrt((vx ** 2).sum()) * torch.sqrt((vy ** 2).sum())) + epsilon)
    return 1 - corr
    
class CustomDataset(Dataset):
    def __init__(self, X, y, normalize=False, scaler=None):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
        
        if normalize:
            # Normalize X across the time dimension
            self.X = (self.X - self.X.mean(dim=-1, keepdim=True)) / (self.X.std(dim=-1, keepdim=True) + 1e-8)
            
            if scaler is None:
                # Create new scaler
                self.scaler = {
                    'y_mean': float(self.y.mean()),
                    'y_std': float(self.y.std() + 1e-8)
                }
            else:
                # Use provided scaler
                self.scaler = scaler
            
            # Normalize y using simple standardization
            self.y = (self.y - self.scaler['y_mean']) / self.scaler['y_std']
        else:
            self.scaler = None

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]
    
    def inverse_transform_y(self, y):
        """Transform normalized y back to original scale"""
        if isinstance(y, np.ndarray):
            y = torch.from_numpy(y).float()
        if self.scaler is not None:
            return y * self.scaler['y_std'] + self.scaler['y_mean']
        return y

def train_base_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=100):
   device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
   model.to(device)
   
   best_model = None
   best_loss = float('inf')
   best_correlation = -1
   scaler = GradScaler()
   
   # Learning rate scheduling
   initial_lr = optimizer.param_groups[0]['lr']
   warmup_epochs = 5
   scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=7, verbose=True)
   
   # Early stopping
   patience = 15
   epochs_without_improvement = 0
   
   for epoch in range(num_epochs):
       # Warmup learning rate
       if epoch < warmup_epochs:
           lr = initial_lr * (epoch + 1) / warmup_epochs
           for param_group in optimizer.param_groups:
               param_group['lr'] = lr
               
       # Training phase
       model.train()
       train_loss = 0.0
       train_outputs = []
       train_targets = []
       
       for inputs, targets in train_loader:
           inputs, targets = inputs.to(device), targets.to(device)
           optimizer.zero_grad()
           
           with autocast():
               outputs = model(inputs)
               loss = criterion(outputs.squeeze(), targets)
           
           scaler.scale(loss).backward()
           
           # Gradient clipping
           scaler.unscale_(optimizer)
           torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
           
           scaler.step(optimizer)
           scaler.update()
           
           train_loss += loss.item()
           train_outputs.extend(outputs.squeeze().detach().cpu().numpy())
           train_targets.extend(targets.cpu().numpy())
           
       train_loss /= len(train_loader)
       train_correlation = np.corrcoef(train_outputs, train_targets)[0,1]
       
       # Validation phase
       model.eval()
       val_loss = 0.0
       val_outputs = []
       val_targets = []
       
       with torch.no_grad():
           for inputs, targets in val_loader:
               inputs, targets = inputs.to(device), targets.to(device)
               outputs = model(inputs)
               loss = criterion(outputs.squeeze(), targets)
               val_loss += loss.item()
               val_outputs.extend(outputs.squeeze().cpu().numpy())
               val_targets.extend(targets.cpu().numpy())
       
       val_loss /= len(val_loader)
       val_correlation = np.corrcoef(val_outputs, val_targets)[0,1]
       
       # Learning rate scheduling
       scheduler.step(val_loss)
       
       # Model checkpointing
       if val_loss < best_loss:
           best_loss = val_loss
           best_correlation = val_correlation
           best_model = copy.deepcopy(model)
           epochs_without_improvement = 0
           
           torch.save({
               'epoch': epoch,
               'model_state_dict': model.state_dict(),
               'optimizer_state_dict': optimizer.state_dict(),
               'loss': best_loss,
               'correlation': best_correlation,
               'scaler_state_dict': scaler.state_dict()
           }, 'best_model.pth')
       else:
           epochs_without_improvement += 1
           if epochs_without_improvement >= patience:
               print(f'Early stopping at epoch {epoch+1}')
               break
       
       print(f'Epoch {epoch+1}/{num_epochs}')
       print(f'Train Loss: {train_loss:.4f}, Train Correlation: {train_correlation:.4f}')
       print(f'Val Loss: {val_loss:.4f}, Val Correlation: {val_correlation:.4f}')
       print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}\n')
       
       # Memory cleanup
       torch.cuda.empty_cache()
       gc.collect()
   
   return best_model, best_loss, best_correlation

class DataNormalizer:
    def __init__(self, epsilon=1e-8):
        self.epsilon = epsilon
    
    def normalize(self, data):
        # Z-score normalization with epsilon to prevent division by zero
        mean = np.mean(data, axis=(0, 2), keepdims=True)
        std = np.std(data, axis=(0, 2), keepdims=True) + self.epsilon
        return (data - mean) / std

def freeze_layers(model, num_layers_to_freeze):
    """Freeze initial layers of the model"""
    for i, (name, param) in enumerate(model.named_parameters()):
        if i < num_layers_to_freeze:
            param.requires_grad = False
        else:
            param.requires_grad = True
    return model

class GradualUnfreeze:
    def __init__(self, model, total_epochs, unfreeze_layers_per_epoch):
        self.model = model
        self.total_epochs = total_epochs
        self.unfreeze_layers_per_epoch = unfreeze_layers_per_epoch
        self.frozen_params = [param for param in model.parameters() if not param.requires_grad]
        
    def step(self, epoch):
        if not self.frozen_params:
            return
        
        layers_to_unfreeze = int(epoch * self.unfreeze_layers_per_epoch)
        for i, param in enumerate(self.frozen_params):
            if i < layers_to_unfreeze:
                param.requires_grad = True
                
def finetune_model(base_model, train_loader, val_loader, criterion, num_epochs=75):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = copy.deepcopy(base_model)
    model.to(device)
    
    # Freeze initial layers
    model = freeze_layers(model, num_layers_to_freeze=6)  # Freeze first two blocks
    gradual_unfreeze = GradualUnfreeze(model, num_epochs, unfreeze_layers_per_epoch=2)
    
    # Optimizer with weight decay
    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=1e-6,
        weight_decay=0.05
    )
    
    # Cosine learning rate scheduler with warm restarts
    scheduler = CosineAnnealingWarmRestarts(
        optimizer,
        T_0=20,  # Restart every 20 epochs
        T_mult=2  # Double the restart interval after each restart
    )
    step_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
    # Gradient clipping
    max_grad_norm = 1.0
    
    best_model = None
    best_val_loss = float('inf')
    patience = 25
    min_delta = 0.001  # Minimum improvement required
    patience_counter = 0
    
    for epoch in range(num_epochs):
        gradual_unfreeze.step(epoch)
        model.train()
        train_loss = 0.0
        
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            
            # Forward pass with mixup augmentation
            if epoch > 20:  # Start mixup after 10 epochs
                lam = np.random.beta(0.4, 0.4)
                idx = torch.randperm(inputs.size(0))
                mixed_inputs = lam * inputs + (1 - lam) * inputs[idx]
                outputs = model(mixed_inputs)
                loss = lam * criterion(outputs.squeeze(), targets) + \
                       (1 - lam) * criterion(outputs.squeeze(), targets[idx])
            else:
                outputs = model(inputs)
                loss = criterion(outputs.squeeze(), targets)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()
            train_loss += loss.item()
        
        scheduler.step()
        step_scheduler.step()
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs.squeeze(), targets)
                val_loss += loss.item()
        
        train_loss /= len(train_loader)
        val_loss /= len(val_loader)
        
        print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
        
        if val_loss < best_val_loss - min_delta:  # Only count as improvement if val_loss decreases by min_delta
            best_val_loss = val_loss
            best_model = copy.deepcopy(model)
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f'Early stopping triggered after {epoch + 1} epochs')
                break
        
        torch.cuda.empty_cache()
        gc.collect()
    
    return best_model

def evaluate_model(model, test_loader, denormalize=True):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    
    all_predictions = []
    all_targets = []
    
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs = inputs.to(device)
            outputs = model(inputs).cpu()  # Move outputs back to CPU immediately
            
            if batch_idx == 0:
                print("\nFirst batch statistics (before denormalization):")
                print_data_stats(outputs, "Model outputs")
                print_data_stats(targets, "Targets")
            
            if denormalize and isinstance(test_loader.dataset, CustomDataset):
                outputs = test_loader.dataset.inverse_transform_y(outputs)
                targets = test_loader.dataset.inverse_transform_y(targets)
                
                if batch_idx == 0:
                    print("\nFirst batch statistics (after denormalization):")
                    print_data_stats(outputs, "Model outputs")
                    print_data_stats(targets, "Targets")
            
            all_predictions.extend(outputs.numpy())
            all_targets.extend(targets.numpy())
    
    predictions = np.array(all_predictions).squeeze()
    targets = np.array(all_targets)
    
    print("\nFinal statistics:")
    print_data_stats(torch.tensor(predictions), "Predictions")
    print_data_stats(torch.tensor(targets), "Targets")
    
    # Calculate correlation
    correlation = np.corrcoef(predictions, targets)[0, 1]
    print(f"\nCorrelation between predictions and targets: {correlation:.4f}")
    
    return predictions, targets
    
# Modified data loading section
def prepare_data(X_name, y_name, index):
    X_ini = np.load(X_name).astype(np.float32)[:,index,:]
    y_ini = np.load(y_name).astype(np.float32)[:,1]
    
    normalizer = DataNormalizer()
    X_normalized = normalizer.normalize(X_ini)
    y_normalized = (y_ini - np.mean(y_ini)) / (np.std(y_ini) + 1e-8)
    
    return X_normalized, y_normalized
    
def print_data_stats(data, name):
    if isinstance(data, np.ndarray):
        data = torch.from_numpy(data)
    print(f"\n{name} statistics:")
    print(f"Mean: {data.mean():.2e}")
    print(f"Std: {data.std():.2e}")
    print(f"Min: {data.min():.2e}")
    print(f"Max: {data.max():.2e}")

In [None]:
model_params = {
    'Chans': 3,
    'kernLength': 64,  # Reduced kernel size
    'F1': 32,         # Increased initial filters
    'F2': 64,
    'F3': 128,
    'kernLength_sep': 32,
    'dropoutRate': 0.5,
    'D': 4,           # Reduced depth multiplier
}

batch_size = 3072
finetune_epochs = 75
criterion = correlation_loss
ex_folder = r'E:\data_zixiao\raw_prediction_61_7'

for i in tqdm(range (len(df))):
    f_folder = df['folder'][i]
    f_name = df['file'][i]
    X_name = df['folder'][i] + '\\' + df['file'][i][:-4] + '_ecog.npy'
    y_name = f_folder + '\\' + df['file'][i][:-4] + '_tarstn.npy'
    index = ast.literal_eval(df['index_list'][i])
    
    X_new, y_new = prepare_data(X_name, y_name, index)
    
    # Split data
    split_idx = int(0.7 * len(X_new))
    X_finetune, X_test = X_new[:split_idx], X_new[split_idx:]
    y_finetune, y_test = y_new[:split_idx], y_new[split_idx:]
    
    # Create datasets with normalized data
    finetune_dataset = CustomDataset(X_finetune, y_finetune)
    test_dataset = CustomDataset(X_test, y_test)
    
    # Smaller batch size for better generalization
    finetune_loader = DataLoader(finetune_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    # Load and finetune model
    trained_base_model = CtxNet(**model_params)
    checkpoint = torch.load(r'C:\zixiao_data\Ctxnet_base_model_ucsf9_100.pth', weights_only=True)
    trained_base_model.load_state_dict(checkpoint['model_state_dict'])
    # Increase dropout
    for m in trained_base_model.modules():
        if isinstance(m, nn.Dropout):
            m.p = 0.8
    for param in trained_base_model.parameters():
        param.requires_grad = False
    # Only unfreeze final layers
    for param in trained_base_model.dense.parameters():
        param.requires_grad = True
    for param in trained_base_model.output.parameters():
        param.requires_grad = True
    finetuned_model = finetune_model(trained_base_model, finetune_loader, test_loader, criterion)
    # Evaluate
    y_pred, y_true = evaluate_model(finetuned_model, test_loader)
    np.save('%s\%s'%(ex_folder, f_name[:-8]+'_pred.npy'), np.stack([y_pred, y_true]))
    torch.cuda.empty_cache()
    gc.collect()