In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import GCNConv
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
from torch_geometric.utils import dense_to_sparse

In [2]:
# Load data
train_data = torch.load('train.pt')
val_data = torch.load('val.pt')
test_data = torch.load('test.pt')

In [3]:
# Prepare dataset
def prepare_dataset(data):
    eeg_list = []
    stim_list = []
    for eeg, stim in data:
        eeg_list.append(eeg.float())           # (320, 64)
        stim_list.append(stim.float())         # (320,)
    eeg_tensor = torch.stack(eeg_list)         # (N, 320, 64)
    stim_tensor = torch.stack(stim_list)       # (N, 320)
    return eeg_tensor, stim_tensor

X_train, y_train = prepare_dataset(train_data)
X_val, y_val = prepare_dataset(val_data)
X_test, y_test = prepare_dataset(test_data)

In [4]:
# Hyperparameters
HYPERPARAMS = {
    'subset_ratio': 0.01,  # Use 1% of the data
    'num_epochs': 5,       # Low epoch for initial testing
    'batch_size': 16,
    'learning_rate': 0.001,
    'tfcn_hidden_channels': 32,  # TGCN hidden channels
    'cnn_filters': 64,          # CNN filters
    'lstm_hidden_size': 128,    # LSTM hidden size
    'lstm_num_layers': 2,       # LSTM layers
    'loop_iterations': 4,       # Number of TGCN+CNN+LSTM loops
    'cnn_combine_filters': 32,  # CNN filters in combine block
    'mlp_hidden_sizes': [128, 64],  # MLP hidden layers
    'dropout': 0.3,
    'num_nodes': 64,  # Number of EEG channels
    'time_steps': 320  # Time steps in EEG data
}


