In [1]:
%load_ext autoreload
%autoreload 2

In [2]:

from mamba_ssm import Mamba
# from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer


import torch
import numpy as np
import torch
import torch.nn as nn
from torch import tensor
import scipy as sp

# import torchvision
# import torchvision.transforms as transforms

import torch
from torch.utils.data import DataLoader, TensorDataset
from mamba_ssm.modules.block import Block
from functools import partial


from mamba_model import MambaEEG
from mamba_ssm.models.config_mamba import MambaConfig
from models import *

In [3]:
from tqdm.notebook import tqdm
import os
import h5py

In [4]:
def load_matlab_string(matlab_extracted_object):
    """
    Converts a string loaded from h5py into a python string
    :param matlab_extracted_object:     (h5py)  matlab string object
    :return:
        extracted_string    (str)   translated string
    """

    # print((chr(c) for c in matlab_extracted_object))
    extracted_string = u''.join(chr(c) for c in matlab_extracted_object[:].flatten())
    # print(extracted_string)
    return extracted_string

In [5]:
task = "NR"

rootdir = "/radraid/spanchavati/eegtotext/zuco-benchmark/data/"

print('##############################')
print(f'start processing ZuCo task2-NR-2.0...')

dataset_dict = {}

for file in tqdm(os.listdir(rootdir)[::-1]):
    if file.endswith(task+".mat"):
        print(file)

        file_name = rootdir + file

        # print('file name:', file_name)
        subject = file_name.split("ts")[1].split("_")[0]
        # print('subject: ', subject)

        # exclude YMH due to incomplete data because of dyslexia
        if subject != 'YMH':
            pass

        f = h5py.File(file_name,'r')
        print('keys in f:', list(f.keys()))
        try:
            sentence_data = f['sentenceData']
            # break
        except:
            continue

        contents = []
        rawEEG = []
        for i in range(sentence_data['rawData'].len()):
            content = load_matlab_string(f[sentence_data['content'][i][0]])
            raweeg = f[sentence_data['rawData'][i][0]]

            contents.append(content)
            rawEEG.append(np.array(raweeg))

        dataset_dict[subject] = {'content': contents, 'eeg': rawEEG}
        #     # contents.append(sentence_data['content'])
            


##############################
start processing ZuCo task2-NR-2.0...


  0%|          | 0/54 [00:00<?, ?it/s]

resultsYAG_NR.mat
keys in f: ['#refs#', 'sentenceData']
resultsXAH_NR.mat
keys in f: []
resultsXLS_NR.mat
keys in f: []
resultsYRK_NR.mat
keys in f: ['#refs#', 'sentenceData']
resultsYDR_NR.mat
keys in f: ['#refs#', 'sentenceData']
resultsYRP_NR.mat
keys in f: ['#refs#', 'sentenceData']
resultsYFR_NR.mat
keys in f: ['#refs#', 'sentenceData']
resultsYHS_NR.mat
keys in f: ['#refs#', 'sentenceData']
resultsXDT_NR.mat
keys in f: []
resultsXBB_NR.mat
keys in f: []
resultsYSL_NR.mat
keys in f: ['#refs#', 'sentenceData']
resultsYTL_NR.mat
keys in f: ['#refs#', 'sentenceData']
resultsXWS_NR.mat
keys in f: []
resultsYIS_NR.mat
keys in f: ['#refs#', 'sentenceData']
resultsYMD_NR.mat
keys in f: ['#refs#', 'sentenceData']
resultsXBD_NR.mat
keys in f: []
resultsYFS_NR.mat
keys in f: ['#refs#', 'sentenceData']
resultsXSS_NR.mat
keys in f: []
resultsYSD_NR.mat
keys in f: ['#refs#', 'sentenceData']
resultsYLS_NR.mat
keys in f: ['#refs#', 'sentenceData']
resultsXPB_NR.mat
keys in f: []
resultsXTR_NR.ma

