In [2]:
import torch
import numpy as np
import argparse
import os
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import RobertaTokenizer, Data2VecTextModel
from transformers import LEDModel, LEDConfig
from transformers import AutoModelWithLMHead, AutoTokenizer, Data2VecTextModel, T5Tokenizer, LongT5Model, LongT5EncoderModel
import nibabel
import numpy as np
import mne
import pandas as pd
import mne_bids
from transformers import BertTokenizer, BertModel
import time as tm

In [19]:
def load_data(sub,ses,task):
    bids_path = mne_bids.BIDSPath(
    subject = sub, session = ses, task=task, datatype= "meg",
    root = '.')
    
    raw = mne_bids.read_raw_bids(bids_path)
    raw.load_data().filter(0.5, 30.0, n_jobs=1)
    
    df = raw.annotations.to_data_frame()
    df_new = pd.DataFrame(df.description.apply(eval).to_list())
    
    return df_new

In [30]:
hp_text = []
for i in np.arange(4):
    temp = []
    dd = load_data('01','0',str(i))
    for j in np.arange(dd.shape[0]):
        if 'word' in dd['kind'][j]:
            temp.append(dd['word'][j])
    hp_text.append(temp)
print(len(hp_text))

Extracting SQD Parameters from sub-01/ses-0/meg/sub-01_ses-0_task-0_meg.con...
Creating Raw.info structure...
Setting channel info structure...
Creating Info structure...
Ready.
Reading events from sub-01/ses-0/meg/sub-01_ses-0_task-0_events.tsv.
Reading channel info from sub-01/ses-0/meg/sub-01_ses-0_task-0_channels.tsv.
The stimulus channel "STI 014" is present in the raw data, but not included in channels.tsv. Removing the channel.
Reading 0 ... 395999  =      0.000 ...   395.999 secs...
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff fre

In [21]:
# Use GPU if possible
device = "cuda:0" if torch.cuda.is_available() else "cpu"
n_total_layers = 12 # total number of layers in model

@torch.inference_mode()
def get_flan_layer_representations(args, text_array, remove_chars):
    seq_len = args.sequence_length
    nlp_model = args.nlp_model
    word_ind_to_extract = args.word_ind_to_extract

    model = BertModel.from_pretrained('bert-base-uncased').to(device)
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    #model.config.max_decoder_position_embeddings = 16384    # Process tokens longer than 1024
    model.eval()

    # get the token embeddings
    token_embeddings = []
    for word in text_array:
        #print(word)
        current_token_embedding = get_led_token_embeddings([word], tokenizer, model, remove_chars)
        token_embeddings.append(np.mean(current_token_embedding.detach().numpy(), 1))
    
    # where to store layer-wise led embeddings of particular length
    LED = {}
    for layer in range(n_total_layers):
        LED[layer] = []
    LED[-1] = token_embeddings

    # Before we've seen enough words to make up the seq_len
    # Extract index 0 after supplying tokens 0 to 0, extract 1 after 0 to 1, 2 after 0 to 2, ... , 19 after 0 to 19
    start_time = tm.time()
    for truncated_seq_len in range(1, 1+seq_len):
        word_seq = text_array[:truncated_seq_len]
        from_start_word_ind_to_extract = -1 + truncated_seq_len
        LED = add_avrg_token_embedding_for_specific_word(word_seq, tokenizer, model, remove_chars, 
                                                            from_start_word_ind_to_extract, LED)
        if truncated_seq_len % 100 == 0:
            print('Completed {} out of {}: {}'.format(truncated_seq_len, len(text_array), tm.time()-start_time))
            start_time = tm.time()

    word_seq = text_array[:seq_len]
    if word_ind_to_extract < 0: # the index is specified from the end of the array, so invert the index
        from_start_word_ind_to_extract = seq_len + word_ind_to_extract
    else:
        from_start_word_ind_to_extract = word_ind_to_extract
        
    # Then, use sequences of length seq_len, still adding the embedding of the last word in a sequence
    for end_curr_seq in range(seq_len, len(text_array)):
        word_seq = text_array[end_curr_seq-seq_len+1:end_curr_seq+1]
        LED = add_avrg_token_embedding_for_specific_word(word_seq, tokenizer, model, remove_chars,
                                                            from_start_word_ind_to_extract, LED)

        if end_curr_seq % 100 == 0:
            print('Completed {} out of {}: {}'.format(end_curr_seq, len(text_array), tm.time()-start_time))
            start_time = tm.time()

    print('Done extracting sequences of length {}'.format(seq_len))
    return LED

