In [None]:
!pip install openneuro-py mne
!openneuro-py download --dataset=ds004504

In [None]:
import mne
import pandas as pd
import torch
import numpy as np
from scipy.signal import butter, filtfilt, welch, csd
import matplotlib.pyplot as plt

def butter_bandpass(lowcut, highcut, fs, order=4):
    """
    Designs a Butterworth bandpass filter.

    Parameters:
    - lowcut: Lower cutoff frequency of the filter.
    - highcut: Upper cutoff frequency of the filter.
    - fs: Sampling frequency.
    - order: Order of the Butterworth filter.

    Returns:
    - b, a: Numerator (b) and denominator (a) polynomials of the filter.
    """
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(order, [low, high], btype='band')
    return b, a

def filter(data, bands, fs, order=4):
    """
    Applies a Butterworth bandpass filter to the data for each specified frequency band.

    Parameters:
    - data: Input signal data (1D array).
    - bands: List of tuples specifying the frequency bands [(low1, high1), (low2, high2), ...].
    - fs: Sampling frequency.
    - order: Order of the Butterworth filter.

    Returns:
    - A 2D array where each row corresponds to the filtered signal for each band.
    """
    ret = []

    for low, high in bands:
        # Design the Butterworth bandpass filter for the given band
        b, a = butter_bandpass(low, high, fs, order=order)
        # Apply the filter to the data
        filtered_signal = filtfilt(b, a, data)
        ret.append(filtered_signal)

    # Stack the filtered signals along the first axis
    return np.stack(ret, axis=0)

diags = pd.read_csv('ds004504/participants.tsv', sep = "\t")["Group"]
scc = []
rbp = []
processed_diags = []

print(set(diags))

length = 500 * 30
bands = [(0.5, 4), (4, 8), (8, 13), (13, 25), (25, 45)]


for i, curr_diag in zip(range(1, 89), diags):

    # load the file ds004504/derivatives/sub-001/eeg/sub-001_task-eyesclosed_eeg.set
    raw = mne.io.read_raw_eeglab(f'ds004504/derivatives/sub-{"{:03}".format(i)}/eeg/sub-{"{:03}".format(i)}_task-eyesclosed_eeg.set', preload=True)
    raw_data = raw.get_data()
    # data is of shape (19, num_samples)
    # I want to reshape this to (num, 19, length)
    num_segments = raw_data.shape[1] // 15000
    raw_data = raw_data[:, :num_segments * 15000].reshape(-1, 19, 15000)
    raw_data = raw_data.reshape(-1, 19, length)

    if curr_diag != "F":
        for data in raw_data:
            filtered = filter(data, bands, 500)
            n_epochs = length // 500
            filtered = filtered.reshape(5, 19, n_epochs, 500)
            filtered = np.moveaxis(filtered, 2, 0)

            fs = 500
            nperseg = 500
            num_channels = 19

            freqs, psd = welch(filtered, fs=fs, nperseg=nperseg, axis=-1)
            psd_values = np.sum(psd, axis=-1)
            relative_band_power = psd_values / np.sum(psd_values, axis=1, keepdims=True)

            # Compute SCC values
            f_data = np.fft.rfft(filtered, axis=-1)
            auto_spectral = np.mean(np.abs(f_data) ** 2, axis=-1)

            f_data_expanded1 = f_data[:, :, :, np.newaxis, :]
            f_data_expanded2 = f_data[:, :, np.newaxis, :, :]
            cross_spectral = f_data_expanded1 * np.conj(f_data_expanded2)
            mean_cross_psd = np.mean(np.abs(cross_spectral), axis=-1)

            denominator = np.sqrt(auto_spectral[:, :, :, np.newaxis] * auto_spectral[:, :, np.newaxis, :])
            coherence_matrix = mean_cross_psd / denominator

            diag_indices = np.arange(num_channels)
            coherence_matrix[:, :, diag_indices, diag_indices] = 0

            coherence_sum = np.sum(coherence_matrix, axis=-1)
            scc_values = coherence_sum / (num_channels - 1)

            scc.append(scc_values)
            rbp.append(relative_band_power)
        processed_diags.extend([curr_diag] * raw_data.shape[0])

scc = np.array(scc)
rbp = np.array(rbp)
processed_diags = np.array(processed_diags)

print(scc.shape)
print(rbp.shape)
print(processed_diags.shape)


In [None]:
import torch

torch_scc = torch.tensor(scc).to("cuda").float()
torch_rbp = torch.tensor(rbp).to("cuda").float()

# currently of shape [32, 30, 5, 19]
# make to shape [32, 19, 5, 30]
torch_scc = torch_scc.permute(0, 3, 1, 2)
torch_rbp = torch_rbp.permute(0, 3, 1, 2)

In [None]:
new_diags = torch.tensor([0 if letter == "C" else 1 for letter in processed_diags]).to("cuda").float()

In [None]:
!pip install positional-encodings

In [None]:
import torch.nn as nn
from positional_encodings.torch_encodings import PositionalEncoding1D, PositionalEncoding2D, PositionalEncoding3D, Summer

