In [1]:
import sys
sys.path.append('../')
from util import util, corpus, alignment, vectorize, retrieval, tokenize

from tqdm.notebook import tqdm as tq
import pandas as pd

[nltk_data] Downloading package stopwords to /home/sudhi/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to /home/sudhi/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
scraped_df = util.decompress_pickle('../../data/other/scraped_df')
scraped_en_docs = util.decompress_pickle('../../data/other/scraped_en_docs')
scraped_de_docs = util.decompress_pickle('../../data/other/scraped_de_docs')

In [None]:
cont_en_doc_vectors = util.decompress_pickle('../../data/other/cont_en_doc_vectors')
cont_de_doc_vectors = util.decompress_pickle('../../data/other/cont_de_doc_vectors')

In [None]:
from sentence_transformers import SentenceTransformer

In [None]:
model = SentenceTransformer('../../data/other/cont_model/')

In [None]:
encoding = vectorize.Encoding(model=[model], model_type="sentence_bert")

In [None]:
cont_retrieval = vectorize.Retrieval(
    vectors=[cont_en_doc_vectors, cont_de_doc_vectors],
    docs=[scraped_en_docs, scraped_de_docs])

In [None]:
query = "adaptive cruise control"
query_vector = encoding.encode(sent=[query])
query_vector[0][:10]

In [None]:
retrieved = cont_retrieval.retrieve(query_vec=query_vector, num_ret=100)
# retrieved

In [None]:
from util import projection

In [None]:
en_vectors = [cont_en_doc_vectors[ind] for ind in retrieved['en_ind']]
de_vectors = [cont_de_doc_vectors[ind] for ind in retrieved['de_ind']]

In [None]:
%%time
# run once, mayve save as pickle file
en_proj_2d = projection.get_projection_2d(cont_en_doc_vectors)
# de_proj_2d = projection.get_projection_2d(cont_de_doc_vectors)

In [None]:
def get_triplets(docs, proj_2d, indices):
    ret_colors = []
    ret_docs = []
    ret_projs = []
    for i in range(len(proj_2d)):
        if i not in indices:
            continue
        elif i in indices[:10]:
            ret_colors.append('red')
            ret_docs.append(docs[i])
            ret_projs.append(proj_2d[i])
        elif i in indices[90:]:
            ret_colors.append('blue')
            ret_docs.append(docs[i])
            ret_projs.append(proj_2d[i])    
    return ret_docs, ret_projs, ret_colors

In [None]:
ret_docs, ret_projs, ret_colors = get_triplets(scraped_en_docs, en_proj_2d, retrieved["en_ind"])
projection.plot_umap_2d(ret_docs, ret_projs, ret_colors)

In [None]:
ret_docs, ret_projs, ret_colors = get_triplets(scraped_de_docs, de_proj_2d, retrieved["de_ind"])
projection.plot_umap_2d(ret_docs, ret_projs, ret_colors)

In [None]:
vectors_comb = [] 
for vec in cont_en_doc_vectors:
    vectors_comb.append(vec)

for vec in cont_de_doc_vectors:
    vectors_comb.append(vec)
    
for vec in query_vector:
    vectors_comb.append(vec)
    
len(vectors_comb)

In [None]:
%%time
proj_comb_2d = projection.get_projection_2d(vectors_comb)
proj_comb_2d.shape

In [None]:
en_proj_2d_ = proj_comb_2d[:len(cont_en_doc_vectors)]
de_proj_2d_ = proj_comb_2d[len(cont_en_doc_vectors):-1]
query_proj_2d_ = [proj_comb_2d[-1]]

print(len(cont_en_doc_vectors), len(cont_de_doc_vectors), len(query_vector))
print(len(en_proj_2d_), len(de_proj_2d_), len(query_proj_2d_))

In [None]:
ret_colors = []
ret_docs = []
ret_projs = []

proj_2d = en_proj_2d_
indices = retrieved['en_ind']
docs = scraped_en_docs


for i in range(len(proj_2d)):
    if i not in indices:
        continue
    elif i in indices[:10]:
        ret_colors.append('red')
        ret_docs.append(docs[i])
        ret_projs.append(proj_2d[i])
    elif i in indices[90:]:
        ret_colors.append('blue')
        ret_docs.append(docs[i])
        ret_projs.append(proj_2d[i])  
        
        
proj_2d = de_proj_2d_
indices = retrieved['de_ind']
docs = scraped_de_docs


for i in range(len(proj_2d)):
    if i not in indices:
        continue
    elif i in indices[:10]:
        ret_colors.append('red')
        ret_docs.append(docs[i])
        ret_projs.append(proj_2d[i])
    else:
        ret_colors.append('blue')
        ret_docs.append(docs[i])
        ret_projs.append(proj_2d[i])  

In [None]:
ret_docs.append(query)
ret_colors.append('green')
ret_projs.extend(query_proj_2d_)

In [None]:
projection.plot_umap_2d(ret_docs, ret_projs, ret_colors)

In [None]:
def plot_umap_for_n(n):
    ret_colors = []
    ret_docs = []
    ret_projs = []
    
    proj_2d = en_proj_2d_
    indices = retrieved['en_ind']
    docs = scraped_en_docs
    
    
    for i in range(len(proj_2d)):
        if i not in indices:
            continue
        elif i in indices[:n]:
            ret_colors.append('red')
            ret_docs.append(docs[i])
            ret_projs.append(proj_2d[i])
        elif i in indices[100-n:]:
            ret_colors.append('blue')
            ret_docs.append(docs[i])
            ret_projs.append(proj_2d[i])  
            
            
    proj_2d = de_proj_2d_
    indices = retrieved['de_ind']
    docs = scraped_de_docs
    
    
    for i in range(len(proj_2d)):
        if i not in indices:
            continue
        elif i in indices[:n]:
            ret_colors.append('red')
            ret_docs.append(docs[i])
            ret_projs.append(proj_2d[i])
        else:
            ret_colors.append('blue')
            ret_docs.append(docs[i])
            ret_projs.append(proj_2d[i])  
            
    ret_docs.append(query)
    ret_colors.append('green')
    ret_projs.extend(query_proj_2d_)
    
    projection.plot_umap_2d(ret_docs, ret_projs, ret_colors)

In [None]:
plot_umap_for_n(5)

In [None]:
plot_umap_for_n(10)

In [None]:
plot_umap_for_n(20)

In [None]:
plot_umap_for_n(20)

In [None]:
plot_umap_for_n(30)

In [None]:
plot_umap_for_n(40)

In [None]:
for doc in retrieved['en_docs'][:5]:
    print(doc)

In [None]:
retrieved['en_docs'][:5]

In [None]:
import spacy
nlp_en = spacy.load("en_core_web_sm")
nlp_de = spacy.load('de_core_news_sm')

In [None]:
import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords
# define stop words for EN and DE
stop_words_en = set(stopwords.words('english'))
stop_words_de = set(stopwords.words('german'))

In [None]:
for doc in retrieved['en_docs'][:5]:
    noun_chunks = [nc for nc in nlp_en(doc).noun_chunks]
    for nc in noun_chunks:
        nc = str(nc)
        nc = nc.lower()
        ncs = [re.sub(stop, "", nc) for stop in stop_words_en]
        print(ncs)

In [None]:
for doc in retrieved['de_docs'][:5]:
    noun_chunks = [nc for nc in nlp_de(doc).noun_chunks]
    print(noun_chunks)

In [None]:
# function similar to the one under utils/explain.py
import re

TOKENS = 0
NOUN_CHUNKS = 1

def get_token_import_doc(encoding, query_vec, doc, lang=None, which_type=TOKENS): 
    if which_type == TOKENS:
        tokens = tokenize.get_tokens(doc, lang)
        # tokens = tokenize.get_naive_tokens(doc)
    elif which_type == NOUN_CHUNKS:
        nlp = nlp_en if lang == 'en' else nlp_de
        tokens = [nc for nc in nlp(doc).noun_chunks]
    
    tok_imp = {}
    for tok in tokens:
        doc_wo_tok = re.sub(str(tok), '', doc.lower())
        
        sent_with_tok_vec = encoding.encode(sent=[doc], lang=lang)
        sent_wo_tok_vec = encoding.encode(sent=[doc_wo_tok], lang=lang)
        
        retrieval = vectorize.Retrieval(vectors=[sent_with_tok_vec, sent_wo_tok_vec],
                                docs=[[doc], [doc_wo_tok]]
                               )
        ret = retrieval.retrieve(query_vec=query_vec, num_ret=len(doc))
        ref = ret['en_sim'][0]
        sim = ret['de_sim'][0]        
        imp = (ref-sim)/ref
        
        if imp != 0:
            tok_imp[tok] = imp 
            
    tokens_import = dict(sorted(tok_imp.items(), key=lambda item: item[1]))
    
    return tokens_import 