In [6]:
import random
from torch.utils.data import Dataset
from transformers import AutoTokenizer
import numpy as np
import torch

class EEGTextDatasetV2(Dataset):
    def __init__(self, data_dict, subject_keys, tokenizer_name='bert-base-uncased', maxlen=15*500, mode='within'):
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.maxlen = maxlen
        self.data = []
        self.subject_to_id = {}
        self.mode = mode  # 'within', 'cross', or 'zero-shot'
        
        self.load_data(data_dict, subject_keys)

    def load_data(self, data_dict, subject_keys):
        for i, key in enumerate(subject_keys):
            patient_data = data_dict[key]
            sentences = np.array(patient_data['content'])
            eeg_data = patient_data['eeg']
            
            if key not in self.subject_to_id:
                self.subject_to_id[key] = len(self.subject_to_id)
            subject_id = self.subject_to_id[key]
            
            mean, std = self.incremental_mean_std(eeg_data)

            for sentence, eeg in zip(sentences, eeg_data):
                eeg_processed, attention_mask = self.process_eeg(eeg, mean, std)
                if eeg_processed is not None:
                    self.data.append({
                        'sentence': sentence,
                        'eeg': eeg_processed,
                        'eeg_attention_mask': attention_mask,
                        'subject_id': subject_id
                    })

    def __getitem__(self, idx):
        item = self.data[idx]
        
        tokenized = self.tokenizer(item['sentence'], return_tensors='pt', padding='max_length', truncation=True)
        
        return {
            'input_ids': tokenized['input_ids'][0],
            'attention_mask': tokenized['attention_mask'][0],
            'eeg': torch.nan_to_num(torch.tensor(item['eeg']), posinf=0, neginf=0).float(),
            'eeg_attention_mask': torch.tensor(item['eeg_attention_mask']),
            'subject_id': torch.tensor(item['subject_id'], dtype=torch.long)
        }

    def __len__(self):
        return len(self.data)

    def process_eeg(self, eeg_data, mean, std):
        """
        Normalize EEG by computing total channel mean and std.
        Right pad EEG with 0s to self.maxlen, throw error if eeg_data is longer than maxlen.
        """
        if eeg_data.shape[0] < 100:
            return None, None
    
        normalized_eeg = (eeg_data - mean) / std
        
        # Check if EEG data length exceeds maxlen
        if normalized_eeg.shape[0] > self.maxlen:
            print(f"EEG data length {normalized_eeg.shape[0]} exceeds maxlen {self.maxlen}")
            return None, None
        
        # Create attention mask
        attention_mask = np.zeros((self.maxlen,))
        attention_mask[:normalized_eeg.shape[0]] = 1
        
        # Right pad EEG data with zeros
        padded_eeg = np.zeros((self.maxlen, normalized_eeg.shape[1]))
        padded_eeg[:normalized_eeg.shape[0], :] = normalized_eeg
        
        return padded_eeg, attention_mask
    
    def incremental_mean_std(self, data_list):
        """
        Calculate mean and standard deviation incrementally for a list of EEG data arrays.
        """
        n_total = 0
        mean = 0
        M2 = 0
        for data in data_list:
            n = data.shape[0]
            if n < 100:
                continue
        n_total += n
        delta = data - mean
        mean += np.nansum(delta, axis=0) / n_total
        delta2 = data - mean
        M2 += np.nansum(delta * delta2, axis=0)

        variance = M2 / (n_total - 1)
        std = np.sqrt(variance)
        return mean, std

