In [1]:
import torch
import numpy as np
import argparse
import os
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import RobertaTokenizer, Data2VecTextModel
from transformers import LEDModel, LEDConfig
from transformers import AutoModelWithLMHead, AutoTokenizer, Data2VecTextModel, T5Tokenizer, LongT5Model, LongT5EncoderModel
import nibabel
import numpy as np
import mne
import pandas as pd
import mne_bids
from transformers import BertTokenizer, BertModel
import time as tm

In [2]:
def load_data(sub,ses,task):
    bids_path = mne_bids.BIDSPath(
    subject = sub, session = ses, task=task, datatype= "meg",
    root = '.')
    
    raw = mne_bids.read_raw_bids(bids_path)
    raw.load_data().filter(0.5, 30.0, n_jobs=1)
    
    df = raw.annotations.to_data_frame()
    df_new = pd.DataFrame(df.description.apply(eval).to_list())
    
    return df_new

In [3]:
hp_text = []
for i in np.arange(4):
    temp = []
    dd = load_data('01','0',str(i))
    for j in np.arange(dd.shape[0]):
        if 'word' in dd['kind'][j]:
            temp.append(dd['word'][j])
    hp_text.append(temp)
print(len(hp_text))

Extracting SQD Parameters from sub-01/ses-0/meg/sub-01_ses-0_task-0_meg.con...
Creating Raw.info structure...
Setting channel info structure...
Creating Info structure...
Ready.
Reading events from sub-01/ses-0/meg/sub-01_ses-0_task-0_events.tsv.
Reading channel info from sub-01/ses-0/meg/sub-01_ses-0_task-0_channels.tsv.
The stimulus channel "STI 014" is present in the raw data, but not included in channels.tsv. Removing the channel.
Reading 0 ... 395999  =      0.000 ...   395.999 secs...
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff fre

In [4]:
# Use GPU if possible
device = "cuda:0" if torch.cuda.is_available() else "cpu"
n_total_layers = 12 # total number of layers in model

@torch.inference_mode()
def get_flan_layer_representations(args, text_array, remove_chars):
    seq_len = args.sequence_length
    nlp_model = args.nlp_model
    word_ind_to_extract = args.word_ind_to_extract

    model = BertModel.from_pretrained('bert-base-uncased').to(device)
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    #model.config.max_decoder_position_embeddings = 16384    # Process tokens longer than 1024
    model.eval()

    # get the token embeddings
    token_embeddings = []
    for word in text_array:
        current_token_embedding = get_led_token_embeddings([word], tokenizer, model, remove_chars)
        token_embeddings.append(np.mean(current_token_embedding.detach().numpy(), 1))
    
    # where to store layer-wise led embeddings of particular length
    LED = {}
    for layer in range(n_total_layers):
        LED[layer] = []
    LED[-1] = token_embeddings
    
    Attention = {}
    for layer in range(n_total_layers):
        Attention[layer] = []

    # Before we've seen enough words to make up the seq_len
    # Extract index 0 after supplying tokens 0 to 0, extract 1 after 0 to 1, 2 after 0 to 2, ... , 19 after 0 to 19
    start_time = tm.time()
    
    for truncated_seq_len in range(len(text_array)-seq_len):
        word_seq = text_array[truncated_seq_len:truncated_seq_len+seq_len]
        from_start_word_ind_to_extract = word_ind_to_extract
        #print(word_seq, from_start_word_ind_to_extract)
        LED,attentions = add_avrg_token_embedding_for_specific_word(word_seq, tokenizer, model, remove_chars, 
                                                            from_start_word_ind_to_extract, LED)
        for layer in range(n_total_layers):
            Attention[layer].append(attentions[layer].detach().numpy()[:,:,:,from_start_word_ind_to_extract])
        if truncated_seq_len % 100 == 0:
            print('Completed {} out of {}: {}'.format(truncated_seq_len, len(text_array), tm.time()-start_time))
            start_time = tm.time()
