In [None]:
import sys

sys.path.insert(0, '..')

In [None]:
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import torch.nn.functional as f
from random import randint
from rouge import Rouge
from colr import color

from main.common.common import *
from main.common.vocab import *
from main.common.simple_vocab import SimpleVocab
from main.common.util.file_util import FileUtil
from main.data.cnn_dataloader import *
from main.seq2seq import Seq2Seq
from main.common.glove.embedding import GloveEmbedding

In [None]:
def generate_attention_heatmap_text(words, attn_scores, dec=False):
    c = (255, 165, 0)
    
    att_words = []
    for idx, word in enumerate(words):
        if dec is True and i == 0:
            continue
        
        attn_score = attn_scores[idx if dec is False else (idx - 1)]
        attn_color = get_color(c, attn_score)

        att_words.append(color(word, back=attn_color))
        
    return ' '.join(att_words)

def get_color(color, opacity):
    r = 255 - opacity * (255 - color[0])
    g = 255 - opacity * (255 - color[1])
    b = 255 - opacity * (255 - color[2])
    return (r, g, b)

def show_attention_heatmap(article, summary, attention):
    attention = attention[:, :-1]
    
    # figure
    figure = plt.figure(figsize=(20, 5))
    ax = figure.add_subplot(111)
    
    cax = ax.matshow(attention, cmap='bone')
    figure.colorbar(cax)
    
    # set up axes
    ax.set_xticklabels([''] + article + ['[STOP]'], rotation=90)
    ax.set_yticklabels([''] + summary)

    # show label at every tick
    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    
    plt.show()
    plt.close()
    
def get_score(summary, reference):
    rouge = Rouge()
    
    summary = summary.split()
    
    try:
        stop_idx = summary.index(TK_STOP['word'])
        summary = summary[:stop_idx]
    except ValueError:
        pass
    
    score = rouge.get_scores(' '.join(summary), reference)[0]["rouge-l"]["f"]

    return score

def evaluate(model, example):
    article, keyword, reference = example

    summary, enc_attention, dec_attention = model.evaluate(article, keyword)
    score = get_score(summary, reference)
        
    article_words = article.split()
    summary_words = summary.split()
    
    enc_attention = enc_attention.cpu()
    enc_attention = t.sum(enc_attention, dim=0) / len(enc_attention)
    enc_attention = t.clamp(enc_attention, 0, 1)
    
    enc_heatmap_text = generate_attention_heatmap_text(article_words, enc_attention)
    
    intra_dec_attn = True if dec_attention is not None else False
    if intra_dec_attn is True:
        dec_attention = dec_attention.cpu()    
        dec_attention = t.sum(dec_attention, dim=0) / len(dec_attention)
        dec_attention = t.clamp(dec_attention, 0, 1)
        
        dec_heatmap_text = generate_attention_heatmap_text(summary_words, dec_attention, dec=True)
        
    print()
    
    print('\033[1m' + 'Article' + '\033[0m')
    print(enc_heatmap_text)

    print()
    
    print('\033[1m' + 'Keyword' + '\033[0m')
    print(keyword)
    
    print()
    
    print('\033[1m' + 'Reference Summary' + '\033[0m')
    print(reference)
    
    print()
    
    print('\033[1m' + 'Generated Summary' + '\033[0m')
    print(dec_heatmap_text if intra_dec_attn is True else summary_words)
    
    print()
    
    print('\033[1m' + 'Rouge-L' + '\033[0m')
    
    print('%.3f' % score)
    
    print()

In [None]:
AppContext('main/conf/eval/config.yml')

vocab = SimpleVocab(FileUtil.get_file_path(conf('vocab-file')), conf('vocab-size'))

embedding = GloveEmbedding(FileUtil.get_file_path(conf('emb-file')), vocab) if conf('emb-file') is not None else None

seq2seq = cuda(Seq2Seq(vocab, embedding))

checkpoint = t.load(FileUtil.get_file_path(conf('eval:load-model-file')))

seq2seq.load_state_dict(checkpoint['model_state_dict'])

seq2seq.eval()

data_loader =  CNNDataLoader(FileUtil.get_file_path(conf('eval:article-file')),
                             FileUtil.get_file_path(conf('eval:summary-file')),
                             FileUtil.get_file_path(conf('eval:keyword-file')),
                             conf('eval:batch-size'), 'eval')

In [None]:
samples = data_loader.read_all()

In [None]:
idx = randint(0, len(samples) - 1)
example = samples[idx]
    
print('idx: ', i)

evaluate(seq2seq, example)