In [2]:
import transformers
from transformers import AutoModelForMaskedLM
from transformers import AutoTokenizer

2021-12-23 21:09:49.959953: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2021-12-23 21:09:49.959982: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


#### 1. Load the model

We'll load the model and tokenizer from a saved checkpoint.

In [9]:
class LM (object):
    def __init__ (self, model_checkpoint):
        self.model_checkpoint = model_checkpoint
        self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)
        self.model = AutoModelForMaskedLM.from_pretrained(model_checkpoint, output_hidden_states=True)

In [10]:
lm = LM ("../checkpoints/contextual-word-embeddings/checkpoint-9000/")

#### 2. Forward pass

- Split all the text into chunks of 512 tokens. 
- Run the forward method on each 512 token chunk. 
- For every token in a chunk get a 768*4 token representation from the final four layers.   

In [19]:
import torch
import json
from collections import Counter

In [267]:
def split2chunks (encoded_input, split_len=510):
    # Break into smaller chunks
    input_ids_chunks = list(encoded_input['input_ids'][0].split(split_len))
    mask_chunks = list(encoded_input['attention_mask'][0].split(split_len))
    
    for i in range (len (input_ids_chunks)):
        pad_len = 510 - input_ids_chunks[i].shape[0]
        # check if tensor length satisfies required chunk size
        if pad_len > 0:
            # if padding length is more than 0, we must add padding
            input_ids_chunks[i] = torch.cat([
                input_ids_chunks[i], torch.Tensor([0] * pad_len)
            ])
            mask_chunks[i] = torch.cat([
                mask_chunks[i], torch.Tensor([0] * pad_len)
            ])
        # Append the CLS token (id=101) and the SEP token (id=102)
        input_ids_chunks[i] = torch.cat([
            torch.Tensor([101]), input_ids_chunks[i], torch.Tensor ([102])
        ])
            
        # Add attention masks
        mask_chunks[i] = torch.cat([
            torch.Tensor([1]), mask_chunks[i], torch.Tensor([1])
        ])
        
    # Now aggregate into one example
    input_ids = torch.stack(input_ids_chunks)
    attention_mask = torch.stack(mask_chunks)
        
    input_dict = {
        'input_ids': input_ids.long().clone().detach(), #torch.tensor(input_ids.long()),
        'attention_mask': attention_mask.int().clone().detach() #torch.tensor(attention_mask.int())
    }
    return input_dict

In [338]:
def get_flattened_embeddings (outputs, attention_mask):
    # Let's concatenate the representation of the final four layers
    embeddings = torch.cat((outputs.hidden_states[-1], #12th hidden layer
                            outputs.hidden_states[-2], #11th hidden layer
                            outputs.hidden_states[-3], #10th hidden layer
                            outputs.hidden_states[-4]), dim=2)
    embeddings = torch.flatten (embeddings[:,1:-1,:], start_dim=0, end_dim=1)
    num_nonzero = (attention_mask[:,1:-1].flatten() == 0).nonzero(as_tuple=True)[0].size()[0]
    if num_nonzero == 0:
        index = None
    else:
        index = (attention_mask[:,1:-1].flatten() == 0).nonzero(as_tuple=True)[0][0].item()
    return embeddings[0:index, :]

def tokens_generator (toks):
    last_token = ""
    i = 0
    token_start = 0
    while i < len (toks):
        if i == 0:
            last_token = toks[i]
            token_start = i
        elif toks[i].startswith ("##"):
            last_token = last_token + toks[i][2:]
        else:
            yield token_start, i, last_token
            last_token = toks[i]
            token_start = i
        i += 1
    if len (last_token) > 0:
        yield token_start, i, last_token

In [298]:
with open ("../data/raw/sample.jsonl") as fin, open ("../data/contextual-embeddings/sample.tsv", "w") as fout:
    for line in fin:
        js = json.loads (line.strip())
        # extract text
        text = js["full_text"] # extract additional metadata for later
        paper_id = js["paper_id"]
        # encode the entire text
        encoded_input = lm.tokenizer(text,
                                     add_special_tokens=False,
                                     return_tensors='pt')
        
        with torch.no_grad ():
            # print (encoded_input["input_ids"].size()) # contains approx. these many tokens
            input_dict = split2chunks (encoded_input)
            outputs = lm.model(**input_dict)
            embeddings = get_flattened_embeddings (outputs, input_dict["attention_mask"])
            wordpieces = lm.tokenizer.convert_ids_to_tokens(encoded_input["input_ids"][0])
            tokens = [token for token in tokens_generator(wordpieces)]
            tokenized_text = [token for _,_,token in tokens]
            token_boundaries = [(start, ended) for start, ended, _ in tokens]
            token_embeddings = torch.stack([embeddings[start:end,:].mean(dim=0) for start, end in token_boundaries])
        
        for i, token in enumerate (tokenized_text):
            string_rep = ' '.join(list(map(str,token_embeddings[i].tolist())))
            fout.write (f'{paper_id}\t{i}\t{token}\t{string_rep}\n')