# extracts layer representations for all words in words_in_array
# encoded_layers: list of tensors, length num layers. each tensor of dims num tokens by num dimensions in representation
# word_ind_to_token_ind: dict that maps from index in words_in_array to index in array of tokens when words_in_array is tokenized,
#                       with keys: index of word, and values: array of indices of corresponding tokens when word is tokenized
@torch.inference_mode()
def predict_led_embeddings(words_in_array, tokenizer, model, remove_chars):    
    for word in words_in_array:
        if word in remove_chars:
            print('An input word is also in remove_chars. This word will be removed and may lead to misalignment. Proceed with caution.')
            return -1
    
    n_seq_tokens = 0
    seq_tokens = []
    
    word_ind_to_token_ind = {}             # dict that maps index of word in words_in_array to index of tokens in seq_tokens
    
    for i,word in enumerate(words_in_array):
        word_ind_to_token_ind[i] = []      # initialize token indices array for current word
        word_tokens = tokenizer.tokenize(word)
            
        for token in word_tokens:
            if token not in remove_chars:  # don't add any tokens that are in remove_chars
                seq_tokens.append(token)
                word_ind_to_token_ind[i].append(n_seq_tokens)
                n_seq_tokens = n_seq_tokens + 1
    
    # convert token to vocabulary indices
    indexed_tokens = tokenizer.convert_tokens_to_ids(seq_tokens)
    tokens_tensor = torch.tensor([indexed_tokens]).to(device)

    # Use local attention, do not use global attention
    # attention_mask = torch.ones(tokens_tensor.shape, dtype=torch.long, device=tokens_tensor.device)
    # global_attention_mask = torch.zeros(tokens_tensor.shape, dtype=torch.long, device=tokens_tensor.device)

    outputs = model(tokens_tensor, output_hidden_states=True)
    encoder_hidden_states = outputs['hidden_states'][1:]    # This is a tuple: (layer1, layer2, ..., layer6)
    all_layers_hidden_states = encoder_hidden_states
    
    return all_layers_hidden_states, word_ind_to_token_ind, None
  
# add the embeddings for a specific word in the sequence
# token_inds_to_avrg: indices of tokens in embeddings output to avrg
@torch.inference_mode()
def add_word_led_embedding(model_dict, embeddings_to_add, token_inds_to_avrg, specific_layer=-1):
    if specific_layer >= 0:  # only add embeddings for one specified layer
        layer_embedding = embeddings_to_add[specific_layer]
        full_sequence_embedding = layer_embedding.cpu().detach().numpy()
        model_dict[specific_layer].append(np.mean(full_sequence_embedding[0,token_inds_to_avrg,:],0))
    else:
        for layer, layer_embedding in enumerate(embeddings_to_add):
            full_sequence_embedding = layer_embedding.cpu().detach().numpy()
            model_dict[layer].append(np.mean(full_sequence_embedding[0,token_inds_to_avrg,:],0)) # avrg over all tokens for specified word
    return model_dict

# predicts representations for specific word in input word sequence, and adds to existing layer-wise dictionary
#
# word_seq: numpy array of words in input sequence
# tokenizer: LED tokenizer
# model: LED model
# remove_chars: characters that should not be included in the represention when word_seq is tokenized
# from_start_word_ind_to_extract: the index of the word whose features to extract, INDEXED FROM START OF WORD_SEQ
# model_dict: where to save the extracted embeddings
@torch.inference_mode()
def add_avrg_token_embedding_for_specific_word(word_seq,tokenizer,model,remove_chars,from_start_word_ind_to_extract,model_dict):
    
    word_seq = list(word_seq)
    all_sequence_embeddings, word_ind_to_token_ind, _ = predict_led_embeddings(word_seq, tokenizer, model, remove_chars)
    token_inds_to_avrg = word_ind_to_token_ind[from_start_word_ind_to_extract]
    model_dict = add_word_led_embedding(model_dict, all_sequence_embeddings,token_inds_to_avrg)
    
    return model_dict


# get the LED token embeddings
@torch.inference_mode()
def get_led_token_embeddings(words_in_array, tokenizer, model, remove_chars):    
    for word in words_in_array:
        if word in remove_chars:
            print('An input word is also in remove_chars. This word will be removed and may lead to misalignment. Proceed with caution.')
            return -1
    
    n_seq_tokens = 0
    seq_tokens = []
    
    word_ind_to_token_ind = {}             # dict that maps index of word in words_in_array to index of tokens in seq_tokens
    
    for i,word in enumerate(words_in_array):
        word_ind_to_token_ind[i] = []      # initialize token indices array for current word
        word_tokens = tokenizer.tokenize(word)
        #print(word)
        for token in word_tokens:
            if token not in remove_chars:  # don't add any tokens that are in remove_chars
                seq_tokens.append(token)
                word_ind_to_token_ind[i].append(n_seq_tokens)
                n_seq_tokens = n_seq_tokens + 1
    
    # convert token to vocabulary indices
    indexed_tokens = tokenizer.convert_tokens_to_ids(seq_tokens)
    tokens_tensor = torch.tensor([indexed_tokens]).to(device)
    
    # outputs = model(tokens_tensor, output_hidden_states=True)
    # hidden_states = outputs['encoder_hidden_states']
    # token_embeddings = hidden_states[0].cpu()
    
    input_embedding_module = model.base_model.get_input_embeddings()
    token_embeddings = input_embedding_module(tokens_tensor).cpu()
    
    return token_embeddings

