In [None]:
import sys
sys.path.append("..")  # add parent directory to system path
import torch
import json
from transformers import BartTokenizer
from model.model_loader import get_model

# Re-importing the necessary libraries
from transformers import BartTokenizer, BartForConditionalGeneration
import torch.nn as nn
import torch.nn.functional as F


import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
import pickle
import json
import matplotlib.pyplot as plt

from tqdm import tqdm




from transformers import BartTokenizer, BartForConditionalGeneration


DataLoader

In [None]:
def evaluate_model_simple(dataloaders, device, tokenizer, model):
    """
    Evaluates the model on a given dataset and returns the predicted strings for inspection.
    
    Args:
    - dataloaders (dict): A dictionary containing the dataloaders for different sets ('test' in this case).
    - device (torch.device): The device to run the model on.
    - tokenizer (transformers tokenizer): The tokenizer used for decoding the model outputs.
    - model (nn.Module): The model to be evaluated.
    
    Returns:
    - list: A list containing the predicted strings for inspection.
    """

    model.eval()  # Set model to evaluate mode
    predicted_strings = []  # Store the predicted strings

    # Process only 10 batches from the dataloader
    for i, (input_embeddings, seq_len, input_masks, input_mask_invert, target_ids, _, _, _) in enumerate(dataloaders['test']):
        if i >= 10:
            break

        # Transfer data to the specified device
        input_embeddings_batch = input_embeddings.to(device).float()
        input_masks_batch = input_masks.to(device)
        target_ids_batch = target_ids.to(device)
        input_mask_invert_batch = input_mask_invert.to(device)

        # Forward pass
        seq2seqLMoutput = model(input_embeddings_batch, input_masks_batch, input_mask_invert_batch, target_ids_batch)
        logits = seq2seqLMoutput.logits
        probs = logits[0].softmax(dim=1)
        _, predictions = probs.topk(1)
        predictions = torch.squeeze(predictions)
        predicted_string = tokenizer.decode(predictions).split('</s></s>')[0].replace('<s>', '')
        
        # Append the predicted string to the list
        predicted_strings.append(predicted_string)

    return predicted_strings

# This is the refactored function.


In [None]:
from notebooks.zuco_data import ZuCo_dataset, get_input_sample,evaluate_model


In [None]:
raw="/Users/michaelholborn/Documents/SoftwareLocal/monotropism/thoughtx/notebooks/"
config_path="task1_task2_taskNRv2_finetune_BrainTranslator_skipstep1_b32_20_30_5e-05_5e-07_unique_sent.json"

In [None]:

def load_json_file(raw, filename):
    with open(raw + filename, 'r') as file:
        data = json.load(file)
    return data

In [None]:
training_config = load_json_file(raw,config_path)
training_config


In [None]:
batch_size = 1

subject_choice = training_config['subjects']
print(f'[INFO]subjects: {subject_choice}')
eeg_type_choice = training_config['eeg_type']
print(f'[INFO]eeg type: {eeg_type_choice}')
bands_choice = training_config['eeg_bands']
print(f'[INFO]using bands: {bands_choice}')

dataset_setting = 'unique_sent'

task_name = training_config['task_name']

model_name = training_config['model_name']

In [None]:

root = "/Users/michaelholborn/Documents/SoftwareLocal/monotropism/thoughtx/datasets/datasets_eeg_text/zuco"

def load_datasets(task_names):
    """ Load datasets for the given tasks """
    task_paths = {
        'task1': os.path.join(root, 'task1-SR/pickle/task1-SR-dataset.pickle'),
        'task2': os.path.join(root, 'task2-NR/pickle/task2-NR-dataset.pickle'),
        'task3': os.path.join(root, 'task3-TSR/pickle/task3-TSR-dataset.pickle'),
        'taskNRv2': os.path.join(root, 'task2-NR-2.0/pickle/task2-NR-2.0-dataset.pickle')
    }
    
    whole_dataset_dicts = []
    for task_name in task_names:
        if task_name in task_paths:
            with open(task_paths[task_name], 'rb') as handle:
                whole_dataset_dicts.append(pickle.load(handle))

    return whole_dataset_dicts

