## 0. Set envs & Load DB, Model and Tokenizer

### Hyperparameter configuration

In [None]:
import os
import easydict

# Set path
EXP_PATH = os.path.dirname(os.getcwd())
TASK_NAME = 'masked_literal_prediction'
RUN_NAME = 'TransE_NoKGenc_H128'

# Essential Hyperparameters
args = easydict.EasyDict({
    "model_type":"lxmert",
    "model_name_or_path":os.path.join(EXP_PATH,'pretrained_models',RUN_NAME),
    "tokenizer_name":"bert-base-uncased",
    "cache_dir":None,
    "eval_criterion" :"lang_acc,kg_acc",
    "block_size":-1,
    "batch_size":1,
    "eval_data_file": os.path.join(EXP_PATH,"data/{}/valid".format(TASK_NAME)),
    "test_data_file": os.path.join(EXP_PATH, "data/{}/test".format(TASK_NAME)),
    "run_name":RUN_NAME,
    "seed":1234,
})

### Environment settings

In [None]:
# Base packages
import logging
import math
from dataclasses import dataclass, field
from glob import glob
from typing import Optional
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import ConcatDataset
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.sampler import RandomSampler, SequentialSampler

# Own implementation
from utils.parameters import parser
from utils.dataset import get_dataset
from utils.data_collator import NodeMasking_DataCollator, NodeClassification_DataCollator, LiteralRegression_DataCollator
from model import LxmertForPreTraining, LxmertForKGTokPredAndMaskedLM

# From Huggingface transformers package
from transformers import (
    CONFIG_MAPPING,
    MODEL_WITH_LM_HEAD_MAPPING,
    LxmertConfig,
    LxmertTokenizer,
    PreTrainedTokenizer,
    # Trainer,
    set_seed,
)

# Set enviroments
set_seed(args.seed)

### Load tokenizer

In [None]:
tokenizer = LxmertTokenizer.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)
## Sanity check
print(tokenizer.convert_tokens_to_ids(tokenizer.tokenize('Just for sanity check [CLS] [SEP] [MASK] [PAD]')))

### Load pretrained model

In [None]:
# Load configuration
config = LxmertConfig.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)


# Load trained model
model = LxmertForKGTokPredAndMaskedLM.from_pretrained(
    args.model_name_or_path,
    from_tf=bool(".ckpt" in args.model_name_or_path),
    config=config,
    cache_dir=args.cache_dir,
)
model.eval()

## Sanity check
print('='*100)
print(config)
print('='*100)
print(model)
print('='*100)

## 1. Class-wise accuracy

### Dataloader with masking

In [None]:
# Load model on CPU for prediction
model.cuda()

# Build vocab for grpah
id2entity = {int(line.split('\t')[1])+len(config.kg_special_token_ids):line.split('\t')[0].split('^^')[0] for line in open(os.path.join(EXP_PATH,'data','entity2id.txt')).read().splitlines()[1:]}
label2entity = {v:id2entity[k] for (k,v) in torch.load(os.path.join(EXP_PATH,"data/{}/id2label".format(TASK_NAME))).items()}

if args.block_size <= 0:
    args.block_size = tokenizer.max_len
    # Our input block size will be the max possible for the model
else:
    args.block_size = min(args.block_size, tokenizer.max_len)

# Get datasets
dataset = (
    get_dataset(args, tokenizer=tokenizer, kg_pad=config.kg_special_token_ids["PAD"], evaluate=True)
)

data_collator = NodeClassification_DataCollator(tokenizer=tokenizer, kg_special_token_ids=config.kg_special_token_ids, kg_size = config.vocab_size['kg'])

# Get data loader
data_loader = DataLoader(
    dataset,
    sampler=SequentialSampler(dataset),
    batch_size=args.batch_size,
    collate_fn=data_collator,
    drop_last=True,
    pin_memory=True,
)