class DiceyModel(nn.Module):
    def __init__(self):
        super(DiceyModel, self).__init__()
        # Adjusted output channels to 20
        self.rbp_conv1 = nn.Conv2d(19, 20, (5, 5), stride=1)
        self.scc_conv1 = nn.Conv2d(19, 20, (5, 5), stride=1)

        self.positional_encoding = PositionalEncoding1D(20)

        # Corrected the shape of cls_token
        self.rbp_cls_token = nn.Parameter(torch.zeros(1, 1, 20))
        self.scc_cls_token = nn.Parameter(torch.zeros(1, 1, 20))

        # Set batch_first=True
        self.rbp_mha = nn.MultiheadAttention(20, 4, batch_first=True)
        self.scc_mha = nn.MultiheadAttention(20, 4, batch_first=True)

        self.rbp_mlp = nn.Sequential(
            nn.Linear(20, 80), nn.GELU(), nn.Dropout(0.1), nn.Linear(80, 20)
        )

        self.scc_mlp = nn.Sequential(
            nn.Linear(20, 80), nn.GELU(), nn.Dropout(0.1), nn.Linear(80, 20)
        )

        self.rbp_norm1 = nn.LayerNorm(20)
        self.scc_norm1 = nn.LayerNorm(20)

        self.rbp_norm2 = nn.LayerNorm(20)
        self.scc_norm2 = nn.LayerNorm(20)

        # Adjusted MLP input size to match concatenated embeddings for binary output
        self.mlp = nn.Sequential(
            nn.Linear(40, 24), nn.GELU(), nn.Dropout(0.3), nn.Linear(24, 1),
            nn.Sigmoid()  # Added sigmoid activation for binary classification
        )

    def forward(self, rbp, scc):
        rbp = self.rbp_conv1(rbp).squeeze(-1)
        scc = self.scc_conv1(scc).squeeze(-1)

        # GELU activation
        rbp = nn.functional.gelu(rbp)
        scc = nn.functional.gelu(scc)

        # Reshape from (batch_size, 20, 26) to (batch_size, 26, 20)
        rbp = rbp.permute(0, 2, 1)
        scc = scc.permute(0, 2, 1)

        # Apply positional encoding
        rbp = rbp + self.positional_encoding(rbp)
        scc = scc + self.positional_encoding(scc)

        # Concatenate CLS token along the sequence dimension (dim=1)
        rbp_cls_token = self.rbp_cls_token.expand(rbp.size(0), -1, -1)
        scc_cls_token = self.scc_cls_token.expand(scc.size(0), -1, -1)
        rbp = torch.cat((rbp_cls_token, rbp), dim=1)
        scc = torch.cat((scc_cls_token, scc), dim=1)

        # Apply multihead attention and LayerNorm
        rbp_attn_output = self.rbp_mha(rbp, rbp, rbp)[0]
        scc_attn_output = self.scc_mha(scc, scc, scc)[0]
        rbp = self.rbp_norm1(rbp + rbp_attn_output)
        scc = self.scc_norm1(scc + scc_attn_output)

        # Apply MLP and LayerNorm
        rbp_mlp_output = self.rbp_mlp(rbp)
        scc_mlp_output = self.scc_mlp(scc)
        rbp = self.rbp_norm2(rbp + rbp_mlp_output)
        scc = self.scc_norm2(scc + scc_mlp_output)

        # Corrected indexing to extract CLS token embeddings
        embed = torch.cat((rbp[:, 0, :], scc[:, 0, :]), dim=1)
        return self.mlp(embed)  # Output now is binary classification


In [None]:
from sklearn.metrics import accuracy_score

# Assuming torch_scc, torch_rbp, and new_diags are your dataset tensors
batch_size = 32
test_split_ratio = 0.2  # 20% of the data for testing

# Create TensorDataset
dataset = TensorDataset(torch_scc, torch_rbp, new_diags)

# Calculate the size of test and train sets
test_size = int(test_split_ratio * len(dataset))
train_size = len(dataset) - test_size

# Split dataset into training and testing sets
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# Create DataLoaders for both the training and testing sets
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

model = DiceyModel().to("cuda")

# Apply weight decay for L2 regularization (typically a small value)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)

# Add learning rate scheduler
scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)  # Reduce LR by a factor of 0.1 every 10 epochs

criterion = nn.BCELoss()  # Binary Cross Entropy for binary classification

epochs = 100  # Example number of epochs
for epoch in range(epochs):
    model.train()
    running_loss = 0.0

    # Training loop
    for batch_idx, (batch_scc, batch_rbp, batch_diags) in enumerate(train_loader):
        optimizer.zero_grad()

        # Forward pass
        outputs = model(batch_scc, batch_rbp).squeeze()
        loss = criterion(outputs, batch_diags.float())

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_train_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{epochs}], Train Loss: {avg_train_loss:.4f}")

    # Evaluation on test set with accuracy
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch_scc, batch_rbp, batch_diags in test_loader:
            outputs = model(batch_scc, batch_rbp).squeeze()

            # Apply a threshold of 0.5 to predict class labels
            preds = (outputs > 0.5).float()

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(batch_diags.cpu().numpy())

    # Calculate accuracy using sklearn's accuracy_score
    accuracy = accuracy_score(all_labels, all_preds)
    print(f"Epoch [{epoch+1}/{epochs}], Test Accuracy: {accuracy:.4f}")

    # Step the scheduler (for StepLR)
    scheduler.step()

    # Alternatively, if using ReduceLROnPlateau scheduler, pass in the validation loss
    # scheduler.step(avg_test_loss)
