## 0. Set envs & Load DB, Model and Tokenizer

### Hyperparameter configuration

In [1]:
import os
os.getcwd()
EXP_PATH = '/home/ubuntu/temp/kg_txt_multimodal/lxmert'
# add env
import sys; sys.path.append(EXP_PATH+'/src')

In [2]:
import os
import easydict

# Set path
# EXP_PATH = os.path.dirname(os.getcwd())
TASK_NAME = 'masked_literal_prediction'
RUN_NAME = 'TransE_NoKGenc_H128'

# Essential Hyperparameters
args = easydict.EasyDict({
    "model_type":"lxmert",
    "model_name_or_path":os.path.join(EXP_PATH,'pretrained_models',RUN_NAME),
    "tokenizer_name":"bert-base-uncased",
    "cache_dir":None,
    "eval_criterion" :"lang_acc,kg_acc",
    "block_size":-1,
    "batch_size":1,
    "eval_data_file": os.path.join(EXP_PATH,"data/{}/valid".format(TASK_NAME)),
    "test_data_file": os.path.join(EXP_PATH, "data/{}/test".format(TASK_NAME)),
    "run_name":RUN_NAME,
    "seed":1234,
})

### Environment settings

In [3]:
# Base packages
import logging
import math
from dataclasses import dataclass, field
from glob import glob
from typing import Optional

import torch
import torch.nn as nn
from torch.utils.data import ConcatDataset
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.sampler import RandomSampler, SequentialSampler

# Own implementation
from utils.parameters import parser
from utils.dataset import get_dataset
from utils.data_collator import NodeMasking_DataCollator, NodeClassification_DataCollator, LiteralRegression_DataCollator
from model import LxmertForPreTraining, LxmertForKGTokPredAndMaskedLM

# From Huggingface transformers package
from transformers import (
    CONFIG_MAPPING,
    MODEL_WITH_LM_HEAD_MAPPING,
    LxmertConfig,
    LxmertTokenizer,
    PreTrainedTokenizer,
    # Trainer,
    set_seed,
)

# Set enviroments
set_seed(args.seed)

### Load tokenizer

In [4]:
tokenizer = LxmertTokenizer.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)
## Sanity check
print(tokenizer.convert_tokens_to_ids(tokenizer.tokenize('Just for sanity check [CLS] [SEP] [MASK] [PAD]')))

[2074, 2005, 20039, 4638, 101, 102, 103, 0]


### Load pretrained model

In [5]:
# Load configuration
config = LxmertConfig.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
config.pretrained_kg_embedding = EXP_PATH + '/from_pretrained/TransE_128/transe.ckpt' # temp_modify

# Load trained model
model = LxmertForKGTokPredAndMaskedLM.from_pretrained(
    args.model_name_or_path,
    from_tf=bool(".ckpt" in args.model_name_or_path),
    config=config,
    cache_dir=args.cache_dir,
)
model.eval()

# ## Sanity check
# print('='*100)
# print(config)
# print('='*100)
# print(model)
# print('='*100)