# Sanity check
print(next(iter(data_loader))['lang_input_ids'])
print(label2entity)

### Metric for class-wise accuracy

In [None]:
def get_ClassWise_Accruacy(logit, label, cwa_dict):
    _, predicted = torch.max(logit, dim=2)
    active_correct = (predicted == label)[~label.eq(-100)].tolist()
    active_label = label[~label.eq(-100)].tolist()
    for correct, label in zip(active_correct, active_label):
        cwa_dict[label].append(correct)
    return cwa_dict

### Measure class-wise accuracy

In [None]:
metrics = {'lang_cwa':{k:list() for k in range(config.vocab_size['lang'])},'kg_cwa':{k:list() for k in range(config.num_kg_labels)}}

for step, inputs in tqdm(enumerate(data_loader),total=len(data_loader)):
    # Load tensors to CUDA devices
    for k, v in inputs.items():
        if isinstance(v, torch.Tensor):
            inputs[k] = v.cuda()
            
    # Eval per minibatch
    with torch.no_grad():
        outputs = model(**inputs,return_dict=True)
        metrics['lang_cwa'] = get_ClassWise_Accruacy(outputs['lang_prediction_logits'].data, inputs['lm_label'].data,metrics['lang_cwa'])
        metrics['kg_cwa'] = get_ClassWise_Accruacy(outputs['kg_prediction_logits'].data, inputs['kg_label'].data,metrics['kg_cwa'])

for k in metrics:
    if 'lang' in k:
        metrics[k] = dict([(tokenizer.convert_ids_to_tokens(label),sum(correct)/len(correct)) for label, correct in list(metrics[k].items()) if len(correct)>0])
    else:
        metrics[k] = dict([(label2entity[label],sum(correct)/len(correct)) for label, correct in list(metrics[k].items()) if len(correct)>0])

### Plot class-wise accuracy

In [None]:
P = 0.25
modality = 'kg'

import matplotlib.pyplot as plt

failure_case = {k:v for k,v in metrics['{}_cwa'.format(modality)].items() if v<P}

print("# {} label under {} : [{}]\n".format(modality,P,len(failure_case)))
print(failure_case)

## 2. Visualize attention score per head

### Load visualization tool

In [None]:
%%javascript
require.config({
  paths: {
      d3: '//cdnjs.cloudflare.com/ajax/libs/d3/3.4.8/d3.min',
      jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
  }
});

from bertviz import head_view

def show_head_view(data_info, view_option='language'):
        
    lang_tokens = data_info['lang_tokens']
    kg_tokens = data_info['kg_tokens']
    lang_attentions = data_info['lang_attentions']
    cross_attentions = data_info['cross_attentions']
    
    print(data_info['lang_seq'])
    
    if view_option == 'langauge':
        head_view(attention=lang_attentions, tokens=lang_tokens)
        
    else:
        all_tokens = lang_tokens + kg_tokens
        sentence_b_start = len(lang_tokens)
        head_view(attention=cross_attentions, tokens=all_tokens , sentence_b_start=sentence_b_start)

### Dataloader without masking

In [None]:
# Load model on CPU for prediction
model.cpu()

if args.block_size <= 0:
    args.block_size = tokenizer.max_len
    # Our input block size will be the max possible for the model
else:
    args.block_size = min(args.block_size, tokenizer.max_len)

# Get datasets
dataset = (
    get_dataset(args, tokenizer=tokenizer, kg_pad=config.kg_special_token_ids["PAD"], evaluate=True)
)

data_collator = NodeClassification_DataCollator(tokenizer=tokenizer, kg_special_token_ids=config.kg_special_token_ids, kg_size = config.vocab_size['kg'],prediction=True)

# Get data loader
data_loader = DataLoader(
    dataset,
    sampler=SequentialSampler(dataset),
    batch_size=args.batch_size,
    collate_fn=data_collator,
    drop_last=True,
    pin_memory=True,
)