#     for truncated_seq_len in range(1, 1+seq_len):
#         print(text_array[:truncated_seq_len])
#         word_seq = text_array[:truncated_seq_len]
#         from_start_word_ind_to_extract = -1 + truncated_seq_len
#         LED = add_avrg_token_embedding_for_specific_word(word_seq, tokenizer, model, remove_chars, 
#                                                             from_start_word_ind_to_extract, LED)
#         if truncated_seq_len % 100 == 0:
#             print('Completed {} out of {}: {}'.format(truncated_seq_len, len(text_array), tm.time()-start_time))
#             start_time = tm.time()

    word_seq = text_array[:seq_len]
        
    # Then, use sequences of length seq_len, still adding the embedding of the last word in a sequence
    for end_curr_seq in range(len(text_array)-seq_len, len(text_array)):
        #print(text_array[end_curr_seq:len(text_array)])
        word_seq = text_array[end_curr_seq:len(text_array)]
        if word_ind_to_extract > 0: # the index is specified from the end of the array, so invert the index
            if len(word_seq) <= 1:
                from_start_word_ind_to_extract = 0
            else:
                from_start_word_ind_to_extract = word_ind_to_extract
        else:
            from_start_word_ind_to_extract = word_ind_to_extract
        #print(word_seq, from_start_word_ind_to_extract)
        LED,attentions = add_avrg_token_embedding_for_specific_word(word_seq, tokenizer, model, remove_chars,
                                                            from_start_word_ind_to_extract, LED)
        for layer in range(n_total_layers):
            Attention[layer].append(attentions[layer].detach().numpy()[:,:,:,from_start_word_ind_to_extract])

        if end_curr_seq % 100 == 0:
            print('Completed {} out of {}: {}'.format(end_curr_seq, len(text_array), tm.time()-start_time))
            start_time = tm.time()

    print('Done extracting sequences of length {}'.format(seq_len))
    return LED, Attention

# extracts layer representations for all words in words_in_array
# encoded_layers: list of tensors, length num layers. each tensor of dims num tokens by num dimensions in representation
# word_ind_to_token_ind: dict that maps from index in words_in_array to index in array of tokens when words_in_array is tokenized,
#                       with keys: index of word, and values: array of indices of corresponding tokens when word is tokenized
@torch.inference_mode()
def predict_led_embeddings(words_in_array, tokenizer, model, remove_chars):    
    for word in words_in_array:
        if word in remove_chars:
            print('An input word is also in remove_chars. This word will be removed and may lead to misalignment. Proceed with caution.')
            return -1
    
    n_seq_tokens = 0
    seq_tokens = []
    
    word_ind_to_token_ind = {}             # dict that maps index of word in words_in_array to index of tokens in seq_tokens
    
    for i,word in enumerate(words_in_array):
        word_ind_to_token_ind[i] = []      # initialize token indices array for current word
        word_tokens = tokenizer.tokenize(word)
            
        for token in word_tokens:
            if token not in remove_chars:  # don't add any tokens that are in remove_chars
                seq_tokens.append(token)
                word_ind_to_token_ind[i].append(n_seq_tokens)
                n_seq_tokens = n_seq_tokens + 1
    # convert token to vocabulary indices
    indexed_tokens = tokenizer.convert_tokens_to_ids(seq_tokens)
    tokens_tensor = torch.tensor([indexed_tokens]).to(device)

    # Use local attention, do not use global attention
    # attention_mask = torch.ones(tokens_tensor.shape, dtype=torch.long, device=tokens_tensor.device)
    # global_attention_mask = torch.zeros(tokens_tensor.shape, dtype=torch.long, device=tokens_tensor.device)

    outputs = model(tokens_tensor, output_hidden_states=True,output_attentions=True)
    encoder_hidden_states = outputs['hidden_states'][1:]    # This is a tuple: (layer1, layer2, ..., layer6)
    all_layers_hidden_states = encoder_hidden_states
    
    return all_layers_hidden_states, word_ind_to_token_ind, outputs['attentions']
  