LxmertForKGTokPredAndMaskedLM(
  (lxmert): LxmertModel(
    (lang_embeddings): LxmertEmbeddings(
      (word_embeddings): Embedding(30522, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128, padding_idx=0)
      (token_type_embeddings): Embedding(2, 128, padding_idx=0)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (kg_embeddings): LxmertEmbeddings(
      (word_embeddings): Embedding(167494, 128, padding_idx=0)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): LxmertEncoder(
      (KG_fc): LxmertKGFeatureEncoder(
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (layer): ModuleList(
        (0): LxmertLayer(
          (attention): LxmertSelfAttentionLayer(
            (self): LxmertAttention(
              (query): Linear(in_features=128, out_features=128, bias=True)
              (key

## 1. visualization for attention heads

### Dataloader without masking

In [6]:
if args.block_size <= 0:
    args.block_size = tokenizer.max_len
    # Our input block size will be the max possible for the model
else:
    args.block_size = min(args.block_size, tokenizer.max_len)

# Get datasets
dataset = (
    get_dataset(args, tokenizer=tokenizer, kg_pad=config.kg_special_token_ids["PAD"], evaluate=True)
)

data_collator = NodeClassification_DataCollator(tokenizer=tokenizer,
                                                kg_special_token_ids=config.kg_special_token_ids,
                                                kg_size=config.vocab_size['kg'],
                                                prediction=True)

# Get data loader
data_loader = DataLoader(
    dataset,
    sampler=SequentialSampler(dataset),
    batch_size=args.batch_size,
    collate_fn=data_collator,
    drop_last=True,
    pin_memory=True,
)



### Get attention score from sample

In [209]:
id2entity = {int(line.split('\t')[1])+len(config.kg_special_token_ids):\
             line.split('\t')[0].split('^^')[0] for line in open(os.path.join(EXP_PATH,'data','entity2id.txt')).read().splitlines()[1:]}


def get_data_info_from_sample_for_viz(sample_idx=0,
                                      model=model,
                                      data_loader=data_loader,
                                      tokenizer=tokenizer,
                                      kg_id_mapping=id2entity):
    
    for idx, data in enumerate(data_loader):
        if idx==sample_idx:
            input_data = data
            break
    
    data_info = {}
    
    # lang part
    lang_temp_input_ids = input_data['lang_input_ids'].cpu().squeeze()
    lang_input_len = len(lang_temp_input_ids.nonzero())
    lang_tokens = tokenizer.convert_ids_to_tokens(lang_temp_input_ids[:lang_input_len])
    lang_seq = tokenizer.convert_tokens_to_string(
        tokenizer.convert_ids_to_tokens(lang_temp_input_ids[:lang_input_len]))
    data_info['lang_input_len'] = lang_input_len 
    data_info['lang_tokens'] = lang_tokens
    data_info['lang_seq'] = lang_seq
    
    # kg part
    kg_temp_input_ids = input_data['kg_input_ids'].cpu().squeeze()
    kg_input_len = len(kg_temp_input_ids.nonzero())
    kg_tokens = list(map(lambda x: kg_id_mapping[x], kg_temp_input_ids[:kg_input_len].tolist()))
    data_info['kg_input_len'] = kg_input_len
    data_info['kg_tokens'] = kg_tokens
    
    # output attentions
    output_data = model(**input_data, output_attentions=True, return_dict=True)
    data_info['lang_attentions'] =  [l_layer[:,:,
                                             :lang_input_len,
                                             :lang_input_len].cpu() for l_layer in output_data['language_attentions']]
    temp_c_attentions = torch.cat([c_layer[:,:,
                                          :(lang_input_len+kg_input_len),
                                          :(kg_input_len+lang_input_len)].cpu() for c_layer in output_data['cross_encoder_attentions']])
    temp_c1_attentions = temp_c_attentions[:,:,:,:kg_input_len]
    temp_c2_attentions = temp_c_attentions[:,:,:,-lang_input_len:]
    c_attentions = torch.cat([temp_c2_attentions, temp_c1_attentions], dim=-1)
    data_info['cross_attentions'] = [c_attentions[layer_idx,:,:,:].unsqueeze(0) for layer_idx in range(len(temp_c_attentions[:]))]
    
    return data_info

### show attention(head view)

In [204]:
%%javascript
require.config({
  paths: {
      d3: '//cdnjs.cloudflare.com/ajax/libs/d3/3.4.8/d3.min',
      jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
  }
});

<IPython.core.display.Javascript object>

In [216]:
from bertviz import head_view

def show_head_view(data_info, view_option='language'):
        
    lang_tokens = data_info['lang_tokens']
    kg_tokens = data_info['kg_tokens']
    lang_attentions = data_info['lang_attentions']
    cross_attentions = data_info['cross_attentions']
    
    print(data_info['lang_seq'])
    
    if view_option == 'langauge':
        head_view(attention=lang_attentions, tokens=lang_tokens)
        
    else:
        all_tokens = lang_tokens + kg_tokens
        sentence_b_start = len(lang_tokens)
        head_view(attention=cross_attentions, tokens=all_tokens , sentence_b_start=sentence_b_start)

In [217]:
result = get_data_info_from_sample_for_viz(sample_idx=26)

In [218]:
show_head_view(result, 'langauge')

[CLS] cardiac catheterization [SEP]


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [219]:
show_head_view(result, 'cross')

[CLS] cardiac catheterization [SEP]


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [205]:
# import matplotlib.pyplot as plt
# fig, axs = plt.subplots(model.config.x_layers, model.config.num_attention_heads, figsize=(20,20))
# for layer_idx in range(model.config.x_layers):
#     for head_idx in range(model.config.num_attention_heads):
#         axs[model.config.l_layers-1-layer_idx, head_idx].matshow(
#             output_data['cross_encoder_attentions'][layer_idx][0,head_idx,
#                                                          :lang_input_len,:kg_input_len].cpu().detach().numpy(),
#             #cmap='gray',
#         )
        
#         if (layer_idx == 0) and (head_idx==0):
#             axs[layer_idx, head_idx].set_xticks(range(len(kg_tokens)))
#             axs[layer_idx, head_idx].set_yticks(range(len(lang_tokens)))
#             axs[layer_idx, head_idx].set_xticklabels(kg_tokens, rotation='45',horizontalalignment='left', fontsize=8)
#             axs[layer_idx, head_idx].set_yticklabels(lang_tokens, rotation='45',horizontalalignment='right', fontsize=8)
#         axs[0, head_idx].set_title("Head_{}\n".format(head_idx+1))
#     axs[layer_idx, 0].set_ylabel('Layer_{}\n'.format(model.config.l_layers-layer_idx))
# plt.suptitle('Cross Modal Attention Vis', fontsize=40)

In [206]:
# for idx, data in enumerate(data_loader):
#     if idx==SAMPLE_IDX:
#         input_data = data
#         break
# print('==== Text ====')
# temp = input_data['lang_input_ids'].cpu().squeeze()
# #print('ID seq : {}'.format(temp[:temp.nonzero().shape[0]].tolist()))
# lang_tokens = tokenizer.convert_ids_to_tokens(temp[:temp.nonzero().shape[0]].tolist())
# print('Token seq : {}\n'.format(lang_tokens))
# print('Full sentence : {}\n\n'.format(tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(temp[:temp.nonzero().shape[0]].tolist()))))
# #print('Seq length : {}'.format(len(temp[:temp.nonzero().shape[0]].tolist())))
# lang_len = len(temp[:temp.nonzero().shape[0]])
# print(lang_len)
# print('==== KG ====')
# temp = input_data['kg_input_ids'].cpu().squeeze()
# #print('ID seq : {}'.format(temp[:temp.nonzero().shape[0]].tolist()))
# kg_tokens = list(map(lambda x: id2entity[x],temp[:temp.nonzero().shape[0]].tolist()))
# print('Token seq : {}\n'.format(kg_tokens))
# #print('Seq length : {}'.format(len(temp[:temp.nonzero().shape[0]].tolist())))
# kg_len = len(temp[:temp.nonzero().shape[0]])

# output_data = model(**input_data, output_attentions=True, return_dict=True)

### show attention(model view)

In [214]:
# from bertviz import model_view

# def show_model_view(data_info, view_option='language'):
        
#     lang_tokens = data_info['lang_tokens']
#     kg_tokens = data_info['kg_tokens']
#     lang_attentions = data_info['lang_attentions']
#     cross_attentions = data_info['cross_attentions']
    
#     print(data_info['lang_seq'])
    
#     if view_option == 'langauge':
#         model_view(attention=lang_attentions, tokens=lang_tokens)
        
#     else:
#         all_tokens = lang_tokens + kg_tokens
#         sentence_b_start = len(lang_tokens)
#         model_view(attention=cross_attentions, tokens=all_tokens , sentence_b_start=sentence_b_start)

In [215]:
# show_model_view(result)