In [22]:
class Args:
  sequence_length = 20
  nlp_model = 'bert'
  word_ind_to_extract = -1

args=Args()

In [32]:
#text_array = np.load(os.getcwd() + '/stimuli_words.npy')
remove_chars = [",","\"","@"]
embeddings = []
for i in np.arange(4):
    temp = get_flan_layer_representations(args,hp_text[i],remove_chars)
    embeddings.append(temp)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Completed 100 out of 668: 1.9971370697021484
Completed 200 out of 668: 2.0290350914001465
Completed 300 out of 668: 2.020437002182007
Completed 400 out of 668: 2.032015085220337
Completed 500 out of 668: 2.150984048843384
Completed 600 out of 668: 2.0917675495147705
Done extracting sequences of length 20


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Completed 100 out of 1503: 1.9843518733978271
Completed 200 out of 1503: 2.0337765216827393
Completed 300 out of 1503: 2.125744104385376
Completed 400 out of 1503: 2.066490888595581
Completed 500 out of 1503: 2.138692617416382
Completed 600 out of 1503: 2.094560146331787
Completed 700 out of 1503: 2.0967442989349365
Completed 800 out of 1503: 2.1158299446105957
Completed 900 out of 1503: 2.1593010425567627
Completed 1000 out of 1503: 2.1351490020751953
Completed 1100 out of 1503: 2.1453917026519775
Completed 1200 out of 1503: 2.115560531616211
Completed 1300 out of 1503: 2.1829957962036133
Completed 1400 out of 1503: 2.124079704284668
Completed 1500 out of 1503: 2.119319200515747
Done extracting sequences of length 20


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Completed 100 out of 2637: 2.115116596221924
Completed 200 out of 2637: 2.1098732948303223
Completed 300 out of 2637: 2.1439590454101562
Completed 400 out of 2637: 2.1545095443725586
Completed 500 out of 2637: 2.212841510772705
Completed 600 out of 2637: 2.1501264572143555
Completed 700 out of 2637: 2.1438345909118652
Completed 800 out of 2637: 2.364421844482422
Completed 900 out of 2637: 2.8701934814453125
Completed 1000 out of 2637: 2.790339708328247
Completed 1100 out of 2637: 2.780317544937134
Completed 1200 out of 2637: 2.795426368713379
Completed 1300 out of 2637: 2.7735438346862793
Completed 1400 out of 2637: 2.781944751739502
Completed 1500 out of 2637: 2.8243918418884277
Completed 1600 out of 2637: 2.834963798522949
Completed 1700 out of 2637: 2.598944664001465
Completed 1800 out of 2637: 2.539745807647705
Completed 1900 out of 2637: 2.525031566619873
Completed 2000 out of 2637: 2.6012203693389893
Completed 2100 out of 2637: 2.5989761352539062
Completed 2200 out of 2637: 2.594

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Completed 100 out of 3753: 2.0793402194976807
Completed 200 out of 3753: 2.39516544342041
Completed 300 out of 3753: 2.557565212249756
Completed 400 out of 3753: 2.5860671997070312
Completed 500 out of 3753: 2.532302141189575
Completed 600 out of 3753: 2.689112663269043
Completed 700 out of 3753: 2.5499701499938965
Completed 800 out of 3753: 2.544285535812378
Completed 900 out of 3753: 2.5627596378326416
Completed 1000 out of 3753: 2.5988969802856445
Completed 1100 out of 3753: 2.629253625869751
Completed 1200 out of 3753: 2.6222033500671387
Completed 1300 out of 3753: 2.6548564434051514
Completed 1400 out of 3753: 2.5204029083251953
Completed 1500 out of 3753: 2.603675365447998
Completed 1600 out of 3753: 2.5384926795959473
Completed 1700 out of 3753: 2.5434272289276123
Completed 1800 out of 3753: 2.5867691040039062
Completed 1900 out of 3753: 2.5772500038146973
Completed 2000 out of 3753: 2.5533056259155273
Completed 2100 out of 3753: 2.6304519176483154
Completed 2200 out of 3753: 2.

In [46]:
embeddings.keys()

dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, -1])

In [33]:
np.save("bert-base-lw-"+str(args.sequence_length), embeddings)

In [26]:
hp_text[0][0:10]

[]

In [34]:
np.save("words-allstories", hp_text)

  arr = np.asanyarray(arr)