# add the embeddings for a specific word in the sequence
# token_inds_to_avrg: indices of tokens in embeddings output to avrg
@torch.inference_mode()
def add_word_led_embedding(model_dict, embeddings_to_add, token_inds_to_avrg, specific_layer=-1):
    if specific_layer >= 0:  # only add embeddings for one specified layer
        layer_embedding = embeddings_to_add[specific_layer]
        full_sequence_embedding = layer_embedding.cpu().detach().numpy()
        model_dict[specific_layer].append(np.mean(full_sequence_embedding[0,token_inds_to_avrg,:],0))
    else:
        for layer, layer_embedding in enumerate(embeddings_to_add):
            full_sequence_embedding = layer_embedding.cpu().detach().numpy()
            model_dict[layer].append(np.mean(full_sequence_embedding[0,token_inds_to_avrg,:],0)) # avrg over all tokens for specified word
    return model_dict

# predicts representations for specific word in input word sequence, and adds to existing layer-wise dictionary
#
# word_seq: numpy array of words in input sequence
# tokenizer: LED tokenizer
# model: LED model
# remove_chars: characters that should not be included in the represention when word_seq is tokenized
# from_start_word_ind_to_extract: the index of the word whose features to extract, INDEXED FROM START OF WORD_SEQ
# model_dict: where to save the extracted embeddings
@torch.inference_mode()
def add_avrg_token_embedding_for_specific_word(word_seq,tokenizer,model,remove_chars,from_start_word_ind_to_extract,model_dict):
    
    word_seq = list(word_seq)
    all_sequence_embeddings, word_ind_to_token_ind, attentions = predict_led_embeddings(word_seq, tokenizer, model, remove_chars)
    token_inds_to_avrg = word_ind_to_token_ind[from_start_word_ind_to_extract]
    model_dict = add_word_led_embedding(model_dict, all_sequence_embeddings,token_inds_to_avrg)
    
    return model_dict,attentions


# get the LED token embeddings
@torch.inference_mode()
def get_led_token_embeddings(words_in_array, tokenizer, model, remove_chars):    
    for word in words_in_array:
        if word in remove_chars:
            print('An input word is also in remove_chars. This word will be removed and may lead to misalignment. Proceed with caution.')
            return -1
    
    n_seq_tokens = 0
    seq_tokens = []
    
    word_ind_to_token_ind = {}             # dict that maps index of word in words_in_array to index of tokens in seq_tokens
    
    for i,word in enumerate(words_in_array):
        word_ind_to_token_ind[i] = []      # initialize token indices array for current word
        word_tokens = tokenizer.tokenize(word)
        #print(word)
        for token in word_tokens:
            if token not in remove_chars:  # don't add any tokens that are in remove_chars
                seq_tokens.append(token)
                word_ind_to_token_ind[i].append(n_seq_tokens)
                n_seq_tokens = n_seq_tokens + 1
    
    # convert token to vocabulary indices
    indexed_tokens = tokenizer.convert_tokens_to_ids(seq_tokens)
    tokens_tensor = torch.tensor([indexed_tokens]).to(device)
    
    # outputs = model(tokens_tensor, output_hidden_states=True)
    # hidden_states = outputs['encoder_hidden_states']
    # token_embeddings = hidden_states[0].cpu()
    
    input_embedding_module = model.base_model.get_input_embeddings()
    token_embeddings = input_embedding_module(tokens_tensor).cpu()
    
    return token_embeddings

In [5]:
class Args:
  sequence_length = 5
  nlp_model = 'bert'
  word_ind_to_extract = 0

args=Args()