### Get attention score from sample

In [None]:
id2entity = {int(line.split('\t')[1])+len(config.kg_special_token_ids):\
             line.split('\t')[0].split('^^')[0] for line in open(os.path.join(EXP_PATH,'data','entity2id.txt')).read().splitlines()[1:]}


def get_data_info_from_sample_for_viz(sample_idx=0,
                                      model=model,
                                      data_loader=data_loader,
                                      tokenizer=tokenizer,
                                      kg_id_mapping=id2entity):
    
    for idx, data in enumerate(data_loader):
        if idx==sample_idx:
            input_data = data
            break
    
    data_info = {}
    
    # lang part
    lang_temp_input_ids = input_data['lang_input_ids'].cpu().squeeze()
    lang_input_len = len(lang_temp_input_ids.nonzero())
    lang_tokens = tokenizer.convert_ids_to_tokens(lang_temp_input_ids[:lang_input_len])
    lang_seq = tokenizer.convert_tokens_to_string(
        tokenizer.convert_ids_to_tokens(lang_temp_input_ids[:lang_input_len]))
    data_info['lang_input_len'] = lang_input_len 
    data_info['lang_tokens'] = lang_tokens
    data_info['lang_seq'] = lang_seq
    
    # kg part
    kg_temp_input_ids = input_data['kg_input_ids'].cpu().squeeze()
    kg_input_len = len(kg_temp_input_ids.nonzero())
    kg_tokens = list(map(lambda x: kg_id_mapping[x], kg_temp_input_ids[:kg_input_len].tolist()))
    data_info['kg_input_len'] = kg_input_len
    data_info['kg_tokens'] = kg_tokens
    
    # output attentions
    output_data = model(**input_data, output_attentions=True, return_dict=True)
    data_info['lang_attentions'] =  [l_layer[:,:,
                                             :lang_input_len,
                                             :lang_input_len].cpu() for l_layer in output_data['language_attentions']]
    temp_c_attentions = torch.cat([c_layer[:,:,
                                          :(lang_input_len+kg_input_len),
                                          :(kg_input_len+lang_input_len)].cpu() for c_layer in output_data['cross_encoder_attentions']])
    temp_c1_attentions = temp_c_attentions[:,:,:,:kg_input_len]
    temp_c2_attentions = temp_c_attentions[:,:,:,-lang_input_len:]
    c_attentions = torch.cat([temp_c2_attentions, temp_c1_attentions], dim=-1)
    data_info['cross_attentions'] = [c_attentions[layer_idx,:,:,:].unsqueeze(0) for layer_idx in range(len(temp_c_attentions[:]))]
    
    return data_info

### Visualize attention

In [None]:
result = get_data_info_from_sample_for_viz(sample_idx=26)

#### Cross

In [None]:
show_head_view(result, 'cross')

#### Language

In [None]:
show_head_view(result, 'langauge')

## <span style="color:skyblue">Supp 1. Visualize attention score in matrix form</span>

### Dataloader without masking

In [None]:
# Load model on CPU for prediction
model.cpu()

if args.block_size <= 0:
    args.block_size = tokenizer.max_len
    # Our input block size will be the max possible for the model
else:
    args.block_size = min(args.block_size, tokenizer.max_len)

# Get datasets
dataset = (
    get_dataset(args, tokenizer=tokenizer, kg_pad=config.kg_special_token_ids["PAD"], evaluate=True)
)

data_collator = NodeClassification_DataCollator(tokenizer=tokenizer, kg_special_token_ids=config.kg_special_token_ids, kg_size = config.vocab_size['kg'],prediction=True)

# Get data loader
data_loader = DataLoader(
    dataset,
    sampler=SequentialSampler(dataset),
    batch_size=args.batch_size,
    collate_fn=data_collator,
    drop_last=True,
    pin_memory=True,
)

### Get attention score from sample

In [None]:
import matplotlib.pyplot as plt