def create_data_splits(data_dict, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1):
    all_sentences = set()
    for subject_data in data_dict.values():
        all_sentences.update(subject_data['content'])
    
    # Select test (zero-shot) sentences
    test_sentences = set(random.sample(all_sentences, int(len(all_sentences) * test_ratio)))
    
    train_val_data = {subject: {'content': [], 'eeg': []} for subject in data_dict}
    test_data = {subject: {'content': [], 'eeg': []} for subject in data_dict}

    for subject, subject_data in data_dict.items():
        for sentence, eeg in zip(subject_data['content'], subject_data['eeg']):
            if sentence in test_sentences:
                test_data[subject]['content'].append(sentence)
                test_data[subject]['eeg'].append(eeg)
                # print(eeg)
            else:
                train_val_data[subject]['content'].append(sentence)
                train_val_data[subject]['eeg'].append(eeg)
    
    # Split remaining data into train and validation
    train_data = {subject: {'content': [], 'eeg': []} for subject in data_dict}
    val_data = {subject: {'content': [], 'eeg': []} for subject in data_dict}

    for subject, subject_data in train_val_data.items():
        n = len(subject_data['content'])
        train_idx = int(n * (train_ratio / (train_ratio + val_ratio)))
        
        train_data[subject]['content'] = subject_data['content'][:train_idx]
        train_data[subject]['eeg'] = subject_data['eeg'][:train_idx]
        
        val_data[subject]['content'] = subject_data['content'][train_idx:]
        val_data[subject]['eeg'] = subject_data['eeg'][train_idx:]

    return train_data, val_data, test_data

def create_datasets(data_dict, tokenizer_name, maxlen):
    train_data, val_data, test_data = create_data_splits(data_dict)
    
    # # Within-subject datasets
    # train_within = EEGTextDatasetV2(train_data, list(train_data.keys()), tokenizer_name, maxlen, mode='within')
    # val_within = EEGTextDatasetV2(val_data, list(val_data.keys()), tokenizer_name, maxlen, mode='within')
    
    # Cross-subject dataset
    train_data = {subject: {'content': train_data[subject]['content'],
                          'eeg': train_data[subject]['eeg']}
                for subject in train_data.keys()}
    val_data = {subject: {'content': val_data[subject]['content'],
                          'eeg':val_data[subject]['eeg']} for subject in val_data.keys()}
    
    train_cross = EEGTextDatasetV2(train_data, list(train_data.keys()), tokenizer_name, maxlen, mode='cross')
    val_cross = EEGTextDatasetV2(val_data, list(val_data.keys()), tokenizer_name, maxlen, mode='cross')
    
    # Test dataset (zero-shot)
    test = EEGTextDatasetV2(test_data, list(test_data.keys()), tokenizer_name, maxlen, mode='test')
    
    return train_cross, val_cross, test # train_within, val_within, 


In [18]:
model_name = 'bert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [19]:
tokenizer_name = model_name
maxlen = 30*500

In [None]:
train_ds, val_ds, test_ds = create_datasets(dataset_dict, tokenizer_name, maxlen)

In [9]:
# Create dataloaders
train_dataloader = DataLoader(train_ds, batch_size=4, shuffle=True,num_workers= 4)
val_dataloader = DataLoader(val_ds, batch_size=4, shuffle=False, num_workers=4)
test_dataloader = DataLoader(test_ds, batch_size = 4, shuffle = False, num_workers=4)


In [10]:
# del model
# del batch

# import gc
# gc.collect()

# torch.cuda.empty_cache()

In [11]:


encoder = HuggingFaceEncoder(model_name, freeze = True)


mm = MambaConfig(ssm_cfg = {'layer':'Mamba1'}, d_model = 32, n_layer = 12)

ee = EEGEncoder(n_channels = 105, max_length= maxlen, mamba_config=mm, embedding = 'mean', patient_ids = list(train_ds.subject_to_id.values()))