In [None]:
for ind in range(2):
    doc = retrieved['en_docs'][ind]
    sim = retrieved['en_sim'][ind]
    print(doc, sim)
    print(get_token_import_doc(encoding=encoding, query_vec=query_vector, doc=doc, lang='en', which_type=TOKENS))
    print(get_token_import_doc(encoding=encoding, query_vec=query_vector, doc=doc, lang='en', which_type=NOUN_CHUNKS))
    print("----")

In [None]:
def display_tok_imp(tok_imp):
    for val in tok_imp:
        print(val, tok_imp[val])

In [None]:
ind_1, ind_2 = 0, 80

print(f"Document {ind_1}")
doc = retrieved['en_docs'][ind_1]
sim = retrieved['en_sim'][ind_1]
print(doc, sim)
display_tok_imp(get_token_import_doc(encoding=encoding, query_vec=query_vector, doc=doc, lang='en', which_type=TOKENS))
print("----")
display_tok_imp(get_token_import_doc(encoding=encoding, query_vec=query_vector, doc=doc, lang='en', which_type=NOUN_CHUNKS))
print("----")
print("----")

print(f"Document {ind_2}")
doc = retrieved['en_docs'][ind_2]
sim = retrieved['en_sim'][ind_2]
print(doc, sim)
display_tok_imp(get_token_import_doc(encoding=encoding, query_vec=query_vector, doc=doc, lang='en', which_type=TOKENS))
print("----")
display_tok_imp(get_token_import_doc(encoding=encoding, query_vec=query_vector, doc=doc, lang='en', which_type=NOUN_CHUNKS))
print("----")
print("----")

print(f"Document {ind_1} + {ind_2}")
doc = retrieved['en_docs'][ind_1] + "\n" + retrieved['en_docs'][ind_2]
display_tok_imp(get_token_import_doc(encoding=encoding, query_vec=query_vector, doc=doc, lang='en', which_type=TOKENS))
print("----")
display_tok_imp(get_token_import_doc(encoding=encoding, query_vec=query_vector, doc=doc, lang='en', which_type=NOUN_CHUNKS))
print("----")
print("----")

In [None]:
doc = retrieved['en_docs'][5]
new_doc = re.sub("the main factors", "--", doc)

In [None]:
doc_vec = encoding.encode([doc])
new_doc_vec = encoding.encode([new_doc])

In [None]:
from sklearn.metrics.pairwise import cosine_similarity
cosine_similarity(query_vector, Y=doc_vec, 
                           dense_output=True)[0]

In [None]:
cosine_similarity(query_vector, Y=new_doc_vec, 
                           dense_output=True)[0]

## Intra-similarity

In [None]:
from imp import reload
reload(vectorize)

In [None]:
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

In [None]:
def sort_dict(dict_to_sort):
    return dict(sorted(dict_to_sort.items(), key=lambda item: item[1]))

In [None]:
def get_imp_word(query):
    query_vec = encoding.encode(sent=[query])
    word_import = {}
    for word in query.split(" "):
        word_vec = encoding.encode(sent=[word])
        word_import[word] = cosine_similarity(word_vec, query_vec)[0]
    word_import = sort_dict(word_import)
    imp_word = list(word_import.keys())[-1]
    return imp_word

In [None]:
# query = 'Collision avoidance system'
query = 'adaptive cruise control'

imp_word = get_imp_word(query)

In [None]:
from transformers import BertTokenizer
import torch
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-uncased')
cont_model = torch.load('/home/sudhi/thesis/thesis_cltr_app/data/models/model_MLM_NSP.pt').module.eval()

In [None]:
word_vec = encoding.encode(sent=[imp_word])
np.asarray(word_vec).shape

In [None]:
docs_with_word_idx = [idx for idx, doc in enumerate(scraped_en_docs) if imp_word in doc][:100]
docs_with_word = [scraped_en_docs[idx] for idx in docs_with_word_idx]

In [None]:
docs_with_word_vec = encoding.encode(docs_with_word)
total_len = len(docs_with_word_vec)

In [None]:
np.asarray(docs_with_word_vec).shape

In [None]:
sim = cosine_similarity(word_vec, Y=docs_with_word_vec, 
                           dense_output=True)[0]

In [None]:
ind = np.argsort(sim)[-10:][::-1]
[docs_with_word[idx] for idx in ind]

In [None]:
ind = np.argsort(sim)[:10]
idxs = [docs_with_word_idx[idx] for idx in ind]
[scraped_en_docs[idx] for idx in idxs]

In [None]:
retrieved

In [None]:
query_vector = encoding.encode(sent=[query])
retrieved = cont_retrieval.retrieve(query_vec=query_vector, num_ret=100)

In [None]:
import re

In [None]:
query_proj_2d = projection.get_projection_2d(query_vector)

In [None]:
def plot_umap_intra_sim(retrieved, imp_word):
    def strip_text(doc, word):
        sp = doc.split(" ")
        index = [idx for idx, s in enumerate(sp) if word in s][0]
        return " ".join(doc.split(" ")[index-20:index+20])
            
    ret_colors = []
    ret_docs = []
    ret_projs = []
    
    proj_2d = en_proj_2d
    indices = retrieved['en_ind']
    docs = scraped_en_docs
    
    for i in range(len(proj_2d)):
        if i in indices[:10]:
            ret_colors.append('red')
            ret_docs.append(docs[i])
            ret_projs.append(proj_2d[i])
        elif i in idxs:
            ret_colors.append('blue')
            stripped_doc = strip_text(docs[i], imp_word)
            ret_docs.append(stripped_doc)
            # ret_docs.append(docs[i])
            ret_projs.append(proj_2d[i])  
            
    ret_docs.append(query)
    ret_colors.append('green')
    ret_projs.extend(query_proj_2d)
    
    projection.plot_umap_2d(ret_docs, ret_projs, ret_colors)

In [None]:
plot_umap_intra_sim(retrieved, imp_word)

## MLM for query expansion

In [None]:
from transformers import BertTokenizer
import torch
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-uncased')
cont_model = torch.load('/home/sudhi/thesis/thesis_cltr_app/data/models/model_MLM_NSP.pt').module.eval()

In [159]:
def retrieve(query):
    query_vector = encoding.encode(sent=[query])
    retrieved = cont_retrieval.retrieve(query_vec=query_vector, num_ret=100)
    
    en_preds = vectorize.get_preds(cont_model, query, retrieved['en_docs'])
    en_ret_docs = vectorize.get_top_k(en_preds, retrieved['en_docs'], 10)
    
    de_preds = vectorize.get_preds(cont_model.to('cuda:0'), query, retrieved['de_docs'])
    de_ret_docs = vectorize.get_top_k(en_preds, retrieved['de_docs'], 10)    
    
    return en_ret_docs, de_ret_docs

In [None]:
query = 'adaptive crusie control'
query_vector = encoding.encode(sent=[query])
retrieved = cont_retrieval.retrieve(query_vec=query_vector, num_ret=100)

In [None]:
len(retrieved['en_docs'])

In [None]:
en_preds = vectorize.get_preds(cont_model, query, retrieved['en_docs'])
en_ret_docs = vectorize.get_top_k(en_preds, retrieved['en_docs'], 10)

In [None]:
de_preds = vectorize.get_preds(cont_model.to('cuda:0'), query, retrieved['de_docs'])
de_ret_docs = vectorize.get_top_k(en_preds, retrieved['de_docs'], 100)

In [None]:
en_ret_docs

In [None]:
de_ret_docs