In [5]:
# Temporal Graph Convolutional Layer
class TGCN(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(TGCN, self).__init__()
        self.conv = GCNConv(in_channels, out_channels)
    
    def forward(self, x, edge_index):
        # x: (batch, time, nodes, features)
        batch_size, time_steps, nodes, features = x.size()
        x_out = []
        for t in range(time_steps):
            x_t = x[:, t, :, :]  # (batch, nodes, features)
            x_t = self.conv(x_t, edge_index)
            x_out.append(x_t.unsqueeze(1))  # (batch, 1, nodes, out_channels)
        return torch.cat(x_out, dim=1)  # (batch, time, nodes, out_channels)

In [6]:
# Single Block (TGCN + CNN + LSTM)
class SingleBlock(nn.Module):
    def __init__(self, in_channels, hidden_channels, cnn_filters, lstm_hidden_size, lstm_num_layers, num_nodes, time_steps, dropout):
        super(SingleBlock, self).__init__()
        self.layers = nn.ModuleList()
        current_channels = in_channels
        for _ in range(HYPERPARAMS['loop_iterations']):
            # TGCN
            tgcn = TGCN(current_channels, hidden_channels)
            # CNN
            cnn = nn.Sequential(
                nn.Conv2d(hidden_channels, cnn_filters, kernel_size=(3, 3), padding=1),
                nn.ReLU(),
                nn.Dropout(dropout)
            )
            # LSTM
            lstm = nn.LSTM(cnn_filters * num_nodes, lstm_hidden_size, lstm_num_layers, batch_first=True, dropout=dropout if lstm_num_layers > 1 else 0)
            self.layers.append(nn.ModuleDict({
                'tgcn': tgcn,
                'cnn': cnn,
                'lstm': lstm
            }))
            current_channels = lstm_hidden_size
        # Output layer
        self.out = nn.Linear(lstm_hidden_size, time_steps)
    
    def forward(self, x, edge_index):
        # x: (batch, time, nodes, features)
        for layer in self.layers:
            # TGCN
            x = layer['tgcn'](x, edge_index)  # (batch, time, nodes, hidden_channels)
            # CNN
            x = x.permute(0, 3, 1, 2)  # (batch, channels, time, nodes)
            x = layer['cnn'](x)  # (batch, cnn_filters, time, nodes)
            x = x.permute(0, 2, 1, 3)  # (batch, time, channels, nodes)
            batch, time, channels, nodes = x.size()
            x = x.reshape(batch, time, channels * nodes)  # (batch, time, channels*nodes)
            # LSTM
            x, _ = layer['lstm'](x)  # (batch, time, lstm_hidden_size)
        # Output
        x = self.out(x)  # (batch, time, time_steps)
        return x

In [7]:
# Combined Model
class EEGEnvModel(nn.Module):
    def __init__(self):
        super(EEGEnvModel, self).__init__()
        # Create adjacency matrix (simple identity for demo, replace with real graph)
        adj = torch.eye(HYPERPARAMS['num_nodes']).to(device)
        edge_index, _ = dense_to_sparse(adj)
        self.register_buffer('edge_index', edge_index)
        
        # Three parallel blocks
        self.block_pearson = SingleBlock(
            in_channels=1,  # Assuming single feature per node
            hidden_channels=HYPERPARAMS['tfcn_hidden_channels'],
            cnn_filters=HYPERPARAMS['cnn_filters'],
            lstm_hidden_size=HYPERPARAMS['lstm_hidden_size'],
            lstm_num_layers=HYPERPARAMS['lstm_num_layers'],
            num_nodes=HYPERPARAMS['num_nodes'],
            time_steps=HYPERPARAMS['time_steps'],
            dropout=HYPERPARAMS['dropout']
        )
        self.block_cosine = SingleBlock(
            in_channels=1,
            hidden_channels=HYPERPARAMS['tfcn_hidden_channels'],
            cnn_filters=HYPERPARAMS['cnn_filters'],
            lstm_hidden_size=HYPERPARAMS['lstm_hidden_size'],
            lstm_num_layers=HYPERPARAMS['lstm_num_layers'],
            num_nodes=HYPERPARAMS['num_nodes'],
            time_steps=HYPERPARAMS['time_steps'],
            dropout=HYPERPARAMS['dropout']
        )
        self.block_mse = SingleBlock(
            in_channels=1,
            hidden_channels=HYPERPARAMS['tfcn_hidden_channels'],
            cnn_filters=HYPERPARAMS['cnn_filters'],
            lstm_hidden_size=HYPERPARAMS['lstm_hidden_size'],
            lstm_num_layers=HYPERPARAMS['lstm_num_layers'],
            num_nodes=HYPERPARAMS['num_nodes'],
            time_steps=HYPERPARAMS['time_steps'],
            dropout=HYPERPARAMS['dropout']
        )
        # Combine block
        self.combine_cnn = nn.Sequential(
            nn.Conv1d(3 * HYPERPARAMS['time_steps'], HYPERPARAMS['cnn_combine_filters'], kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Dropout(HYPERPARAMS['dropout'])
        )
        mlp_layers = []
        in_features = HYPERPARAMS['cnn_combine_filters'] * HYPERPARAMS['time_steps']
        for hidden_size in HYPERPARAMS['mlp_hidden_sizes']:
            mlp_layers.extend([
                nn.Linear(in_features, hidden_size),
                nn.ReLU(),
                nn.Dropout(HYPERPARAMS['dropout'])
            ])
            in_features = hidden_size
        mlp_layers.append(nn.Linear(in_features, HYPERPARAMS['time_steps']))
        self.mlp = nn.Sequential(*mlp_layers)
    
    def forward(self, x):
        # x: (batch, time, nodes)
        x = x.unsqueeze(-1)  # (batch, time, nodes, 1)
        # Parallel blocks
        out_pearson = self.block_pearson(x, self.edge_index)  # (batch, time, time_steps)
        out_cosine = self.block_cosine(x, self.edge_index)
        out_mse = self.block_mse(x, self.edge_index)
        # Combine
        combined = torch.cat([out_pearson, out_cosine, out_mse], dim=2)  # (batch, time, 3*time_steps)
        combined = combined.permute(0, 2, 1)  # (batch, 3*time_steps, time)
        combined = self.combine_cnn(combined)  # (batch, cnn_combine_filters, time)
        combined = combined.permute(0, 2, 1).reshape(combined.size(0), -1)  # (batch, time*cnn_combine_filters)
        out = self.mlp(combined)  # (batch, time_steps)
        return out_pearson, out_cosine, out_mse, out

In [8]:
# Loss Functions
def pearson_loss(pred, target):
    pred = pred - pred.mean(dim=1, keepdim=True)
    target = target - target.mean(dim=1, keepdim=True)
    num = (pred * target).sum(dim=1)
    denom = torch.sqrt((pred ** 2).sum(dim=1) * (target ** 2).sum(dim=1))
    return - (num / (denom + 1e-8)).mean()

cosine_loss = nn.CosineEmbeddingLoss()
mse_loss = nn.MSELoss()

In [9]:
# Training Function
def train_model(model, X_train, y_train, X_val, y_val):
    optimizer = optim.Adam(model.parameters(), lr=HYPERPARAMS['learning_rate'])
    train_losses = {'pearson': [], 'cosine': [], 'mse': [], 'combined': []}
    val_metrics = {'pearson': {'mse': [], 'pearson': [], 'cosine': []},
                   'cosine': {'mse': [], 'pearson': [], 'cosine': []},
                   'mse': {'mse': [], 'pearson': [], 'cosine': []},
                   'combined': {'mse': [], 'pearson': [], 'cosine': []}}
    
    # Subsample 1% of data
    num_samples = int(len(X_train) * HYPERPARAMS['subset_ratio'])
    indices = torch.randperm(len(X_train))[:num_samples]
    X_train_subset = X_train[indices].to(device)
    y_train_subset = y_train[indices].to(device)
    
    for epoch in range(HYPERPARAMS['num_epochs']):
        model.train()
        epoch_losses = {'pearson': 0, 'cosine': 0, 'mse': 0, 'combined': 0}
        num_batches = 0
        
        # Training loop with tqdm
        for i in tqdm(range(0, len(X_train_subset), HYPERPARAMS['batch_size']), desc=f'Epoch {epoch+1}'):
            batch_X = X_train_subset[i:i+HYPERPARAMS['batch_size']]
            batch_y = y_train_subset[i:i+HYPERPARAMS['batch_size']]
            
            optimizer.zero_grad()
            out_pearson, out_cosine, out_mse, out_combined = model(batch_X)
            
            # Losses
            loss_pearson = pearson_loss(out_pearson, batch_y)
            loss_cosine = cosine_loss(out_cosine.reshape(-1), batch_y.reshape(-1), torch.ones(batch_y.size(0)).to(device))
            loss_mse = mse_loss(out_mse, batch_y)
            loss_combined = mse_loss(out_combined, batch_y)  # Combined block uses MSE
            
            total_loss = loss_pearson + loss_cosine + loss_mse + loss_combined
            total_loss.backward()
            optimizer.step()
            
            epoch_losses['pearson'] += loss_pearson.item()
            epoch_losses['cosine'] += loss_cosine.item()
            epoch_losses['mse'] += loss_mse.item()
            epoch_losses['combined'] += loss_combined.item()
            num_batches += 1
        
        # Average losses
        for key in epoch_losses:
            train_losses[key].append(epoch_losses[key] / num_batches)
        
        # Validation
        model.eval()
        with torch.no_grad():
            out_pearson, out_cosine, out_mse, out_combined = model(X_val.to(device))
            outputs = {
                'pearson': out_pearson,
                'cosine': out_cosine,
                'mse': out_mse,
                'combined': out_combined
            }
            y_val_device = y_val.to(device)
            
            for block in outputs:
                pred = outputs[block].cpu().numpy()
                true = y_val.cpu().numpy()
                
                # MSE
                mse = np.mean((pred - true) ** 2)
                val_metrics[block]['mse'].append(mse)
                
                # Pearson
                pearsons = [pearsonr(pred[i], true[i])[0] for i in range(len(pred))]
                val_metrics[block]['pearson'].append(np.mean(pearsons))
                
                # Cosine
                cos_sim = np.mean([np.dot(pred[i], true[i]) / (np.linalg.norm(pred[i]) * np.linalg.norm(true[i]) + 1e-8) for i in range(len(pred))])
                val_metrics[block]['cosine'].append(cos_sim)
        
        # Print validation metrics
        print(f"\nEpoch {epoch+1} Validation Metrics:")
        for block in val_metrics:
            print(f"{block.capitalize()} Block - MSE: {val_metrics[block]['mse'][-1]:.4f}, "
                  f"Pearson: {val_metrics[block]['pearson'][-1]:.4f}, "
                  f"Cosine: {val_metrics[block]['cosine'][-1]:.4f}")
    
    return train_losses, val_metrics

In [10]:
# Plotting Function
def plot_results(train_losses, val_metrics, X_test, y_test, model):
    # Loss Curves
    fig, axes = plt.subplots(2, 2, figsize=(12, 8))
    axes = axes.flatten()
    metrics = ['pearson', 'cosine', 'mse', 'combined']
    for i, metric in enumerate(metrics):
        axes[i].plot(train_losses[metric], label='Train Loss')
        axes[i].plot(val_metrics[metric]['mse'], label='Val MSE')
        axes[i].set_title(f'{metric.capitalize()} Block Loss')
        axes[i].set_xlabel('Epoch')
        axes[i].set_ylabel('Loss')
        axes[i].legend()
    plt.tight_layout()
    plt.show()
    
    # Test Regression Plots
    model.eval()
    with torch.no_grad():
        _, _, _, out_combined = model(X_test[:5].to(device))
        pred = out_combined.cpu().numpy()
        true = y_test[:5].numpy()
    
    fig, axes = plt.subplots(5, 1, figsize=(10, 15))
    for i in range(5):
        axes[i].plot(true[i], label='Original')
        axes[i].plot(pred[i], label='Predicted')
        axes[i].set_title(f'Test Sample {i+1}')
        axes[i].legend()
    plt.tight_layout()
    plt.show()

In [14]:
# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize and train model
model = EEGEnvModel().to(device)
train_losses, val_metrics = train_model(model, X_train, y_train, X_val, y_val)

Epoch 1:   0%|                                                                                   | 0/58 [00:01<?, ?it/s]


ValueError: not enough values to unpack (expected 4, got 3)

In [None]:
# Plot results
plot_results(train_losses, val_metrics, X_test, y_test, model)