In [1]:
from pathlib import Path

from bertviz import head_view

import torch

In [82]:
from model.swan import SWANPhase1Only
from model.base import chkpt
from test_model import load_config, run_model_for_attention
from common.dataset import Dataset
from learner import *

In [3]:
main_path = Path('.')
data_path = main_path / 'resource'
chpt_path = main_path / 'runs_copy' / 'best_SWAN_P1'

In [4]:
tokenizer = torch.load(chpt_path / 'tokenizer.pt')
checkpoint = torch.load(chpt_path / 'SWANPhase1Only.pt')

In [None]:
config = load_config(chpt_path)
nvix = SWANPhase1Only.create_or_load(path=str(chpt_path), **config)
nvix.eval()

In [None]:
state_dict = nvix.state_dict()
state_dict.keys()

In [66]:
dataset_file = data_path / 'dataset' / 'pen.json'
test_data = Dataset(dataset_file, number_window=3)

In [67]:
test_data.select_items_with_file(data_path / 'experiments'/ 'pen' /'test')

In [75]:
test_data.num_items

365

In [85]:
set_seed(config['seed'])
batch = test_data.get_minibatches(batch_size=1, for_testing=True)
output = nvix(
    text=batch[0].text.to(nvix.device),
    beam=config['beam_for_equation'], 
    beam_expl=config['beam_for_explanation']
)

In [86]:
output

{'eqn_ignore': {1},
 'var_expl': [Label([[102, 1996, 3091, 1997, 2051, 1997, 27244, 2075, 102, -1, -1, -1, -1]])],
 'num_expl': [Label([[102, 1996, 3292, 1997, 27244, 2075, 1999, 7338, 102, -1, -1], [102, 1996, 3177, 1997, 27244, 2075, 1999, 7338, 2566, 3178, 102]])],
 'explanation': [Explanation(numbers=$[Label([[102, 1996, 3292, 1997, 27244, 2075, 1999, 7338, 102, -1, -1], [102, 1996, 3177, 1997, 27244, 2075, 1999, 7338, 2566, 3178, 102]])], variables=$[Label([[102, 1996, 3091, 1997, 2051, 1997, 27244, 2075, 102, -1, -1, -1, -1]])], worker=$0)],
 'equation': Equation(operator=Label([[0]]), operands=[Label([[-1]]), Label([[-1]])])}

In [38]:
import numpy
import matplotlib.pyplot as plt

# http://stackoverflow.com/questions/14391959/heatmap-in-matplotlib-with-pcolor
def plot_head_map(mma, target_labels, source_labels):
    fig, ax = plt.subplots()
    heatmap = ax.pcolor(mma, cmap=plt.cm.Blues)
    # put the major ticks at the middle of each cell
    ax.set_xticks(numpy.arange(mma.shape[1]) + 0.5, minor=False) # mma.shape[1] = target seq 길이
    ax.set_yticks(numpy.arange(mma.shape[0]) + 0.5, minor=False) # mma.shape[0] = input seq 길이
 
    # without this I get some extra columns rows
    # http://stackoverflow.com/questions/31601351/why-does-this-matplotlib-heatmap-have-an-extra-blank-column
    ax.set_xlim(0, int(mma.shape[1]))
    ax.set_ylim(0, int(mma.shape[0]))
 
    # want a more natural, table-like display
    ax.invert_yaxis()
    ax.xaxis.tick_top()
 
    # source words -> column labels
    ax.set_xticklabels(source_labels, minor=False)
    # target words -> row labels
    ax.set_yticklabels(target_labels, minor=False)
 
    plt.xticks(rotation=45)
 
    # plt.tight_layout()
    plt.show()

    
def read_plot_alignment_matrices(source_labels, target_labels, alpha):
 
    mma = alpha.cpu().data.numpy()
 
    plot_head_map(mma, target_labels, source_labels)

In [48]:
attn = state_dict['explanation.encoder.layer.11.output.LayerNorm.weight']
attn.shape
head_view(attn, encoded_input)

ValueError: The attention tensor does not have the correct number of dimensions. Make sure you set output_attentions=True when initializing your model.

In [29]:
input_str = "the sears tower in chicago is 145 ##0 feet tall . the john hancock center in chicago is 112 ##7 feet tall . suppose you are asked to build a small - scale replica of each . if you make the sears tower 3 meter tall , what would be the approximate height of the john hancock replica ? round your answer to the nearest hundred ##th ."
encoded_input = tokenizer.encode(input_str)
len(encoded_input)

76

In [32]:
output_str = "the height of the sears tower"
encoded_output = tokenizer.encode(output_str)
len(encoded_output)

8

In [27]:
config

{'seed': 1,
 'batch_size': 16,
 'beam_for_equation': 3,
 'beam_for_explanation': 5,
 'dataset': '/home/bydelta/SimpleFESTA/resource/dataset/pen.json',
 'learner': {'model': 'SWAN_P1',
  'encoder': 'google/electra-base-discriminator',
  'equation': {'hidden_dim': 0, 'intermediate_dim': 0, 'layer': 6, 'head': 0},
  'explanation': {'encoder': 'google/electra-base-discriminator',
   'shuffle': True}},
 'resource': {'num_gpus': 1.0, 'num_cpus': 1.0},
 'experiment': {'dev': {'split_file': '/home/bydelta/SimpleFESTA/resource/experiments/pen/dev',
   'period': 100},
  'train': {'split_file': '/home/bydelta/SimpleFESTA/resource/experiments/pen/train'},
  'test': {'split_file': '/home/bydelta/SimpleFESTA/resource/experiments/pen/test',
   'period': 500}},
 'grad_clip': 10.0,
 'optimizer': {'type': 'lamb',
  'lr': 0.00176,
  'betas': (0.9, 0.999),
  'eps': 1e-08,
  'debias': True},
 'scheduler': {'type': 'warmup-linear',
  'num_warmup_epochs': 10.0,
  'num_total_epochs': 500},
 'window': 3}

In [7]:
dataset_file = data_path / 'dataset' / 'pen.json'

KEY_DATASET = 'dataset'
KEY_SEED = 'seed'

exp_base = {
        KEY_DATASET: str(dataset_file.absolute()),
        KEY_SEED: config['seed']
    }

In [None]:
test_check, test_results = run_model_for_attention(chpt_path, **exp_base)

In [None]:
test_results