In [None]:
from tqdm import tqdm
def get_preds(model, query, doc):
    preds = []
    with torch.no_grad():
        inputs = tokenizer(query, doc, return_tensors='pt',
                      max_length=512, padding=True, truncation=True).to('cuda')
        output = model(**inputs, output_hidden_states=True, output_attentions=True)

    return output  

In [None]:
def mean_pooling(sent):
    with torch.no_grad():
        inputs = tokenizer(sent, return_tensors='pt',
                      max_length=512, padding=True, truncation=True).to('cuda')
        output = cont_model(**inputs, output_hidden_states=True, output_attentions=True)
    
    # last layer gives the best embeddings
    token_embeddings = output['hidden_states'][-1] 
    # Mean Pooling - Take attention mask into account for correct averaging
    input_mask_expanded = inputs['attention_mask'].unsqueeze(-1).expand(token_embeddings.size()).float()
    sent_embedding = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sent_embedding.cpu()

In [None]:
inputs.keys()

In [None]:
mean_pooling(output, inputs['attention_mask'])[0]

In [None]:
input_id_list = inputs['input_ids'][0].tolist() # Batch index 0
print(input_id_list)
tokens = tokenizer.convert_ids_to_tokens(input_id_list)  
print(tokens)
print(len(tokens))

In [None]:
output.attentions[11].size()

In [None]:
query_tokens = tokens[tokens.index('[CLS]')+1: tokens.index('[SEP]')]

In [None]:
import numpy as np
head = 11
num_heads = 12
for head in range(num_heads):
    print(f"Head {head}")
    for ind, token in enumerate(tokens):
        attention = output.attentions[11][0][head][ind]
        indices = np.argsort(attention.cpu().numpy())[::-1][:5]
        print(f"{token}  --> {' '.join([tokens[idx] for idx in indices])}")
    print()

In [None]:
from transformers import pipeline
unmasker = pipeline('fill-mask', model=cont_model.to('cpu'), tokenizer=tokenizer, top_k=10)
from nltk.corpus import stopwords
stop_words_en = set(stopwords.words('english'))
stop_words_de = set(stopwords.words('german'))
import string
punct = list(string.punctuation)
def get_suggestions(query, imp_word):
    masked_query = re.sub(imp_word, "[MASK]", query)
    preds = unmasker(masked_query)
    suggestions = []
    for pred in preds:
        token_str = re.sub(" ", "", pred['token_str'])
        if token_str in stop_words_en:
            continue
        if token_str in stop_words_de:
            continue
        if token_str in punct:
            continue
        # TODO: sometimes, comprising, inclusing, etc. comes
        # add a rule to avoid these
        # avoid numbers
        suggestion = pred['sequence']
        if suggestion.lower() != query.lower():
            suggestions.append(suggestion)
    return suggestions

In [None]:
query = "sensors and cars"
imp_word = get_imp_word(query)
get_suggestions(query, imp_word)

In [None]:
np.argsort(np.asarray([5, 4, 1, 10]))[::-1]

## Retrieving with the contextual model

In [3]:
from transformers import BertTokenizer
import torch
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-uncased')
cont_model = torch.load('/home/sudhi/thesis/thesis_cltr_app/data/models/model_MLM_NSP.pt').module.eval()

In [4]:
import sys
sys.path.append('../')
from util import util, corpus, alignment, vectorize, retrieval, tokenize

from tqdm.notebook import tqdm as tq
import pandas as pd

In [49]:
scraped_en_docs = util.decompress_pickle('/home/sudhi/thesis/thesis_cltr_app/data/docs/scraped_en_docs') 
scraped_de_docs = util.decompress_pickle('/home/sudhi/thesis/thesis_cltr_app/data/docs/scraped_de_docs')          

# static model
st_en_doc_vectors = util.decompress_pickle('/home/sudhi/thesis/thesis_cltr_app/data/static/st_en_doc_vectors')
st_de_doc_vectors = util.decompress_pickle('/home/sudhi/thesis/thesis_cltr_app/data/static/st_de_doc_vectors')

scraped_en_docs_, st_en_doc_vectors_ = [], []
for doc, vec in zip(scraped_en_docs, st_en_doc_vectors):
    if len(doc.split(' ')) < 10:
        continue
    scraped_en_docs_.append(doc)
    st_en_doc_vectors_.append(vec)
        
scraped_de_docs_, st_de_doc_vectors_ = [], []
for doc, vec in zip(scraped_de_docs, st_de_doc_vectors):
    if len(doc.split(' ')) < 10:
        continue
    scraped_de_docs_.append(doc)
    st_de_doc_vectors_.append(vec)    
        
en_wv_unsup = util.decompress_pickle('/home/sudhi/thesis/thesis_cltr_app/data/static/en_wv_unsup')
de_wv_unsup = util.decompress_pickle('/home/sudhi/thesis/thesis_cltr_app/data/static/de_wv_unsup')

st_encoding = vectorize.Encoding(model=[en_wv_unsup, de_wv_unsup], model_type='static')
st_retrieval = vectorize.Retrieval(vectors=[st_en_doc_vectors_, st_de_doc_vectors_],
                                docs=[scraped_en_docs_, scraped_de_docs_]
                               )

In [51]:
query = 'adaptive cruise control'
query_vector = st_encoding.encode(sent=[query], lang='en')
retrieved = st_retrieval.retrieve(query_vec=query_vector, num_ret=100)
docs = [retrieved['en_docs'], retrieved['de_docs']]

In [53]:
en_preds = vectorize.get_preds(cont_model, query, retrieved['en_docs'])
en_ret_docs = vectorize.get_top_k(en_preds, retrieved['en_docs'], 100)

100%|██████████| 100/100 [00:01<00:00, 89.24it/s]


In [54]:
de_preds = vectorize.get_preds(cont_model.to('cuda:0'), query, retrieved['de_docs'])
de_ret_docs = vectorize.get_top_k(en_preds, retrieved['de_docs'], 100)

100%|██████████| 100/100 [00:01<00:00, 85.50it/s]


**trying to understand why the 500 documents retrieved by the static model are reranked by the contextual model.**

In [11]:
re_ranking = {}
for ct_ind, doc in enumerate(en_ret_docs):
    re_ranking[ct_ind] = retrieved['en_docs'].index(doc)

In [12]:
re_ranking

{0: 34,
 1: 66,
 2: 50,
 3: 65,
 4: 3,
 5: 0,
 6: 36,
 7: 33,
 8: 6,
 9: 41,
 10: 77,
 11: 16,
 12: 24,
 13: 79,
 14: 73,
 15: 52,
 16: 32,
 17: 31,
 18: 21,
 19: 7,
 20: 78,
 21: 54,
 22: 1,
 23: 64,
 24: 8,
 25: 90,
 26: 19,
 27: 59,
 28: 11,
 29: 30,
 30: 17,
 31: 14,
 32: 15,
 33: 74,
 34: 48,
 35: 35,
 36: 70,
 37: 2,
 38: 40,
 39: 68,
 40: 84,
 41: 62,
 42: 20,
 43: 38,
 44: 18,
 45: 47,
 46: 37,
 47: 49,
 48: 44,
 49: 86,
 50: 5,
 51: 51,
 52: 98,
 53: 57,
 54: 9,
 55: 43,
 56: 91,
 57: 55,
 58: 61,
 59: 75,
 60: 26,
 61: 45,
 62: 25,
 63: 13,
 64: 28,
 65: 72,
 66: 85,
 67: 87,
 68: 12,
 69: 23,
 70: 60,
 71: 99,
 72: 53,
 73: 56,
 74: 27,
 75: 94,
 76: 29,
 77: 93,
 78: 71,
 79: 89,
 80: 82,
 81: 80,
 82: 92,
 83: 81,
 84: 97,
 85: 42,
 86: 4,
 87: 22,
 88: 46,
 89: 95,
 90: 69,
 91: 10,
 92: 63,
 93: 96,
 94: 76,
 95: 83,
 96: 67,
 97: 58,
 98: 39,
 99: 88}

In [13]:
for ind in range(100):
    print(f"{ind}\nSTATIC:{retrieved['en_docs'][ind]}\nCONT:{en_ret_docs[ind]}")