In [None]:
datasets_loaded=load_datasets(['task1','task2','taskNRv2'])

In [None]:
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')


In [None]:
test_set = ZuCo_dataset(datasets_loaded,
                         'test',
                           tokenizer,
                             subject = subject_choice,
                               eeg_type = eeg_type_choice,
                                 bands = bands_choice,
                                   setting = dataset_setting)


In [None]:
test_dataloader = DataLoader(test_set, batch_size = 1, shuffle=False, num_workers=0)


In [None]:
dataloaders = {'test':test_dataloader}
dataloaders

In [None]:
pretrained_bart = BartForConditionalGeneration.from_pretrained('facebook/bart-large')


In [None]:
model=get_model()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")



In [None]:

""" main architecture for open vocabulary EEG-To-Text decoding"""
class BrainTranslator(nn.Module):
    def __init__(self, pretrained_layers, in_feature = 840, decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048):
        super(BrainTranslator, self).__init__()
        
        self.pretrained = pretrained_layers
        # additional transformer encoder, following BART paper about 
        self.additional_encoder_layer = nn.TransformerEncoderLayer(d_model=in_feature, nhead=additional_encoder_nhead,  dim_feedforward = additional_encoder_dim_feedforward, batch_first=True)
        self.additional_encoder = nn.TransformerEncoder(self.additional_encoder_layer, num_layers=6)
        
        # print('[INFO]adding positional embedding')
        # self.positional_embedding = PositionalEncoding(in_feature)

        self.fc1 = nn.Linear(in_feature, decoder_embedding_size)

    def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted):
        """input_embeddings_batch: batch_size*Seq_len*840"""
        """input_mask: 1 is not masked, 0 is masked"""
        """input_masks_invert: 1 is masked, 0 is not masked"""
        
        # input_embeddings_batch = self.positional_embedding(input_embeddings_batch) 

        # use src_key_padding_masks
        encoded_embedding = self.additional_encoder(input_embeddings_batch, src_key_padding_mask = input_masks_invert) 
        
        # encoded_embedding = self.additional_encoder(input_embeddings_batch) 
        encoded_embedding = F.relu(self.fc1(encoded_embedding))
        out = self.pretrained(inputs_embeds = encoded_embedding, attention_mask = input_masks_batch, return_dict = True, labels = target_ids_batch_converted)                    
        
        return out

In [None]:
pretrained_bart = BartForConditionalGeneration.from_pretrained(
            "facebook/bart-large"
        )

        # Use the correct path to your model weights
checkpoint_path = "/Users/michaelholborn/Documents/SoftwareLocal/monotropism/thoughtx/local_checkpoint/task1_task2_taskNRv2_finetune_BrainTranslator_skipstep1_b1_20_30_5e-05_5e-07_unique_sent.pt"

# Initialize BrainTranslator with the pretrained BART layers

try:
    model = BrainTranslator(pretrained_bart)
except Exception as e:
    raise ValueError(f"Error initializing BrainTranslator: {str(e)}")

model_weights = torch.load(
    checkpoint_path, map_location=torch.device("cpu")
)
model.load_state_dict(model_weights)
model.eval()

In [None]:
# model = BrainTranslator(pretrained_bart,
#                           in_feature = 105*len(bands_choice),
#                             decoder_embedding_size = 1024,
#                               additional_encoder_nhead=8,
#                                 additional_encoder_dim_feedforward = 2048)

# model.load_state_dict(torch.load(checkpoint_path))
# model.to(device)

In [None]:
# output_all_results_path= "./results"

output_all_results_path = f'./results/{task_name}-{model_name}-all_decoding_results.txt'


In [None]:
evaluate_model(dataloaders, device, tokenizer, model)


In [None]:
first_batch = next(iter(dataloaders['test']))
print(len(first_batch))


In [None]:
if __name__ == '__main__':
    # Your training code here

    strings =evaluate_model_simple(dataloaders, device, tokenizer,model)



In [None]:
strings

In [None]:

def eval_model(dataloaders, device, tokenizer, criterion, model, output_all_results_path = './results/temp.txt' ):
    # modified from: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html

    model.eval()   # Set model to evaluate mode
    running_loss = 0.0

    # Iterate over data.
    sample_count = 0
    
    target_tokens_list = []
    target_string_list = []
    pred_tokens_list = []
    pred_string_list = []
    with open(output_all_results_path,'w') as f:
        for input_embeddings, seq_len, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels, sent_level_EEG in dataloaders['test']:
            # load in batch
            input_embeddings_batch = input_embeddings.to(device).float()
            input_masks_batch = input_masks.to(device)
            target_ids_batch = target_ids.to(device)
            input_mask_invert_batch = input_mask_invert.to(device)
            
            target_tokens = tokenizer.convert_ids_to_tokens(target_ids_batch[0].tolist(), skip_special_tokens = True)
            target_string = tokenizer.decode(target_ids_batch[0], skip_special_tokens = True)
            # print('target ids tensor:',target_ids_batch[0])
            # print('target ids:',target_ids_batch[0].tolist())
            # print('target tokens:',target_tokens)
            # print('target string:',target_string)
            f.write(f'target string: {target_string}\n')

            # add to list for later calculate bleu metric
            target_tokens_list.append([target_tokens])
            target_string_list.append(target_string)
            
            """replace padding ids in target_ids with -100"""
            target_ids_batch[target_ids_batch == tokenizer.pad_token_id] = -100 

            # target_ids_batch_label = target_ids_batch.clone().detach()
            # target_ids_batch_label[target_ids_batch_label == tokenizer.pad_token_id] = -100

            # forward
            seq2seqLMoutput = model(input_embeddings_batch, input_masks_batch, input_mask_invert_batch, target_ids_batch)

            """calculate loss"""
            # logits = seq2seqLMoutput.logits # 8*48*50265
            # logits = logits.permute(0,2,1) # 8*50265*48

            # loss = criterion(logits, target_ids_batch_label) # calculate cross entropy loss only on encoded target parts
            # NOTE: my criterion not used
            loss = seq2seqLMoutput.loss # use the BART language modeling loss


            # get predicted tokens
            # print('target size:', target_ids_batch.size(), ',original logits size:', logits.size())
            logits = seq2seqLMoutput.logits # 8*48*50265
            # logits = logits.permute(0,2,1)
            # print('permuted logits size:', logits.size())
            probs = logits[0].softmax(dim = 1)
            # print('probs size:', probs.size())
            values, predictions = probs.topk(1)
            # print('predictions before squeeze:',predictions.size())
            predictions = torch.squeeze(predictions)
            predicted_string = tokenizer.decode(predictions).split('</s></s>')[0].replace('<s>','')
            # print('predicted string:',predicted_string)
            f.write(f'predicted string: {predicted_string}\n')
            f.write(f'################################################\n\n\n')

            # convert to int list
            predictions = predictions.tolist()
            truncated_prediction = []
            for t in predictions:
                if t != tokenizer.eos_token_id:
                    truncated_prediction.append(t)
                else:
                    break
            pred_tokens = tokenizer.convert_ids_to_tokens(truncated_prediction, skip_special_tokens = True)
            # print('predicted tokens:',pred_tokens)
            pred_tokens_list.append(pred_tokens)
            pred_string_list.append(predicted_string)
            # print('################################################')
            # print()

            sample_count += 1
            # statistics
            running_loss += loss.item() * input_embeddings_batch.size()[0] # batch loss
            # print('[DEBUG]loss:',loss.item())
            # print('#################################')


    epoch_loss = running_loss / dataset_sizes['test_set']
    print('test loss: {:4f}'.format(epoch_loss))

    """ calculate corpus bleu score """
    weights_list = [(1.0,),(0.5,0.5),(1./3.,1./3.,1./3.),(0.25,0.25,0.25,0.25)]
    for weight in weights_list:
        # print('weight:',weight)
        corpus_bleu_score = corpus_bleu(target_tokens_list, pred_tokens_list, weights = weight)
        print(f'corpus BLEU-{len(list(weight))} score:', corpus_bleu_score)

    print()
    """ calculate rouge score """
    rouge = Rouge()
    rouge_scores = rouge.get_scores(pred_string_list,target_string_list, avg = True)
    print(rouge_scores)