IndexError: index 0 is out of bounds for dimension 0 with size 0

In [339]:
embeddings = get_flattened_embeddings (outputs, input_dict["attention_mask"])

In [340]:
embeddings.size()

torch.Size([5610, 3072])

In [333]:
wordpieces = lm.tokenizer.convert_ids_to_tokens(encoded_input["input_ids"][0])
tokens = [token for token in tokens_generator(wordpieces)]
tokenized_text = [token for _,_,token in tokens]
token_boundaries = [(start, ended) for start, ended, _ in tokens]
token_embeddings = torch.stack([embeddings[start:end,:].mean(dim=0) for start, end in token_boundaries])
len(wordpieces)

5610

In [325]:
(input_dict["attention_mask"][:,1:-1].flatten() == 0).nonzero(as_tuple=True)[0].size()[0]

0

In [288]:
print(len (tokenized_text))
print(len (token_boundaries))
print(token_embeddings.size())

4687
4687
torch.Size([4687, 3072])


In [253]:
token_boundaries = [(start, ended) for start, ended, token in tokens_generator (toks)]

In [209]:
i = 0
remapped_toks = list ()
token_start = 0
token_ended = 0
while i < len (toks):
    if toks[i].startswith ("##"):
        remapped_toks[-1] = remapped_toks[-1] + toks[i][2:]
        token_ended += 1
    else:
        remapped_toks.append (toks[i])
        # reset start of token
        token_start = i
        token_ended = i
    i += 1

In [212]:
print(toks[0:100])
print(remapped_toks[0:100])