0
STATIC:The adaptive cruise control adapts the vehicle’s speed to the flow of traffic.
CONT:Adaptive cruise control makes riding more convenient for the rider, especially in dense traffic, and could help to prevent rear-end collisions.
1
STATIC:An ideal upgrade for the adaptive cruise assist is also the
CONT:The adaptive cruise assist maintains proper speed and following distance via targeted acceleration and braking. The car automatically adapts its speed to the traffic situation and the route, for example at curves and cross-ways. In stop-and-go traffic as well as traffic jam situations, the ACA can bring the car to a complete stop. Depending on the duration of the stop, the car can start automatically again.
2
STATIC:Traffic jam assist is a subsystem of adaptive cruise control (ACC) or adaptive cruise assist (ACA). In vehicles with an automatic transmission, traffic jam assist can assume certain steering tasks over a speed range up to 65 km/h
CONT:The adaptive cruise assist, which 

In [14]:
l_count = []
for doc in en_ret_docs[:10]:
    l_count.append(len(doc.split(" ")))
    
print(f"{np.mean(np.asarray(l_count))}")


l_count = []
for doc in retrieved['en_docs'][:10]:
    l_count.append(len(doc.split(" ")))
    
print(f"{np.mean(np.asarray(l_count))}")

43.7
24.6


In [15]:
def get_sent_vector_contextual(sent):
    with torch.no_grad():
        inputs = tokenizer(sent, return_tensors='pt',
                      max_length=512, padding=True, truncation=True).to('cuda')
        output = cont_model(**inputs, output_hidden_states=True, output_attentions=True)
    
    # last layer gives the best embeddings
    token_embeddings = output['hidden_states'][-1] 
    # Mean Pooling - Take attention mask into account for correct averaging
    input_mask_expanded = inputs['attention_mask'].unsqueeze(-1).expand(token_embeddings.size()).float()
    sent_embedding = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return np.asarray(sent_embedding.cpu())

In [16]:
def sort_dict(dict_to_sort):
    return dict(sorted(dict_to_sort.items(), key=lambda item: item[1]))

In [17]:
def get_imp_word(query):
    query_vec = encoding.encode(sent=[query])
    word_import = {}
    for word in query.split(" "):
        word_vec = encoding.encode(sent=[word])
        word_import[word] = cosine_similarity(word_vec, query_vec)[0]
    word_import = sort_dict(word_import)
    imp_word = list(word_import.keys())[-1]
    return imp_word

In [None]:
imp_word = get_imp_word(query)
imp_word

In [None]:
word_vec = get_sent_vector_contextual(sent=[imp_word])
# word_vec = cont_encoding.encode(sent=[imp_word])
word_vec.shape

In [None]:
docs_with_word_idx = [idx for idx, doc in enumerate(scraped_en_docs) if imp_word in doc][:100]
docs_with_word = [scraped_en_docs[idx] for idx in docs_with_word_idx]

In [None]:
docs_with_word_vec = get_sent_vector_contextual(docs_with_word)
# docs_with_word_vec = cont_encoding.encode(docs_with_word)
total_len = len(docs_with_word_vec)
docs_with_word_vec.shape

In [None]:
sim = cosine_similarity(word_vec, Y=docs_with_word_vec, 
                           dense_output=True)[0]

In [None]:
np.argsort(sim), sim[np.argsort(sim)[0]], sim[np.argsort(sim)[-1]]

In [None]:
ind = np.argsort(sim)[:10]
idxs = [docs_with_word_idx[idx] for idx in ind]
other_docs = [scraped_en_docs[idx] for idx in idxs]
other_docs

In [None]:
en_ret_docs_vecs = get_sent_vector_contextual(en_ret_docs)
other_docs_vecs = get_sent_vector_contextual(other_docs)
query_vec = get_sent_vector_contextual(query)

In [None]:
en_ret_docs_vecs.shape, other_docs_vecs.shape, query_vec.shape

In [None]:
vectors = []
vectors.extend(query_vec)
vectors.extend(en_ret_docs_vecs)
vectors.extend(other_docs_vecs)
vectors = np.asarray(vectors)
vectors.shape

In [None]:
proj_2d = projection.get_projection_2d(vectors)

In [None]:
query_proj = proj_2d[:1]
ret_docs_proj = proj_2d[1:len(en_ret_docs)+1]
other_docs_proj = proj_2d[len(en_ret_docs)+1:]

In [None]:
query_proj.shape, ret_docs_proj.shape, other_docs_proj.shape

In [None]:
def plot_umap_intra_sim(query, en_ret_docs, other_docs, imp_word, proj_2d):
    def strip_text(doc, word):
        sp = doc.split(" ")
        index = [idx for idx, s in enumerate(sp) if word in s][0]
        return " ".join(doc.split(" ")[index-20:index+20])
            
    ret_colors = []
    ret_docs = []
    ret_projs = []
    
    query_proj = proj_2d[:1]
    ret_docs_proj = proj_2d[1:len(en_ret_docs)+1]
    other_docs_proj = proj_2d[len(en_ret_docs)+1:]
            
    for doc, proj in zip(en_ret_docs, ret_docs_proj):
        ret_colors.append('red')
        ret_docs.append(doc)
        ret_projs.append(proj)
        
    for doc, proj in zip(other_docs, other_docs_proj):
        ret_colors.append('blue')
        stripped_doc = strip_text(doc, imp_word)
        ret_docs.append(stripped_doc)
        ret_projs.append(proj)
            
    ret_docs.append(query)
    ret_colors.append('green')
    ret_projs.extend(query_proj)
    
    projection.plot_umap_2d(ret_docs, ret_projs, ret_colors)

In [None]:
plot_umap_intra_sim(query, en_ret_docs, other_docs, imp_word, proj_2d)

## Try to explain with attentions

In [20]:
pd.set_option('display.max_rows', 500)

In [21]:
def get_output(sent_1, sent_2=None):
    with torch.no_grad():
        if sent_2:
            inputs = tokenizer(sent_1, sent_2, return_tensors='pt',
                          max_length=512, padding=True, truncation=True).to('cuda')
        else:
            inputs = tokenizer(sent_1, return_tensors='pt',
                          max_length=512, padding=True, truncation=True).to('cuda')            
        output = cont_model(**inputs, output_hidden_states=True, output_attentions=True)
    return inputs, output

In [22]:
# output['attentions'] --> tuple of length 12 --> 12 layers
# each layer has torch.Size([1, 12, 6, 6])
# 12 heads, num of tokens, num of tokens

In [23]:
def get_tokens(inputs):
    input_id_list = inputs['input_ids'][0].tolist() # Batch index 0
    tokens = tokenizer.convert_ids_to_tokens(input_id_list)  
    query_tokens = tokens[tokens.index('[CLS]')+1: tokens.index('[SEP]')]
    doc_tokens = tokens[tokens.index('[SEP]')+1:]
    # print(f"There are {len(tokens)} tokens: {tokens}")
    return tokens, query_tokens, doc_tokens

In [24]:
{f"head_{head}": "" for head in range(6)}

{'head_0': '',
 'head_1': '',
 'head_2': '',
 'head_3': '',
 'head_4': '',
 'head_5': ''}

In [25]:
import numpy as np
def show_attention(output, tokens, layer=11, only_query_tokens=False, query_tokens=None):
    num_heads = 12
    print(f"Attention of Layer: {layer}")
    df = []
    for ind, token in enumerate(tokens):
        df_token = {f"head_{head}": "" for head in range(6)}
        df_token["token"] = token
        for head in range(num_heads):
            attention = output.attentions[layer][0][head][ind]
            # indices = np.argsort(attention.cpu().numpy())[::-1][-1:]
            indices = np.argsort(attention.cpu().numpy())[::-1][-5:]
            if only_query_tokens:
                df_token[f"head_{head}"] = ' '.join([tokens[idx] for idx in indices if tokens[idx] in query_tokens])
            else:
                df_token[f"head_{head}"] = ' '.join([tokens[idx] for idx in indices])
        df.append(df_token)
    
    df = pd.DataFrame(df)
    tokens_col = df.pop("token")
    df.insert(0, "token", tokens_col)
    
    return df

In [28]:
sent = 'automatic emergency braking'
inputs, output = get_output(sent_1=sent)
tokens, query_tokens, doc_tokens = get_tokens(inputs)
df = show_attention(output=output, tokens=tokens, layer=11)
df

