In [50]:
import torch
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
import torch.nn as nn
from tqdm import tqdm
from sklearn.cluster import KMeans
from torch.utils.data import DataLoader, TensorDataset
from scipy.stats import pearsonr
import torch.nn.functional as F


In [2]:
train_data = torch.load('train.pt')
val_data = torch.load('val.pt')
test_data = torch.load('test.pt')

In [3]:
def prepare_dataset(data):
    eeg_list = []
    stim_list = []
    for eeg, stim in data:
        eeg_list.append(eeg.float())           # (320, 64)
        stim_list.append(stim.float())         # (320,)
    eeg_tensor = torch.stack(eeg_list)         # (N, 320, 64)
    stim_tensor = torch.stack(stim_list)       # (N, 320)
    return eeg_tensor, stim_tensor

In [4]:
X_train, y_train = prepare_dataset(train_data)
X_val, y_val = prepare_dataset(val_data)
X_test, y_test = prepare_dataset(test_data)

In [21]:
class EEGAutoencoder(nn.Module):
    def __init__(self, input_dim=320*64, latent_dim=32):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, latent_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, input_dim)
        )
    
    def forward(self, x):
        z = self.encoder(x)
        x_recon = self.decoder(z)
        return x_recon, z

def cluster_with_autoencoder(eeg_data, k=5, latent_dim=32, epochs=50, device='cuda'):
    

    N, T, C = eeg_data.shape
    data_flat = eeg_data.reshape(N, T * C)
    data_tensor = data_flat.clone().detach().float()
    loader = DataLoader(TensorDataset(data_tensor), batch_size=64, shuffle=True)

    model = EEGAutoencoder(input_dim=T*C, latent_dim=latent_dim).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.MSELoss()

    # Train autoencoder
    model.train()
    print("Training Autoencoder...")

    for epoch in range(epochs):
        epoch_loss = 0.0
        pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{epochs}")
        for batch, in pbar:
            batch = batch.to(device)
            recon, _ = model(batch)
            loss = loss_fn(recon, batch)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            pbar.set_postfix({"Loss": f"{loss.item():.4f}"})

    # Get latent features
    model.eval()
    with torch.no_grad():
        latent_all = model.encoder(data_tensor.to(device)).cpu().numpy()

    # KMeans on latent space
    kmeans = KMeans(n_clusters=k, random_state=0).fit(latent_all)
    labels = kmeans.labels_

    return labels, latent_all, model, kmeans

In [22]:
k = 5
device = 'cuda' if torch.cuda.is_available() else 'cpu'

cluster_labels, latent_embeddings, cluster_model, kmeans = cluster_with_autoencoder(X_train, k=k, latent_dim=32, epochs=5, device=device)

Training Autoencoder...


Epoch 1/5: 100%|██████████████████████████████████████████████████████| 1428/1428 [00:18<00:00, 75.29it/s, Loss=36.9907]
Epoch 2/5: 100%|██████████████████████████████████████████████████████| 1428/1428 [00:18<00:00, 75.93it/s, Loss=29.6252]
Epoch 3/5: 100%|██████████████████████████████████████████████████████| 1428/1428 [00:18<00:00, 75.29it/s, Loss=30.3687]
Epoch 4/5: 100%|██████████████████████████████████████████████████████| 1428/1428 [00:18<00:00, 76.33it/s, Loss=51.6567]
Epoch 5/5: 100%|██████████████████████████████████████████████████████| 1428/1428 [00:18<00:00, 76.69it/s, Loss=52.6811]


In [23]:
clustered_eeg = [X_train[cluster_labels == i] for i in range(k)]

In [24]:
for i in range(k):
    print(clustered_eeg[i].shape) 

torch.Size([71904, 320, 64])
torch.Size([299, 320, 64])
torch.Size([1020, 320, 64])
torch.Size([18065, 320, 64])
torch.Size([93, 320, 64])