['previous', 'work', 'has', 'shown', 'that', 'the', 'problem', 'of', 'structural', 'differences', 'between', 'language', 'pairs', 'in', 'sm', '##t', 'can', 'be', 'alleviate', '##d', 'by', 'source', '-', 'side', 'syn', '##ta', '##ctic', 're', '##ord', '##ering', '.', 'taking', 'account', 'for', 'the', 'integration', 'with', 'sm', '##t', 'systems', ',', 'these', 'methods', 'can', 'be', 'divided', 'into', 'two', 'different', 'kinds', 'of', 'approaches', ':', 'the', 'deter', '##mini', '##stic', 're', '##ord', '##ering', 'and', 'the', 'non', '##de', '##ter', '##mini', '##stic', 're', '##ord', '##ering', 'approach', '.', 'to', 'carry', 'out', 'the', 'deter', '##mini', '##stic', 'approach', ',', 'syn', '##ta', '##ctic', 're', '##ord', '##ering', 'is', 'performed', 'uniformly', 'on', 'the', 'training', ',', 'dev', '##set', 'and', 'tests', '##et', 'before']
['previous', 'work', 'has', 'shown', 'that', 'the', 'problem', 'of', 'structural', 'differences', 'between', 'language', 'pairs', 'in', 'sm

In [207]:

for tok in toks:
    if tok.startswith ("##"):
        print (tok)

##t
##d
##ta
##ctic
##ord
##ering
##t
##mini
##stic
##ord
##ering
##de
##ter
##mini
##stic
##ord
##ering
##mini
##stic
##ta
##ctic
##ord
##ering
##set
##et
##t
##ord
##ered
##t
##ta
##ctic
##ord
##ering
##se
##ord
##ered
##t
##ders
##ona
##izan
##ifiers
##ord
##ering
##ta
##ctic
##mini
##stic
##ders
##ord
##ering
##s
##ord
##ered
##s
##ta
##ctic
##t
##s
##ord
##ering
##s
##ord
##ering
##ta
##ctic
##ser
##der
##ta
##ctic
##ara
##bic
##ta
##ctic
##ord
##ering
##b
##t
##eng
##lish
##ta
##ctic
##ord
##ering
##s
##mt
##bs
##mt
##xa
##mined
##ord
##ering
##de
##ter
##mini
##stic
##st
##ru
##ction
##osition
##bank
##ta
##ctic
##ord
##ering
##ord
##ering
##s
##ta
##ctic
##ord
##ering
##ta
##ctic
##ord
##ering
##t
##mini
##stic
##ta
##ctic
##ord
##ering
##une
##d
##ate
##mt
##ta
##ctic
##ord
##ering
##ta
##ctic
##ord
##ering
##tonic
##s
##ta
##ctic
##ord
##ering
##se
##fi
##lter
##ed
##s
##ord
##ering
##s
##b
##t
##ta
##ctic
##ord
##ering
##ness
##ta
##ctic
##ord
##ering
##s
##a
##ta
##ctic
##o

In [193]:
word2wordpieces = dict ()
for token in tokenized_text:
    if token in word2wordpieces:
        wordpieces = word2wordpieces[token]
    else:
        wordpieces = lm.tokenizer (token, add_special_tokens=False)["input_ids"]
        word2wordpieces[token] = wordpieces

In [194]:
word2wordpieces

{'Previous': [3025],
 'work': [2147],
 'has': [2038],
 'shown': [3491],
 'that': [2008],
 'the': [1996],
 'problem': [3291],
 'of': [1997],
 'structural': [8332],
 'differences': [5966],
 'between': [2090],
 'language': [2653],
 'pairs': [7689],
 'in': [1999],
 'SMT': [15488, 2102],
 'can': [2064],
 'be': [2022],
 'alleviated': [24251, 2094],
 'by': [2011],
 'source-side': [3120, 1011, 2217],
 'syntactic': [19962, 2696, 13306],
 'reordering.': [2128, 8551, 7999, 1012],
 'Taking': [2635],
 'account': [4070],
 'for': [2005],
 'integration': [8346],
 'with': [2007],
 'systems,': [3001, 1010],
 'these': [2122],
 'methods': [4725],
 'divided': [4055],
 'into': [2046],
 'two': [2048],
 'different': [2367],
 'kinds': [7957],
 'approaches': [8107],
 ':': [1024],
 'deterministic': [28283, 25300, 10074],
 'reordering': [2128, 8551, 7999],
 'and': [1998],
 'nondeterministic': [2512, 3207, 3334, 25300, 10074],
 'approach.': [3921, 1012],
 'To': [2000],
 'carry': [4287],
 'out': [2041],
 'approach,

In [178]:
print (len (tokenized_text)) # our original sequence
print (encoded_input["input_ids"].size()) # encoded sequence

3888
torch.Size([1, 5372])


In [124]:
print(outputs.hidden_states[-1][0].size()) #12th hidden layer
print(outputs.hidden_states[-2][0]) #11th hidden layer
print(outputs.hidden_states[-3][0]) #10th hidden layer
print(outputs.hidden_states[-4][0]) #9th hidden layer

torch.Size([512, 768])
tensor([[-1.6916e-01, -2.3938e-02, -6.3066e-02,  ...,  3.6686e-01,
         -2.8920e-01,  2.4723e-02],
        [-5.8806e-02,  4.8172e-01, -9.5538e-01,  ..., -1.0929e+00,
          3.6555e-01,  8.4185e-01],
        [-4.7472e-01,  1.0298e+00, -1.0489e+00,  ..., -6.4436e-01,
         -5.2447e-02,  4.9744e-01],
        ...,
        [-6.0153e-04, -7.5476e-01,  1.1760e-01,  ...,  7.1302e-01,
         -8.9428e-01, -3.3192e-02],
        [-5.5412e-01, -2.9986e-01, -3.2172e-01,  ...,  1.0814e+00,
         -2.6647e-01,  1.2822e-01],
        [-1.6653e-01,  2.9930e-02,  4.8501e-02,  ...,  3.3610e-01,
         -3.5506e-01,  5.3574e-02]])
tensor([[-0.3802, -0.2830,  0.2516,  ...,  0.3296,  0.0750, -0.0262],
        [ 0.1706,  0.1378, -1.0766,  ..., -0.6881,  0.8039,  0.7762],
        [-0.4868,  0.6286, -1.7008,  ..., -0.3775, -0.0198,  0.6475],
        ...,
        [-0.1883, -0.8068, -0.0559,  ...,  0.9973, -0.6118,  0.2201],
        [-0.9317, -0.1211, -0.6334,  ...,  1.1118, -

In [106]:
input_dict["input_ids"].size()

torch.Size([11, 512])

In [86]:
input_dict["input_ids"].size()

torch.Size([11, 512])

In [83]:
output.hidden_states[-1].size()

torch.Size([1, 512, 768])

In [60]:
with torch.no_grad(): # do this so that the costly gradients are not calculated
    text = "Replace me by any text you'd like."
    encoded_input = lm.tokenizer(text, 
                                 add_special_tokens=False,
                                 return_tensors='pt')
    print (encoded_input)
    input_ids_chunks = encoded_input['input_ids'][0].split(510)
    mask_chunks = encoded_input['attention_mask'][0].split(510)
    print (mask_chunks)
    #output = lm.model(**encoded_input)

{'input_ids': tensor([[5672, 2033, 2011, 2151, 3793, 2017, 1005, 1040, 2066, 1012]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
(tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),)


{'input_ids': tensor([[ 101, 5672, 2033, 2011, 2151, 3793, 2017, 1005, 1040, 2066, 1012,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [27]:
encoded_input

{'input_ids': tensor([[ 101, 5672, 2033, 2011, 2151, 3793, 2017, 1005, 1040, 2066, 1012,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [43]:
lm.tokenizer.convert_ids_to_tokens (encoded_input["input_ids"][0],
                                    skip_special_tokens=True)

['replace', 'me', 'by', 'any', 'text', 'you', "'", 'd', 'like', '.']

In [42]:
embeddings = torch.cat((output.hidden_states[-1][0], 
                        output.hidden_states[-2][0],
                        output.hidden_states[-3][0],
                        output.hidden_states[-4][0]), dim=1)
embeddings = embeddings[1:-1, :]
print (embeddings.size())

torch.Size([10, 3072])


In [32]:
print(output.hidden_states[-1][0]) #12th hidden layer
print(output.hidden_states[-2][0]) #11th hidden layer
print(output.hidden_states[-3][0]) #10th hidden layer
print(output.hidden_states[-4][0]) #9th hidden layer

tensor([[ 0.1510, -0.2819,  0.5040,  ..., -0.0091,  0.0294,  0.1366],
        [ 1.0676, -0.1254,  0.9921,  ...,  0.6519,  0.4978,  0.1428],
        [ 0.1723, -0.2149,  0.5841,  ..., -0.5973,  0.2768,  0.8791],
        ...,
        [ 0.3494,  0.1025,  0.7701,  ..., -0.9383, -0.5957, -0.0526],
        [ 0.2089, -0.3424, -0.0376,  ..., -0.0370, -0.1770, -0.5604],
        [-0.1014, -0.1054,  0.5419,  ...,  0.0048, -0.1940,  0.0583]])
tensor([[ 0.1486, -0.4854,  0.4405,  ..., -0.0168, -0.1848, -0.2286],
        [ 1.3134, -0.2420,  1.1080,  ...,  0.4739, -0.1670,  0.0852],
        [ 0.4342, -0.3302,  0.4590,  ..., -0.6905,  0.0618,  1.1912],
        ...,
        [ 0.1250,  0.3564,  1.0470,  ..., -1.4132,  0.0486,  0.0545],
        [ 0.0532, -0.1248, -0.1468,  ..., -0.4602, -0.8243, -0.7510],
        [ 0.0078, -0.2642,  0.3153,  ...,  0.0701, -0.1611, -0.2113]])
tensor([[-0.0364, -0.4436,  0.7083,  ...,  0.2594, -0.1732, -0.1971],
        [ 0.9671, -0.5260,  1.2343,  ...,  0.9798, -0.1338,  0

In [29]:
for i in range (len(output.hidden_states)):
    print(i, output.hidden_states[-1][0].size())

0 torch.Size([12, 768])
1 torch.Size([12, 768])
2 torch.Size([12, 768])
3 torch.Size([12, 768])
4 torch.Size([12, 768])
5 torch.Size([12, 768])
6 torch.Size([12, 768])
7 torch.Size([12, 768])
8 torch.Size([12, 768])
9 torch.Size([12, 768])
10 torch.Size([12, 768])
11 torch.Size([12, 768])
12 torch.Size([12, 768])