Attention of Layer: 11


Unnamed: 0,token,head_0,head_1,head_2,head_3,head_4,head_5,head_6,head_7,head_8,head_9,head_10,head_11
0,[CLS],brak ##ing automatic emergency [CLS],##ing automatic brak emergency [CLS],automatic [SEP] brak emergency [CLS],##ing brak [SEP] emergency [CLS],emergency ##ing [SEP] brak [CLS],brak ##ing [CLS] automatic emergency,automatic brak emergency [CLS] [SEP],automatic ##ing [SEP] emergency [CLS],emergency ##ing brak [SEP] [CLS],[CLS] ##ing brak automatic emergency,automatic ##ing brak [CLS] emergency,automatic ##ing brak emergency [CLS]
1,automatic,automatic [CLS] emergency brak ##ing,emergency [SEP] ##ing brak [CLS],emergency [SEP] ##ing brak [CLS],[SEP] emergency brak ##ing [CLS],emergency [SEP] ##ing brak [CLS],emergency [SEP] ##ing [CLS] brak,automatic emergency ##ing brak [CLS],##ing automatic brak [CLS] emergency,##ing automatic emergency brak [CLS],automatic ##ing emergency brak [CLS],automatic [CLS] brak ##ing emergency,automatic emergency [CLS] ##ing brak
2,emergency,emergency brak automatic [CLS] ##ing,automatic [SEP] brak ##ing [CLS],automatic [SEP] brak ##ing [CLS],emergency automatic brak ##ing [CLS],automatic [SEP] brak ##ing [CLS],automatic [SEP] ##ing brak [CLS],emergency brak automatic ##ing [CLS],brak emergency ##ing automatic [CLS],##ing automatic emergency brak [CLS],emergency automatic brak ##ing [CLS],brak automatic emergency [CLS] ##ing,emergency brak automatic ##ing [CLS]
3,brak,[CLS] brak ##ing automatic emergency,emergency [SEP] automatic ##ing [CLS],[SEP] emergency ##ing automatic [CLS],[SEP] ##ing emergency automatic [CLS],[SEP] emergency automatic ##ing [CLS],[SEP] ##ing automatic emergency [CLS],[SEP] brak emergency automatic [CLS],automatic brak ##ing emergency [CLS],##ing [CLS] automatic brak emergency,brak ##ing emergency automatic [CLS],brak automatic ##ing emergency [CLS],brak ##ing emergency [CLS] automatic
4,##ing,brak [CLS] ##ing automatic emergency,[SEP] automatic brak emergency [CLS],[SEP] brak automatic emergency [CLS],[SEP] brak automatic [CLS] emergency,[SEP] automatic emergency brak [CLS],[SEP] automatic brak [CLS] emergency,automatic ##ing brak [CLS] emergency,##ing brak automatic emergency [CLS],##ing emergency automatic [CLS] brak,##ing brak [CLS] automatic emergency,##ing brak automatic [CLS] emergency,##ing [CLS] brak automatic emergency
5,[SEP],[CLS] ##ing brak [SEP] emergency,automatic brak emergency ##ing [CLS],brak automatic emergency ##ing [CLS],brak ##ing automatic emergency [CLS],brak emergency automatic ##ing [CLS],brak automatic emergency ##ing [CLS],brak emergency automatic ##ing [CLS],[CLS] brak ##ing automatic emergency,brak emergency [CLS] automatic ##ing,automatic brak ##ing emergency [CLS],brak emergency automatic ##ing [CLS],automatic brak emergency ##ing [CLS]


In [None]:
head_view(attention=output.attentions,
            tokens=tokens,
            sentence_b_start=6,
            prettify_tokens=False,
            layer=11,
            heads=None
         )

In [58]:
query, doc = 'adaptive cruise control', en_ret_docs[0]
inputs, output = get_output(sent_1=query, sent_2=doc)
tokens, query_tokens, doc_tokens = get_tokens(inputs)
df = show_attention(output=output, tokens=tokens, layer=11, only_query_tokens=False, query_tokens=query_tokens)
df

Attention of Layer: 11


Unnamed: 0,token,head_0,head_1,head_2,head_3,head_4,head_5,head_6,head_7,head_8,head_9,head_10,head_11
0,[CLS],ada control ##ptive especially dense,##ptive ##s collision - end,##veni for end ##ent -,##ent ada ada ##veni [CLS],"the , and ada .",con ##veni - to the,". ada , , -",##ptive con - end prevent,could help ada ##veni makes,[CLS] for - end cruise,ada ##s ada rear end,"end - to prevent ,"
1,ada,dense more end especially rear,"to the , - ##s","end , for - ##s","- ##ent , ##s end","end ##ent , to ##s","for the , to ##s","and help , to for",prevent end ##veni con rear,"rear ##veni , especially prevent",help ##veni con for ##ent,"- , con for ##s",con ##ent - end ##s
2,##ptive,dense prevent end especially rear,"and ##s the , ,","to ##s for , ,","traffic to con , ,","for and ##ent , ,","and to ##s , ,","##ent more to , for",", for rear end con","cruise rear ##veni prevent ,","for , ##veni con ##ent","rear , , for ##s",##veni for ##s con ##ent
3,cruise,prevent dense end rear especially,"to , and - ##s","and ##s , - ,","##s , and to ,",", - ##veni ##s ,",". ##s and to ,","help and the to ,","prevent con , rear end","dense cruise prevent rear ,","ada ##ent . ##s ,","to and - ##s ,","- rear , ##s ##ent"
4,control,- especially dense rear end,"##s , to rear -","end , , rear -",", to end rear -",end ##veni con rear -,"the , rear to -",for to rear - end,##ent ##veni con rear end,"especially , end rear prevent",", ada ada - rear","##s , end - rear",ada ##s end rear -
5,[SEP],help prevent rear . ##s,". , ##s , -",", ##ent ##s , -",##ptive ada - ada con,", [CLS] to ##s -",". and to ##s ,","and , - ##s ,","dense , ##s end -",ada con ada end ##veni,"##ptive end ada cruise ,","more ##veni for , con",", ##ptive ##ptive ada ada"
6,ada,help dense prevent more rear,"- , the , ##s","- for end , ##s","control , - ##s end","traffic , and , ##s","and the to ##s ,","and more to for ,",- con prevent end rear,##veni collision ada help prevent,##s con end for ##ent,", - con end ##s",##ent con - ##s end
7,##ptive,in dense especially end rear,"the for , , ##s","and ##s for , ,","end to ##s , ,","##s and for , ,","to the and , ,","##s , more to for",for prevent rear con end,", ##veni collision prevent help",##s con for ##veni ##ent,collision rear end for ##s,##veni ##ent for con ##s
8,cruise,prevent end especially dense rear,"- , to and ##s","- and , ##s ,",". ##s to and ,","and . , ##s ,","the ##s and to ,","especially more the , to",", ##s prevent rear end",ada cruise collision prevent help,"for , rear ##s ada",- to end rear ##s,- ##ent end rear ##s
9,control,especially in end dense rear,", rear ##s to -","in , rear , -","rear and , to -",", , rear ada -","and to , the -",more ada - rear end,con ##s prevent rear end,"end , collision prevent help",dense cruise - rear ada,con ##s - end rear,ada ##s end rear -


In [73]:
en_ret_docs[0]

'Adaptive cruise control makes riding more convenient for the rider, especially in dense traffic, and could help to prevent rear-end collisions.'

In [74]:
en = "Adaptive cruise control makes riding more convenient for the rider, especially in dense traffic, and could help to prevent rear-end collisions."
de = "Die adaptive Geschwindigkeitsregelung macht das Fahren für den Fahrer insbesondere bei dichtem Verkehr komfortabler und könnte helfen, Auffahrunfälle zu vermeiden."

In [80]:
query, doc = en, de
inputs, output = get_output(sent_1=query, sent_2=doc)
tokens, query_tokens = get_tokens(inputs)
df = show_attention(output=output, tokens=tokens, layer=11, only_query_tokens=False, query_tokens=query_tokens)
df