In [6]:
#text_array = np.load(os.getcwd() + '/stimuli_words.npy')
remove_chars = [",","\"","@"]
embeddings = []
att_emb = []
for i in np.arange(4):
    temp, attention = get_flan_layer_representations(args,hp_text[i],remove_chars)
    embeddings.append(temp)
    att_emb.append(attention)
    #break

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Completed 0 out of 668: 0.024509906768798828
Completed 100 out of 668: 1.8443083763122559
Completed 200 out of 668: 1.822925090789795
Completed 300 out of 668: 1.831390380859375
Completed 400 out of 668: 1.9879462718963623
Completed 500 out of 668: 1.838944911956787
Completed 600 out of 668: 1.9465408325195312
Done extracting sequences of length 5


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Completed 0 out of 1503: 0.021186113357543945
Completed 100 out of 1503: 1.8425421714782715
Completed 200 out of 1503: 1.831035852432251
Completed 300 out of 1503: 1.8355066776275635
Completed 400 out of 1503: 1.816735029220581
Completed 500 out of 1503: 1.8367316722869873
Completed 600 out of 1503: 1.8706519603729248
Completed 700 out of 1503: 1.8433268070220947
Completed 800 out of 1503: 1.8903298377990723
Completed 900 out of 1503: 1.834489345550537
Completed 1000 out of 1503: 1.8771371841430664
Completed 1100 out of 1503: 1.8333868980407715
Completed 1200 out of 1503: 1.83292818069458
Completed 1300 out of 1503: 1.8219003677368164
Completed 1400 out of 1503: 1.852733850479126
Completed 1500 out of 1503: 1.823373556137085
Done extracting sequences of length 5


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Completed 0 out of 2637: 0.01956629753112793
Completed 100 out of 2637: 1.8261232376098633
Completed 200 out of 2637: 1.8190717697143555
Completed 300 out of 2637: 1.8413641452789307
Completed 400 out of 2637: 1.837132215499878
Completed 500 out of 2637: 1.819685697555542
Completed 600 out of 2637: 1.843285083770752
Completed 700 out of 2637: 1.833064079284668
Completed 800 out of 2637: 1.8373029232025146
Completed 900 out of 2637: 1.8290891647338867
Completed 1000 out of 2637: 1.8302440643310547
Completed 1100 out of 2637: 1.842980146408081
Completed 1200 out of 2637: 1.8068034648895264
Completed 1300 out of 2637: 1.8300073146820068
Completed 1400 out of 2637: 1.8509764671325684
Completed 1500 out of 2637: 1.838590145111084
Completed 1600 out of 2637: 1.8170878887176514
Completed 1700 out of 2637: 1.8136839866638184
Completed 1800 out of 2637: 1.8293733596801758
Completed 1900 out of 2637: 1.8208563327789307
Completed 2000 out of 2637: 1.8240318298339844
Completed 2100 out of 2637: 1.

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Completed 0 out of 3753: 0.0420989990234375
Completed 100 out of 3753: 1.8296763896942139
Completed 200 out of 3753: 1.8240387439727783
Completed 300 out of 3753: 1.820549488067627
Completed 400 out of 3753: 1.8373620510101318
Completed 500 out of 3753: 1.8306810855865479
Completed 600 out of 3753: 1.8226678371429443
Completed 700 out of 3753: 1.8175740242004395
Completed 800 out of 3753: 1.8216567039489746
Completed 900 out of 3753: 1.8408787250518799
Completed 1000 out of 3753: 1.819706678390503
Completed 1100 out of 3753: 1.8334987163543701
Completed 1200 out of 3753: 1.8296220302581787
Completed 1300 out of 3753: 1.825845718383789
Completed 1400 out of 3753: 1.8367488384246826
Completed 1500 out of 3753: 1.87437105178833
Completed 1600 out of 3753: 1.9810292720794678
Completed 1700 out of 3753: 1.8242592811584473
Completed 1800 out of 3753: 1.8475041389465332
Completed 1900 out of 3753: 1.865833044052124
Completed 2000 out of 3753: 1.8207762241363525
Completed 2100 out of 3753: 1.8

In [7]:
np.save("bert-base-lw-rh-attention-"+str(args.sequence_length), embeddings)

In [26]:
hp_text[0][0:10]

[]

In [3]:
hp_text = np.load("words-allstories.npy", allow_pickle=True)