In [25]:
clustered_y = [y_train[cluster_labels == i] for i in range(k)]

In [26]:
for i in range(k):
    print(clustered_y[i].shape) 

torch.Size([71904, 320])
torch.Size([299, 320])
torch.Size([1020, 320])
torch.Size([18065, 320])
torch.Size([93, 320])


In [28]:
X_val_flat = X_val.view(X_val.shape[0], -1).to(device)   # shape: (N_val, 320*64)
X_test_flat = X_test.view(X_test.shape[0], -1).to(device)


cluster_model.eval()
with torch.no_grad():
    latent_val = cluster_model.encoder(X_val_flat).cpu().numpy()   # shape: (N_val, latent_dim)
    latent_test = cluster_model.encoder(X_test_flat).cpu().numpy() # shape: (N_test, latent_dim)

    
val_cluster_labels = kmeans.predict(latent_val)
test_cluster_labels = kmeans.predict(latent_test)


clustered_val_eeg = [X_val[val_cluster_labels == i] for i in range(k)]
clustered_val_y   = [y_val[val_cluster_labels == i] for i in range(k)]

clustered_test_eeg = [X_test[test_cluster_labels == i] for i in range(k)]
clustered_test_y   = [y_test[test_cluster_labels == i] for i in range(k)]

In [29]:
def pearson_corr(pred, target):
    pred = pred.detach().cpu().numpy()
    target = target.detach().cpu().numpy()
    correlations = [pearsonr(p, t)[0] for p, t in zip(pred, target)]
    return sum(correlations) / len(correlations)

def cosine_sim(pred, target):
    return F.cosine_similarity(pred, target, dim=1).mean().item()

In [30]:
class LSTMExpert(nn.Module):
    def __init__(self, input_dim=64, hidden_dim=64, num_layers=1, output_dim=320):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):  # x: (B, 320, 64)
        out, _ = self.lstm(x)
        out = self.fc(out[:, -1, :])  # use last time step hidden state
        return out  # (B, 320)

In [31]:
experts = [LSTMExpert().to(device) for _ in range(k)]

In [32]:
epochs = 5