There are 70 tokens: ['[CLS]', 'ada', '##ptive', 'cruise', 'control', 'makes', 'riding', 'more', 'con', '##veni', '##ent', 'for', 'the', 'rider', ',', 'especially', 'in', 'dense', 'traffic', ',', 'and', 'could', 'help', 'to', 'prevent', 'rear', '-', 'end', 'collision', '##s', '.', '[SEP]', 'die', 'ada', '##ptive', 'geschwindigkeit', '##sr', '##ege', '##lung', 'macht', 'das', 'fahren', 'fur', 'den', 'fahrer', 'insbesondere', 'bei', 'dicht', '##em', 'verkehr', 'kom', '##fort', '##able', '##r', 'und', 'konnte', 'hel', '##fen', ',', 'auf', '##fa', '##hru', '##n', '##falle', 'zu', 'ver', '##meid', '##en', '.', '[SEP]']
Attention of Layer: 11


Unnamed: 0,token,head_0,head_1,head_2,head_3,head_4,head_5,head_6,head_7,head_8,head_9,head_10,head_11
0,[CLS],macht insbesondere ##fort dicht dense,##hru ##meid ##falle ##fa end,##able ##em ##hru - ##fa,[CLS] ##able ##ege kom ##fort,"##en ##ege , das ##fa",for to con the ##veni,"- ada the , ,",- con ##ent end prevent,macht ##ege konnte hel ver,rider ##veni ##ent for the,ada ada ver ##meid .,"##hru , die ##fen ##en"
1,ada,bei dicht insbesondere konnte ##meid,##fa ##hru zu ##falle ##n,##n ##fa ##s ##hru zu,##s ##hru ##falle ##em ##n,##em ##hru ##n ##en zu,##falle ##fa ##n zu ##s,##n verkehr ##en konnte zu,##fa ##fort ##fen ##meid ver,##fort ver insbesondere dicht ##meid,". prevent , ##r ##ent","##r zu ##n , ##s",##hru - ##n end ##s
2,##ptive,bei auf ##meid insbesondere konnte,##falle ##s ##en zu ##n,"##s ##n zu , ,","##s , ##n ##falle verkehr",zu und . ##en ##n,". , zu ##s ##n","##s ##fa , konnte zu",##hru ##fort ##fa ver ##meid,[CLS] ##fort verkehr dicht ##meid,"##meid ##s prevent , .",", , zu , ##s","##ent . ##n , ##s"
3,cruise,especially dicht auf insbesondere konnte,##fen ##fa zu ##n ##en,", , ##s , ##fa","##s ##fa ##n zu ,","##n ##fa zu , ##en",", ##s ##en ##fa zu","konnte ##en , ##fen zu","auf , ##en ver ##meid",[CLS] verkehr ver dicht ##meid,"rear prevent ##s , .","ver und ##en , zu","- rear . ##s ,"
4,control,kom ##meid dicht rear auf,"##em auf , ##n ##fa","##n ##hru auf , ##fa","ada ##hru , auf ##fa","##hru - , auf ##fa","##n , - auf ##fa",end dicht auf rear ##fa,##fa ##en prevent ver ##meid,fahren auf verkehr dicht ##meid,- ##fa ##meid prevent rear,. ##fa end auf rear,auf ##fa end - rear
5,makes,bei und konnte dicht auf,##em auf ##hru ##n ##fa,- ##n auf ##hru ##fa,##em auf ##n ##fa ##hru,##falle auf ##fa ##hru ##n,auf ##n ##hru - ##fa,auf ##falle ##hru verkehr ##fa,##hru auf ##meid ver ##fa,verkehr auf ver dicht ##meid,- end ##fa ##hru rear,"zu , auf ##n ##fa",rear - ##hru ##s end
6,riding,insbesondere dicht ##meid kom ver,- ##hru ##fa ##en ##n,##en ver ##hru ##fa ##n,- ##sr ##hru ##em ##n,##en ##hru ##fa ##sr ##n,zu ##fa . - ##n,ver ##fen zu ##n ##en,##en ##hru ##fa ##meid ver,verkehr ##hru dicht geschwindigkeit ##meid,rear ##hru ##sr ada ##fa,"and , zu ver ##en",##n ##hru ##s - end
7,more,ver rear bei ##meid auf,##s ##hru - ##fa ##n,##meid ##hru - ##n ##fa,##em ##hru - ##n ##fa,auf ##sr ##hru ##fa ##n,end ##hru ##n - ##fa,rear - ##hru end ##fa,##fen ##hru dicht auf ver,ver ##hru geschwindigkeit auf ##meid,##hru end ##fa rear ada,"##n zu ##en ver ,",##s . - rear end
8,con,konnte bei rear insbesondere auf,zu ##hru ##em ##n ##fa,##en - zu ##hru ##fa,verkehr auf ##hru ##em ##fa,"##hru zu ##fa ##en ,",end zu auf - ##fa,zu ##fa end rear auf,auf ##meid ##en ##fen ver,fahren dicht [CLS] insbesondere ##meid,##meid . rear end prevent,". zu , auf ,",##s . - rear end
9,##veni,dicht konnte rear insbesondere auf,##hru ##em auf ##n zu,to ##fa ##n zu auf,. bei verkehr ##em fahrer,"##fa auf , ##hru zu",", ##hru rear auf ##fa",end - rear ##fa auf,dicht ver ##en ##meid ##fen,especially ##fort [CLS] insbesondere ##meid,". , end prevent rear",". rear end auf ,",", . - end rear"


In [38]:
query, doc = 'adaptive cruise control', en_ret_docs[0]
inputs, output = get_output(sent_1=query, sent_2=doc)
tokens, query_tokens, doc_tokens = get_tokens(inputs)
df = show_attention(output=output, tokens=tokens, layer=11, only_query_tokens=False, query_tokens=query_tokens)
df

Attention of Layer: 11


Unnamed: 0,token,head_0,head_1,head_2,head_3,head_4,head_5,head_6,head_7,head_8,head_9,head_10,head_11
0,[CLS],ada control ##ptive especially dense,##ptive ##s collision - end,##veni for end ##ent -,##ent ada ada ##veni [CLS],"the , and ada .",con ##veni - to the,". ada , , -",##ptive con - end prevent,could help ada ##veni makes,[CLS] for - end cruise,ada ##s ada rear end,"end - to prevent ,"
1,ada,dense more end especially rear,"to the , - ##s","end , for - ##s","- ##ent , ##s end","end ##ent , to ##s","for the , to ##s","and help , to for",prevent end ##veni con rear,"rear ##veni , especially prevent",help ##veni con for ##ent,"- , con for ##s",con ##ent - end ##s
2,##ptive,dense prevent end especially rear,"and ##s the , ,","to ##s for , ,","traffic to con , ,","for and ##ent , ,","and to ##s , ,","##ent more to , for",", for rear end con","cruise rear ##veni prevent ,","for , ##veni con ##ent","rear , , for ##s",##veni for ##s con ##ent
3,cruise,prevent dense end rear especially,"to , and - ##s","and ##s , - ,","##s , and to ,",", - ##veni ##s ,",". ##s and to ,","help and the to ,","prevent con , rear end","dense cruise prevent rear ,","ada ##ent . ##s ,","to and - ##s ,","- rear , ##s ##ent"
4,control,- especially dense rear end,"##s , to rear -","end , , rear -",", to end rear -",end ##veni con rear -,"the , rear to -",for to rear - end,##ent ##veni con rear end,"especially , end rear prevent",", ada ada - rear","##s , end - rear",ada ##s end rear -
5,[SEP],help prevent rear . ##s,". , ##s , -",", ##ent ##s , -",##ptive ada - ada con,", [CLS] to ##s -",". and to ##s ,","and , - ##s ,","dense , ##s end -",ada con ada end ##veni,"##ptive end ada cruise ,","more ##veni for , con",", ##ptive ##ptive ada ada"
6,ada,help dense prevent more rear,"- , the , ##s","- for end , ##s","control , - ##s end","traffic , and , ##s","and the to ##s ,","and more to for ,",- con prevent end rear,##veni collision ada help prevent,##s con end for ##ent,", - con end ##s",##ent con - ##s end
7,##ptive,in dense especially end rear,"the for , , ##s","and ##s for , ,","end to ##s , ,","##s and for , ,","to the and , ,","##s , more to for",for prevent rear con end,", ##veni collision prevent help",##s con for ##veni ##ent,collision rear end for ##s,##veni ##ent for con ##s
8,cruise,prevent end especially dense rear,"- , to and ##s","- and , ##s ,",". ##s to and ,","and . , ##s ,","the ##s and to ,","especially more the , to",", ##s prevent rear end",ada cruise collision prevent help,"for , rear ##s ada",- to end rear ##s,- ##ent end rear ##s
9,control,especially in end dense rear,", rear ##s to -","in , rear , -","rear and , to -",", , rear ada -","and to , the -",more ada - rear end,con ##s prevent rear end,"end , collision prevent help",dense cruise - rear ada,con ##s - end rear,ada ##s end rear -