SAMPLE_IDX = 26

id2entity = {int(line.split('\t')[1])+len(config.kg_special_token_ids):line.split('\t')[0].split('^^')[0] for line in open(os.path.join(EXP_PATH,'data','entity2id.txt')).read().splitlines()[1:]}
for idx, data in enumerate(data_loader):
    if idx==SAMPLE_IDX:
        input_data = data
        break
        
print('==== Text ====')
temp = input_data['lang_input_ids'].cpu().squeeze()
lang_tokens = tokenizer.convert_ids_to_tokens(temp[:temp.nonzero().shape[0]].tolist())
print('Token seq : {}\n'.format(lang_tokens))
print('Full sentence : {}\n\n'.format(tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(temp[:temp.nonzero().shape[0]].tolist()))))

print('==== KG ====')
temp = input_data['kg_input_ids'].cpu().squeeze()
kg_tokens = list(map(lambda x: id2entity[x],temp[:temp.nonzero().shape[0]].tolist()))
print('Token seq : {}\n'.format(kg_tokens))

output_data = model(**input_data, output_attentions=True, return_dict=True)

### Cross modal attention

In [None]:
import matplotlib.pyplot as plt
fig, axs = plt.subplots(model.config.x_layers, model.config.num_attention_heads, figsize=(20,20))
for layer_idx in range(model.config.x_layers):
    for head_idx in range(model.config.num_attention_heads):
        axs[model.config.l_layers-1-layer_idx, head_idx].matshow(
            output_data['cross_encoder_attentions'][layer_idx][0,head_idx,
                                                         :len(lang_tokens),:len(kg_tokens)].cpu().detach().numpy(),
            #cmap='gray',
        )
        #if (layer_idx == config.x_layers-1) and (head_idx==):
        if (layer_idx == 0) and (head_idx==0):
            axs[layer_idx, head_idx].set_xticklabels(kg_tokens, rotation='45',horizontalalignment='left', fontsize=8)
            axs[layer_idx, head_idx].set_yticklabels(lang_tokens, rotation='45',horizontalalignment='right', fontsize=8)
            axs[layer_idx, head_idx].set_xticks(range(len(kg_tokens)))
            axs[layer_idx, head_idx].set_yticks([range(len(lang_tokens))])
        axs[0, head_idx].set_title("Head_{}\n".format(head_idx+1))
    axs[layer_idx, 0].set_ylabel('Layer_{}\n'.format(model.config.l_layers-layer_idx))
plt.suptitle('Cross Modal Attention Vis', fontsize=40)

### Language attention

In [None]:
import matplotlib.pyplot as plt
fig, axs = plt.subplots(model.config.l_layers, model.config.num_attention_heads, figsize=(18,20))
for layer_idx in range(model.config.l_layers):
    for head_idx in range(model.config.num_attention_heads):
        axs[model.config.l_layers-1-layer_idx, head_idx].matshow(
            output_data['language_attentions'][layer_idx][0,head_idx,
                                                          :len(lang_tokens),
                                                          :len(lang_tokens)].cpu().detach().numpy(),
            #cmap='gray',
        )
        if (layer_idx == 0) and (head_idx==0):
            axs[layer_idx, head_idx].set_xticks(range(len(lang_tokens)))
            axs[layer_idx, head_idx].set_yticks(range(len(lang_tokens)))
            axs[layer_idx, head_idx].set_xticklabels(lang_tokens, rotation='45',horizontalalignment='left', fontsize=8)
            axs[layer_idx, head_idx].set_yticklabels(lang_tokens, rotation='45',horizontalalignment='right', fontsize=8)
        axs[0, head_idx].set_title("Head_{}\n".format(head_idx+1))
    axs[layer_idx, 0].set_ylabel('Layer_{}\n'.format(model.config.l_layers-layer_idx))
plt.suptitle('Language Attention Vis', fontsize=40)