In [1]:
import sys

sys.path.insert(0, '..')

In [2]:
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import torch.nn.functional as f
from random import randint
from rouge import Rouge
from colr import color

from main.common.common import *
from main.common.vocab import *
from main.common.simple_vocab import SimpleVocab
from main.common.util.file_util import FileUtil
from main.data.cnn_dataloader import *
from main.seq2seq import Seq2Seq
from main.common.glove.embedding import GloveEmbedding

In [3]:
def generate_attention_heatmap_text(words, attn_scores):
    c = (255, 165, 0)
    
    att_words = []
    for idx, word in enumerate(words):
        attn_score = attn_scores[idx]
        attn_color = get_color(c, attn_score)
        
        att_words.append(color(word, back=attn_color))
        
    return ' '.join(att_words)

def get_color(color, opacity):
    r = 255 - opacity * (255 - color[0])
    g = 255 - opacity * (255 - color[1])
    b = 255 - opacity * (255 - color[2])
    return (r, g, b)

def show_attention_heatmap(article, summary, attention):
    attention = attention[:, :-1]
    
    # figure
    figure = plt.figure(figsize=(20, 5))
    ax = figure.add_subplot(111)
    
    cax = ax.matshow(attention, cmap='bone')
    figure.colorbar(cax)
    
    # set up axes
    ax.set_xticklabels([''] + article + ['[STOP]'], rotation=90)
    ax.set_yticklabels([''] + summary)

    # show label at every tick
    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    
    plt.show()
    plt.close()
    
def get_score(summary, reference):
    rouge = Rouge()
    
    summary = summary.split()        
    summary = [w for w in summary if w != TK_STOP['word']]
    
    score = rouge.get_scores(' '.join(summary), reference)[0]["rouge-l"]["f"]

    return score

In [4]:
AppContext()

vocab = SimpleVocab(FileUtil.get_file_path(conf('vocab-file')), conf('vocab-size'))

embedding = GloveEmbedding(FileUtil.get_file_path(conf('emb-file')), vocab) if conf('emb-file') is not None else None

seq2seq = cuda(Seq2Seq(vocab, embedding))

checkpoint = t.load(FileUtil.get_file_path(conf('model-file')))

seq2seq.load_state_dict(checkpoint['model_state_dict'])

seq2seq.eval()

data_loader =  CNNDataLoader(FileUtil.get_file_path(conf('train:article-file')),
                             FileUtil.get_file_path(conf('train:summary-file')),
                             FileUtil.get_file_path(conf('train:keyword-file')),
                             conf('train:batch-size'))

DEBUG:SimpleVocab:initialize vocabulary from: /home/vivien/PycharmProjects/kw-txt-summarization/data/train/vocab.txt


In [21]:
samples = data_loader.read_all()

article, keyword, reference = samples[3]

keyword = keyword[0]
reference = reference[0]
    
summary, attention = seq2seq.evaluate(article, keyword)

score = get_score(summary, reference)

print('>>> article: ', article)
print('>>> keyword: ', keyword)
print('========================')
print('>>> reference: ', reference)
print('>>> prediction: ', summary)
print('>>> score: ', score)

article_words = article.split()
summary_words = summary.split()
attention = attention.cpu()

#show_attention_heatmap(article_words, summary_words, attention)

article_words_attention = t.sum(attention, dim=0) / t.max(attention)
article_words_attention = t.clamp(article_words_attention, 0, 1)

heatmap_text = generate_attention_heatmap_text(article_words, article_words_attention)

print('>>> article with heatmap: ', heatmap_text)


>>> article:  ( cnn ) -- qatar plans to build nine fully air - conditioned open - air stadiums to stage matches at the 2022 fifa world cup . click through the gallery above to see how the stadiums will look .
>>> keyword:  
>>> reference:  qatar hopes to host the 2022 world cup in temperatures of over 40 c . it plans to use solar power to air condition its stadiums . qatar will be the first country in the middle east to stage the event .
>>> prediction:  godfather number number number number number number number number number number number number number number number number number number number number number number number number number number number number number number number number number number number number number number number number number number number number number number number number number
>>> score:  0.0
>>> article with heatmap:  [48;2;255;165;0m([0m [48;2;255;165;0mcnn[0m [48;2;255;165;0m)[0m [48;2;255;165;0m--[0m [48;2;255;165;0mqatar[0m [48;2;255;165;0mplans