In [60]:
query, doc = 'adaptive cruise control', en_ret_docs[0]
inputs, output = get_output(sent_1=query, sent_2=doc)
tokens, query_tokens, doc_tokens = get_tokens(inputs)
df = show_attention(output=output, tokens=tokens, layer=10, only_query_tokens=True, query_tokens=doc_tokens)
df

Attention of Layer: 10


Unnamed: 0,token,head_0,head_1,head_2,head_3,head_4,head_5,head_6,head_7,head_8,head_9,head_10,head_11
0,[CLS],rear ##s - end prevent,##ptive collision ##s ada,to prevent ##veni end -,", and ##s to",collision end - prevent rear,##ptive ada - rear end,- prevent ##s to con,prevent to rear end -,riding ada control ##ptive cruise,rear ##s collision .,especially help to and .,##s con to end -
1,ada,dense collision - rear end,"for to , - ##s",", control end ##s -","prevent end , - ##s","for in traffic , ,",collision rear ##s end -,"- to , end ##s",", for end - ##s",", control ada ##ptive cruise",to ##veni ##ptive ada,collision rear ##s end -,and - prevent collision ##s
2,##ptive,collision - traffic end ##s,"end . ##s , -","to , ##s - for",", for - end ##s","for , end and especially",cruise - collision end ##s,", prevent ##s for end","collision prevent , - ##s",", control ##ptive cruise ada",##ptive ada ##ptive ada,collision ##s rear - end,end for - collision ##s
3,cruise,"##s and , traffic ,","prevent , end ##s -","##veni for to - ,","rear , - end ##s","rear in especially , and",ada to - collision ##s,"to ##s end - ,",", to - , ##s",more control cruise ##ptive ada,##ptive ada ##ptive ada,collision ##s rear - end,"end prevent , - ##s"
4,control,", in - and to","- rear , end .",to for end rear -,", - rear end ##s",in dense rear - end,to ##s rear end -,", to rear - end",to ##s end rear -,control con ##ptive cruise ada,control ada ##ptive ada,##s collision - rear end,", rear ##s end -"
5,[SEP],and prevent ##veni ##ent end,"con to - ##s ,",rear ##s end to -,"ada , , ##s -",rear ada end - to,to con rear end -,"for ##s , to -",##veni ##s to ada for,"to end , ##s -",the ##s in end rear,for rear ##s - end,##ent end prevent ##s -
6,ada,dense collision - end rear,end for - ##s,makes prevent ##s end -,"end to ##s - ,","prevent ##s , rear ,",collision ##s rear - end,prevent - to end ##s,", for - end ##s",##s cruise - end prevent,rear ada prevent cruise collision,prevent collision ##s end -,", and - collision ##s"
7,##ptive,traffic - collision end ##s,ada for end ##s,##ent collision for end -,", - end to ##s",and - for end especially,rear - collision ##s end,to prevent con end for,- prevent collision for ##s,collision ada - end prevent,prevent ##ptive ada cruise collision,##s collision rear - end,and end - collision ##s
8,cruise,"rear , traffic ##s collision",##ptive end ada ##s,and ada prevent to -,rear - ##s end to,- especially end and rear,the collision prevent ##s rear,"##s , - end to","for , - , ##s",and - to end prevent,prevent ##ptive cruise ada collision,riding dense - rear end,"and , end ##s -"
9,control,"- , prevent to rear",##ptive end ##s ada,##veni end for rear -,to - end rear ##s,especially dense - rear end,to prevent - end rear,"##s , - end to",##s in end - rear,con to - prevent end,##s cruise ##ptive ada collision,collision prevent - rear end,collision rear ##s end -


In [67]:
df[:6].to_csv('dummy.csv')

In [195]:
import numpy as np
def get_attended_tokens(output, tokens, query_tokens, layer=11):
    metrics = []
    num_heads = 12
    doc_start_ind = tokens.index('[SEP]') + 1
    for ind, token in enumerate(tokens):
        if ind < doc_start_ind:
            continue
        attended_tokens = []
        for head in range(num_heads):
            attention = output.attentions[layer][0][head][ind]
            # indices = np.argsort(attention.cpu().numpy())[::-1][-1:]
            indices = np.argsort(attention.cpu().numpy())[::-1][-len(query_tokens):]
            attended = [tokens[idx] for idx in indices if tokens[idx] in query_tokens]
            m = get_metrics(query_tokens, attended)
            metrics.append( 
                {
                'prec': m['prec'],
                'recall': m['recall'],
                'token': tokens[ind],
                'head': head
                }
            )
            attended_tokens.extend(attended)
        attended_tokens = list(set(attended_tokens)) 
    return attended_tokens, metrics

In [185]:
def get_metrics(query_tokens, attended_tokens):
    tp, fp, tn, fn = 0, 0, 0, 0
    for token in query_tokens:
        if token in attended_tokens:
            tp += 1
        else:
            fn += 1
            
    for token in attended_tokens:
        if token not in query_tokens:
            fp += 1
        
    prec, recall = 0, 0
    if tp + fp:
        prec = tp / (tp + fp)
    if tp + fn:
        recall = tp / (tp + fn)
    
    metrics = {
        'prec': prec,
        'recall': recall
    }
    return metrics

In [232]:
from collections import Counter
def show_scores(docs):
    layer_counts = []
    head_counts = []
    for idx, doc in enumerate(docs):
        query, doc = 'adaptive cruise control', doc
        inputs, output = get_output(sent_1=query, sent_2=doc)
        tokens, query_tokens, doc_tokens = get_tokens(inputs)
        macro = {'prec': 0, 'recall': 0}
        layer_count = []
        token_count = []
        head_count = []
        for layer in range(12):
            attended_tokens, metrics = get_attended_tokens(output, tokens, query_tokens, layer=layer)
            # metrics = get_metrics(query_tokens, attended_tokens)
            # macro['prec'] += metrics['prec']
            # macro['recall'] += metrics['recall']            
            for metric in metrics:
                macro['prec'] += metric['prec']
                macro['recall'] += metric['recall']
                if metric['prec'] == 1 and metric['recall'] == 1:
                    layer_count.append(layer)
                    token_count.append(metric['token'])
                    head_count.append(metric['head'])
            macro['prec'] /= len(metrics)
            macro['recall'] /= len(metrics)
            f_score = (macro['prec'] + macro['recall']) / 2
        

        macro['prec'] /= 12
        macro['recall'] /= 12
        macro['f_score'] = (macro['prec'] + macro['recall']) /2
        
        print(f"Retrieved at rank {idx}", macro['f_score'])
        print('Layer Count', Counter(layer_count).most_common(3))
        print('Token Count', Counter(token_count).most_common(3))
        print('Head Count', Counter(head_count).most_common(3))
        
        layer_counts.append(Counter(layer_count).most_common(1)[0])
        head_counts.append(Counter(head_count).most_common(1)[0])
        
    count = {
        'layer': Counter(layer_counts).most_common(1)[0],
        'head': Counter(head_counts).most_common(1)[0]
    }
        
    return count

In [233]:
query = 'adaptive cruise control'
en_ret_docs, de_ret_docs = retrieve(query)
print("EN")
count = show_scores(en_ret_docs)
print(f"Layer: {Counter(count['layer']).most_common(1)[0][0]}, Head: {Counter(count['head']).most_common(1)[0][0]}")
print("DE")
count = show_scores(de_ret_docs)
print(f"Layer: {Counter(count['layer']).most_common(1)[0][0]}, Head: {Counter(count['head']).most_common(1)[0][0]}")