model = EEGTextCLIP(
    eeg_encoder=ee,
    text_encoder=encoder,
    text_embedding_dims=768,
    projection_dims=256,
    dropout=0.1,
    temperature=1.0,
    weight_decay=1e-6,
    head_lr=1e-4,
    image_encoder_lr=1e-4,
    text_encoder_lr=1e-4,
    lr_scheduler_patience=5.0,
    lr_scheduler_factor=0.8
)

  rank_zero_warn(
  rank_zero_warn(


In [12]:
# # Define callbacks
# checkpoint_callback = ModelCheckpoint(
#     monitor='val/loss',
#     dirpath='checkpoints/',
#     filename='eeg-text-clip-{epoch:02d}-{val_loss:.2f}',
# )

# trainer = Trainer(
#     max_epochs=10,
#     accelerator='gpu',
#     devices=[0],
#     num_sanity_val_steps=10,
#     # fast_dev_run=5,
#     # callbacks=[checkpoint_callback],
#     # log_every_n_steps=1  # Added logging for debugging
# )


# # Train the model
# trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)


<!-- # for batch in train_dataloader:
#     break -->

In [13]:
# 
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"


In [15]:
# batch['subject_id']

In [16]:
# del model
# del batch

# torch.cuda.empty_cache()

In [17]:
from torch.utils.tensorboard import SummaryWriter
import numpy as np

device = 'cuda:0'

model.to(device)


# Initialize tensorboard writer
writer = SummaryWriter()

# Get optimizer and scheduler
optim_config = model.configure_optimizers()
optimizer = optim_config['optimizer']
lr_scheduler = optim_config['lr_scheduler']

num_epochs = 30
best_val_loss = float('inf')

# Create a tqdm progress bar for epochs
epoch_bar = tqdm(range(num_epochs), desc="Training", position=0)

for epoch in epoch_bar:
    model.train()
    train_loss = 0
    
    # Create a tqdm progress bar for batches
    batch_bar = tqdm(train_dataloader, desc=f"Epoch {epoch}", position=1, leave=False)
    
    for batch_idx, batch in enumerate(batch_bar):
        batch = {b: batch[b].to(device) for b in batch}
        optimizer.zero_grad()

        eeg_embeddings, text_embeddings = model(batch)
        loss = model._compute_losses(eeg_embeddings, text_embeddings).mean()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        
        # Update batch progress bar
        batch_bar.set_postfix({
            'loss': f"{loss.item():.4f}",
            'lr': f"{optimizer.param_groups[0]['lr']:.2e}"
        })
        
        # Log training loss
        writer.add_scalar('Loss/train', loss.item(), epoch * len(train_dataloader) + batch_idx)
    
    avg_train_loss = train_loss / len(train_dataloader)
    
    # Validation loop
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in val_dataloader:
            batch = {b: batch[b].to(device) for b in batch}
            eeg_embeddings, text_embeddings = model(batch)
            loss = model._compute_losses(eeg_embeddings, text_embeddings).mean()
            val_loss += loss.item()
    
    avg_val_loss = val_loss / len(val_dataloader)
    
    # Log validation loss
    writer.add_scalar('Loss/val', avg_val_loss, epoch)
    
    # Learning rate scheduler step
    # lr_scheduler.step(avg_val_loss) #TURNED THIS OFF 7/23!!!
    
    # Log learning rate
    writer.add_scalar('Learning Rate', optimizer.param_groups[0]['lr'], epoch)
    
    # Save best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), 'best_model.pth')
    
    # Update epoch progress bar
    epoch_bar.set_postfix({
        'train_loss': f"{avg_train_loss:.4f}",
        'val_loss': f"{avg_val_loss:.4f}",
        'lr': f"{optimizer.param_groups[0]['lr']:.2e}",
        'best_val_loss': f"{best_val_loss:.4f}"
    })

# Close tensorboard writer
writer.close()

Training:   0%|          | 0/30 [00:00<?, ?it/s]

Epoch 0:   0%|          | 0/1063 [00:00<?, ?it/s]

../aten/src/ATen/native/cuda/Indexing.cu:1290: indexSelectLargeIndex: block: [418,0,0], thread: [32,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1290: indexSelectLargeIndex: block: [418,0,0], thread: [33,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1290: indexSelectLargeIndex: block: [418,0,0], thread: [34,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1290: indexSelectLargeIndex: block: [418,0,0], thread: [35,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1290: indexSelectLargeIndex: block: [418,0,0], thread: [36,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1290: indexSelectLargeIndex: block: [418,0,0], thread: [37,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1290: indexSelectLargeIndex: block: [418,

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
with torch.no_grad():
    torch.cuda.empty_cache()

In [None]:
from sklearn.metrics.pairwise import cosine_similarity


In [None]:
def get_unique_sentences(data_dict):
    unique_sentences = {}
    for key, patient_data in data_dict.items():
        for sentence in patient_data['content']:
            if sentence not in unique_sentences:
                unique_sentences[sentence] = len(unique_sentences)
    return unique_sentences

def embed_unique_sentences(model, unique_sentences, tokenizer):
    model.eval()
    sentence_order = list(unique_sentences.keys())
    inputs = tokenizer(sentence_order, return_tensors='pt', padding=True, truncation=True)
    with torch.no_grad():
        text_features = model.text_encoder(inputs.to(device))
        text_embeddings = model.text_proj(text_features.to(device))
    return text_embeddings, sentence_order

def embed_eeg_data(model, dataloader):
    model.eval()
    eeg_embeddings = []
    with torch.no_grad():
        for batch in dataloader:
            batch = {b: batch[b].to(device) for b in batch}
            
            eeg_embeds, _ = model(batch)
            eeg_embeddings.append(eeg_embeds)
    eeg_embeddings = torch.cat(eeg_embeddings)
    return eeg_embeddings

def compute_similarity(embeddings1, embeddings2):
    return cosine_similarity(embeddings1.cpu().numpy(), embeddings2.cpu().numpy())

def retrieve_closest(similarity_matrix, sentence_order, top_k=5):
    closest_indices = np.argsort(-similarity_matrix, axis=1)[:, :top_k]
    closest_sentences = [[sentence_order[idx] for idx in row] for row in closest_indices]
    return closest_sentences


In [None]:

train_dataloader = DataLoader(train_ds, batch_size=32, shuffle=False, num_workers=4)
val_dataloader = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=4)
test_dataloader = DataLoader(test_ds, batch_size = 32, shuffle = False, num_workers = 4)

unique_sentences = get_unique_sentences(dataset_dict)
text_embeddings, sentence_order = embed_unique_sentences(model, unique_sentences, tokenizer)

In [None]:
# Embed EEG data
train_eeg_embeddings = embed_eeg_data(model, train_dataloader)
val_eeg_embeddings = embed_eeg_data(model, val_dataloader)
test_eeg_embeddings = embed_eeg_data(model, test_dataloader)


# Compute similarities
train_similarity_matrix = compute_similarity(train_eeg_embeddings, text_embeddings)
val_similarity_matrix = compute_similarity(val_eeg_embeddings, text_embeddings)
test_similarity_matrix = compute_similarity(test_eeg_embeddings, text_embeddings)


# Retrieve closest matches
train_closest_matches = retrieve_closest(train_similarity_matrix, sentence_order)
val_closest_matches = retrieve_closest(val_similarity_matrix, sentence_order)
test_closest_matches = retrieve_closest(test_similarity_matrix, sentence_order)


# Example: print the closest matches for the first EEG embedding in the validation set
print("Top matches for the first EEG embedding in the validation set:")
for sentence in val_closest_matches[0]:
    print(f"Sentence: {sentence}, Similarity: {val_similarity_matrix[0, unique_sentences[sentence]]:.4f}")

In [None]:
sentence_order[10]

In [None]:

idx = 10

#Example: print the closest matches for the first EEG embedding in the validation set
print("Top matches for the first EEG embedding in the validation set:")
for sentence in train_closest_matches[idx]:
    print(f"Sentence: {sentence}, Similarity: {train_similarity_matrix[idx, unique_sentences[sentence]]:.4f}")


tokenizer.decode(train_ds[idx]['input_ids'], True)

In [None]:
# len(sentence_order)