In [33]:
for i in range(k):
    train_loader = DataLoader(TensorDataset(clustered_eeg[i], clustered_y[i]), batch_size=32, shuffle=True)
    model = experts[i]
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.MSELoss()

    for epoch in range(epochs):
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            pred = model(x)
            loss = loss_fn(pred, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

In [43]:
class CombinedModel(nn.Module):
    def __init__(self, experts, combine_dim=320, output_dim=320):
        super().__init__()
        self.experts = nn.ModuleList(experts)
        self.combine_layer = nn.Sequential(
            nn.Conv1d(k, 16, kernel_size=3, padding=1),  # k experts’ outputs as channels
            nn.ReLU(),
            nn.Conv1d(16, 1, kernel_size=1),
        )
    
    def forward(self, x):  # x: (B, 320, 64)
        expert_outputs = [expert(x) for expert in self.experts]  # list of (B, 320)
        expert_stack = torch.stack(expert_outputs, dim=1)  # (B, k, 320)

        # Move expert_stack and combine_layer to CPU
        expert_stack_cpu = expert_stack.cpu()
        combine_layer_cpu = self.combine_layer.to('cpu')

        out = combine_layer_cpu(expert_stack_cpu)  # Run combine_layer on CPU
        return out.squeeze(1).to(x.device)  # Move output back to GPU

In [51]:
X_all = torch.cat(clustered_eeg, dim=0)
y_all = torch.cat(clustered_y, dim=0)
train_loader = DataLoader(TensorDataset(X_all, y_all), batch_size=32, shuffle=True)

combiner_model = CombinedModel(experts).to(device)
optimizer = torch.optim.Adam(combiner_model.parameters(), lr=1e-4)
loss_fn = nn.MSELoss()

print("Training Combiner Model...")
for epoch in range(epochs):
    # Training Phase
    combiner_model.train()
    for expert in experts:
        expert.train()
    
    epoch_loss = 0.0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
    for x, y in pbar:
        x, y = x.to(device), y.to(device)
        pred = combiner_model(x)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        pbar.set_postfix({"Loss": f"{loss.item():.4f}"})

    avg_train_loss = epoch_loss / len(train_loader)
    combiner_model.combine_layer = combiner_model.combine_layer.to(device)
    # Validation Phase with cluster-specific experts
    combiner_model.eval()
    for expert in experts:
        expert.eval()

    val_preds = []
    val_targets = []

    with torch.no_grad():
        for i in range(len(X_val)):
            x = X_val[i].unsqueeze(0).to(device)  # (1, 320, 64)
            y = y_val[i].unsqueeze(0).to(device)  # (1, 320)
            cluster_id = val_cluster_labels[i]

            expert_output = experts[cluster_id](x)  # (1, 320)

            # Fill other expert outputs with zeros
            dummy_outputs = [torch.zeros_like(expert_output) for _ in range(len(experts))]
            dummy_outputs[cluster_id] = expert_output
            expert_stack = torch.stack(dummy_outputs, dim=1).to(device)  # <-- Add .to(device)
            print("expert_stack device:", expert_stack.device)
            print("combine_layer weights device:", next(combiner_model.combine_layer.parameters()).device)
            pred = combiner_model.combine_layer(expert_stack).squeeze(1)

            val_preds.append(pred)
            val_targets.append(y)

    val_preds = torch.cat(val_preds, dim=0)       # (N_val, 320)
    val_targets = torch.cat(val_targets, dim=0)   # (N_val, 320)

    val_mse = F.mse_loss(val_preds, val_targets).item()
    val_pearson = pearson_corr(val_preds, val_targets)
    val_cosine = cosine_sim(val_preds, val_targets)

    print(f"Epoch {epoch+1}: Train Loss = {avg_train_loss:.4f} | "
          f"Val MSE = {val_mse:.4f} | Pearson = {val_pearson:.4f} | Cosine = {val_cosine:.4f}")

Training Combiner Model...


Epoch 1/5: 100%|███████████████████████████████████████████████████████| 2856/2856 [00:42<00:00, 66.95it/s, Loss=0.2374]


expert_stack device: cuda:0
combine_layer weights device: cuda:0


RuntimeError: GET was unable to find an engine to execute this computation

In [53]:
combiner_model.eval()
    for expert in experts:
        expert.eval()

    val_preds = []
    val_targets = []

    with torch.no_grad():
        for i in range(len(X_val)):
            x = X_val[i].unsqueeze(0).to(device)  # (1, 320, 64)
            y = y_val[i].unsqueeze(0).to(device)  # (1, 320)
            cluster_id = val_cluster_labels[i]

            expert_output = experts[cluster_id](x)  # (1, 320)

            # Fill other expert outputs with zeros
            dummy_outputs = [torch.zeros_like(expert_output) for _ in range(len(experts))]
            dummy_outputs[cluster_id] = expert_output
            expert_stack = torch.stack(dummy_outputs, dim=1).to(device)  # <-- Add .to(device)
            print("expert_stack device:", expert_stack.device)
            print("combine_layer weights device:", next(combiner_model.combine_layer.parameters()).device)
            pred = combiner_model.combine_layer(expert_stack).squeeze(1)

            val_preds.append(pred)
            val_targets.append(y)

    val_preds = torch.cat(val_preds, dim=0)       # (N_val, 320)
    val_targets = torch.cat(val_targets, dim=0)   # (N_val, 320)

    val_mse = F.mse_loss(val_preds, val_targets).item()
    val_pearson = pearson_corr(val_preds, val_targets)
    val_cosine = cosine_sim(val_preds, val_targets)

RuntimeError: Input and parameter tensors are not at the same device, found input tensor at cpu and parameter tensor at cuda:0

Conv1d in_channels: 5