100%|██████████| 100/100 [00:01<00:00, 91.33it/s]
100%|██████████| 100/100 [00:01<00:00, 88.99it/s]


EN
Retrieved at rank 0 0.03549626239652399
Layer Count [(2, 34), (7, 30), (8, 22)]
Token Count [(',', 20), ('to', 13), ('dense', 12)]
Head Count [(7, 39), (9, 30), (8, 14)]
Retrieved at rank 1 0.01750645690259295
Layer Count [(2, 65), (6, 13), (0, 10)]
Token Count [('the', 8), ('in', 7), ('and', 6)]
Head Count [(7, 67), (11, 11), (9, 9)]
Retrieved at rank 2 0.014058686254523133
Layer Count [(2, 90), (0, 13), (6, 9)]
Token Count [('the', 14), ('.', 8), (',', 6)]
Head Count [(7, 90), (9, 16), (11, 7)]
Retrieved at rank 3 0.011839308072262631
Layer Count [(2, 100), (0, 13), (6, 4)]
Token Count [('the', 13), ('.', 12), ('and', 8)]
Head Count [(7, 101), (9, 13), (11, 3)]
Retrieved at rank 4 0.041019929976200555
Layer Count [(7, 37), (2, 34), (5, 15)]
Token Count [('to', 24), ('the', 23), ('preceding', 11)]
Head Count [(7, 43), (9, 30), (3, 15)]
Retrieved at rank 5 0.05170020736174518
Layer Count [(2, 35), (7, 35), (5, 25)]
Token Count [('the', 23), ('to', 18), ('flow', 16)]
Head Count [(7, 

In [234]:
query = 'automatische notbremsung'
en_ret_docs, de_ret_docs = retrieve(query)
print("EN")
count = show_scores(en_ret_docs)
print(f"Layer: {Counter(count['layer']).most_common(1)[0][0]}, Head: {Counter(count['head']).most_common(1)[0][0]}")
print("DE")
count = show_scores(de_ret_docs)
print(f"Layer: {Counter(count['layer']).most_common(1)[0][0]}, Head: {Counter(count['head']).most_common(1)[0][0]}")

100%|██████████| 100/100 [00:01<00:00, 94.08it/s]
100%|██████████| 100/100 [00:01<00:00, 91.26it/s]


EN
Retrieved at rank 0 0.044989418578295956
Layer Count [(7, 42), (11, 34), (2, 29)]
Token Count [('pe', 17), ('##rians', 14), ('.', 13)]
Head Count [(7, 40), (10, 33), (2, 28)]
Retrieved at rank 1 0.049159649762418496
Layer Count [(7, 54), (2, 25), (8, 20)]
Token Count [('and', 23), ('to', 15), ('-', 14)]
Head Count [(9, 29), (7, 29), (3, 24)]
Retrieved at rank 2 0.046987620874196376
Layer Count [(7, 50), (2, 25), (11, 19)]
Token Count [('and', 25), ('avoid', 13), ('the', 13)]
Head Count [(7, 38), (3, 23), (9, 22)]
Retrieved at rank 3 0.0031287534644060407
Layer Count [(2, 177), (0, 73), (1, 5)]
Token Count [('the', 36), ('to', 19), (',', 15)]
Head Count [(7, 178), (9, 70), (0, 3)]
Retrieved at rank 4 0.0053520807143551214
Layer Count [(2, 164), (0, 62), (1, 5)]
Token Count [('the', 22), (',', 14), ('.', 11)]
Head Count [(7, 165), (9, 60), (10, 3)]
Retrieved at rank 5 0.07128514758870197
Layer Count [(7, 52), (8, 48), (11, 47)]
Token Count [('##ing', 35), ('of', 29), ('live', 29)]
Hea

In [235]:
query = 'driver assistance systems'
en_ret_docs, de_ret_docs = retrieve(query)
print("EN")
count = show_scores(en_ret_docs)
print(f"Layer: {Counter(count['layer']).most_common(1)[0][0]}, Head: {Counter(count['head']).most_common(1)[0][0]}")
print("DE")
count = show_scores(de_ret_docs)
print(f"Layer: {Counter(count['layer']).most_common(1)[0][0]}, Head: {Counter(count['head']).most_common(1)[0][0]}")

100%|██████████| 100/100 [00:01<00:00, 94.00it/s]
100%|██████████| 100/100 [00:01<00:00, 89.74it/s]


EN
Retrieved at rank 0 0.03911295165067779
Layer Count [(7, 39), (2, 29), (5, 26)]
Token Count [('are', 22), ('before', 18), ('the', 18)]
Head Count [(7, 34), (9, 27), (3, 26)]
Retrieved at rank 1 0.06365590821295922
Layer Count [(7, 48), (11, 39), (2, 31)]
Token Count [('the', 64), ('.', 35), ('to', 31)]
Head Count [(8, 49), (9, 40), (7, 28)]
Retrieved at rank 2 0.027085338612843207
Layer Count [(2, 44), (11, 14), (5, 13)]
Token Count [('.', 9), ('and', 9), ('assist', 7)]
Head Count [(7, 47), (8, 19), (9, 10)]
Retrieved at rank 3 0.07263596986225701
Layer Count [(10, 83), (11, 81), (7, 76)]
Token Count [('all', 45), ('and', 44), ('protect', 44)]
Head Count [(9, 53), (8, 53), (7, 50)]
Retrieved at rank 4 0.0065648924725996765
Layer Count [(2, 110), (0, 27), (6, 3)]
Token Count [('the', 13), (',', 10), ('.', 10)]
Head Count [(7, 111), (9, 25), (11, 3)]
Retrieved at rank 5 0.05441390357695708
Layer Count [(7, 43), (2, 36), (5, 32)]
Token Count [('the', 30), ('.', 21), ('a', 19)]
Head Cou

In [236]:
query = 'adaptive Geschwindigkeitsregelung'
en_ret_docs, de_ret_docs = retrieve(query)
print("EN")
count = show_scores(en_ret_docs)
print(f"Layer: {Counter(count['layer']).most_common(1)[0][0]}, Head: {Counter(count['head']).most_common(1)[0][0]}")
print("DE")
count = show_scores(de_ret_docs)
print(f"Layer: {Counter(count['layer']).most_common(1)[0][0]}, Head: {Counter(count['head']).most_common(1)[0][0]}")

100%|██████████| 100/100 [00:01<00:00, 90.92it/s]
100%|██████████| 100/100 [00:01<00:00, 87.61it/s]


EN
Retrieved at rank 0 0.03549626239652399
Layer Count [(2, 34), (7, 30), (8, 22)]
Token Count [(',', 20), ('to', 13), ('dense', 12)]
Head Count [(7, 39), (9, 30), (8, 14)]
Retrieved at rank 1 0.01750645690259295
Layer Count [(2, 65), (6, 13), (0, 10)]
Token Count [('the', 8), ('in', 7), ('and', 6)]
Head Count [(7, 67), (11, 11), (9, 9)]
Retrieved at rank 2 0.041019929976200555
Layer Count [(7, 37), (2, 34), (5, 15)]
Token Count [('to', 24), ('the', 23), ('preceding', 11)]
Head Count [(7, 43), (9, 30), (3, 15)]
Retrieved at rank 3 0.05170020736174518
Layer Count [(2, 35), (7, 35), (5, 25)]
Token Count [('the', 23), ('to', 18), ('flow', 16)]
Head Count [(7, 34), (9, 27), (5, 18)]
Retrieved at rank 4 0.024048274319439858
Layer Count [(2, 53), (10, 12), (0, 11)]
Token Count [('and', 9), ('.', 6), ('automatically', 6)]
Head Count [(7, 54), (9, 28), (11, 12)]
Retrieved at rank 5 0.01804908504518042
Layer Count [(2, 89), (0, 16), (6, 8)]
Token Count [(',', 13), ('the', 12), ('in', 6)]
Head C

**Observations**
- mostly Layers **2** and **7** have heads where tokens attend to almost all the query tokens.
- mostly Heads **7** and **9** where tokens attend to almost all the query tokens.
- most of the articles, conjunctions and prepositions (a, an, and, die, der, mit, ) attend to all the query tokens.