In [1]:
from modules.data import Documents
import modules.beam_search as bs

doc_file = './data/kaggle_news_rouge1.pkl'
docs = Documents(doc_file, vocab_size = 30000)
vocab = docs.vocab

In [2]:
from modules.texts import Vocab, GloVeLoader
import os
from os.path import join
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import modules.extractive as ext
import modules.abstractive as abs
import modules.beam_search as bs
from modules.data import Documents
from torch.utils.data import DataLoader
import numpy as np

d = 200
emb = nn.Embedding(vocab.V, d)
pretrained = emb
vocab_size = vocab.V
emb_size = emb.weight.data.size(1)
n_kernels = 50
kernel_sizes = [1,2,3,4,5]
sent_size = len(kernel_sizes) * n_kernels
hidden_size = 400
num_layers = 1
n_classes = len(docs.dclass2id)
batch_size = 1
torch.manual_seed(7)
torch.cuda.manual_seed(7)

ext_s_enc = ext.SentenceEncoder(vocab_size, emb_size,
                                   n_kernels, kernel_sizes, pretrained)
ext_d_enc = ext.DocumentEncoder(sent_size, hidden_size)
ext_extc = ext.ExtractorCell(sent_size, hidden_size)
ext_d_classifier = ext.DocumentClassifier(sent_size, n_classes)
abs_enc = abs.EncoderRNN(emb, hidden_size, num_layers)
abs_dec = abs.AttnDecoderRNN(emb, hidden_size * 2, num_layers)

# Load the models
prefix = 'hidden400nkernels50epoch10_'
data_dir = join(os.path.expanduser('~'), 'cs671-large')
models = [emb, ext_s_enc, ext_d_enc, ext_extc, ext_d_classifier, abs_enc, abs_dec]
names = ['emb', 'ext_s_enc','ext_d_enc','ext_extc','ext_d_classifier','abs_enc','abs_dec']
for name, model in zip(names, models):
    model.load_state_dict(torch.load(join(data_dir, prefix + name)))

In [5]:
from rouge import Rouge

rouge = Rouge()

examples = []
n_samples = 3000
for i, doc in enumerate(docs):
    # start from testing samples
    if i < n_samples: pass
    try:
        test_input = [torch.LongTensor(sent).view(1,-1) for sent in doc.sents]
        ref_input = torch.LongTensor(doc.head).view(1,-1)
        top1_batch, seqs_batch = bs.generate_title(doc_sents = test_input,
                                                   beam_size = 20,
                                                   models = models,
                                                   max_kernel_size = max(ext_s_enc.kernel_sizes))
        for i in range(len(top1_batch)):
            generated = '_BEGIN_ ' + vocab.id2sents([top1_batch[i][0]])
            title = vocab.id2sents([ref_input[i]])
            contents = vocab.id2sents(doc.sents)
            rouge_batch = rouge.get_scores(generated, title)[0]['rouge-1']['r']
            examples.append((generated, title, contents, rouge_batch))
    
    except Exception as e:
        pass
#         print(e)

In [6]:
f = open('./data/generated_examples_epoch100_recall.txt', 'w')
for generated, title, contents, r_score in sorted(examples, key = lambda x: -x[-1]):
#     if 'firewall' in reference:
#         print(generated)
#         print(reference)
#         print(r_score)
    print(generated)
    print(title)
    print(r_score)
    s = '%s\n%s\n%s\n%.3f\n' % (generated, title, contents, r_score)
    f.write(s)
f.close()

import pickle
with open('./data/generated_examples_epoch100_recall.pkl', 'wb') as f:
    pickle.dump(examples, f)

_BEGIN_ ' fitbits :# reports _END_
_BEGIN_ bmw denies reports of emissions UNK _END_
0.375
_BEGIN_ mandatorily ' _END_
_BEGIN_ amul releases poster on ' dangal ' _END_
0.375
_BEGIN_ ' officer its reports _END_
_BEGIN_ tharoor ' categorically ' denies reports of him joining bjp _END_
0.36363636363636365
_BEGIN_ ' navtej _END_
_BEGIN_ aamir khan ' s ' dangal ' leaked online _END_
0.3333333333333333
_BEGIN_ ' navtej _END_
_BEGIN_ mumbai ' s juhu beach to be redeveloped _END_
0.3
_BEGIN_ ' navtej _END_
_BEGIN_ woman sexually assaulted by ' superhost ' sues airbnb _END_
0.3
_BEGIN_ mandatorily ' _END_
_BEGIN_ release date of ' wonder woman ' sequel announced _END_
0.3
_BEGIN_ mandatorily ' _END_
_BEGIN_ kargil was india ' s first televised war _END_
0.3
_BEGIN_ gondia ' labs -? dayanand _END_
_BEGIN_ ' lipstick under my burkha ' hits the theatres _END_
0.3
_BEGIN_ wolfsburg ' _END_
_BEGIN_ makeshift ' illegal ' shops demolished by noida authority _END_
0.3
_BEGIN_ ' navtej _END_
_BEGIN_ tru

0.2
_BEGIN_ mandatorily ' _END_
_BEGIN_ indian fisherman shot dead by sri lankan navy _END_
0.2
_BEGIN_ ' jail shroff ' _END_
_BEGIN_ 8 indian fishermen arrested by sri lankan officials _END_
0.2
_BEGIN_ mandatorily ' _END_
_BEGIN_ indian - origin businessman shot dead in us _END_
0.2
_BEGIN_ ' navtej _END_
_BEGIN_ don ' t drop catches else smith will score a ton : clarke _END_
0.2
_BEGIN_ wolfsburg ' _END_
_BEGIN_ govt scraps bengaluru ' s ? 1 , 761 cr steel flyover project _END_
0.2
_BEGIN_ ' navtej _END_
_BEGIN_ reliance jio halts prime membership due to overload _END_
0.2
_BEGIN_ ' navtej _END_
_BEGIN_ delhi metro to introduce music on airport line _END_
0.2
_BEGIN_ ' navtej _END_
_BEGIN_ rss leader announces ? 1 cr bounty on kerala cm ' s head _END_
0.2
_BEGIN_ ' navtej _END_
_BEGIN_ flexible material made using water tougher than steel _END_
0.2
_BEGIN_ ' alright ' ryanair reports _END_
_BEGIN_ female engineer accuses tesla of sexism , harassment _END_
0.2
_BEGIN_ wolfsburg ' _EN

0.16666666666666666
_BEGIN_ bjp4delhi ' _END_
_BEGIN_ if i can gain from launching srk ' s son , why won ' t i : karan _END_
0.16666666666666666
_BEGIN_ ' trans :# reports _END_
_BEGIN_ bcci treasurer objected to cash award for team : bcci panel _END_
0.16666666666666666
_BEGIN_ ' navtej _END_
_BEGIN_ molestation case against tvf ceo to be closed : report _END_
0.16666666666666666
_BEGIN_ mandatorily ' _END_
_BEGIN_ ajinkya rahane scratchy , down on confidence : sourav ganguly _END_
0.16666666666666666
_BEGIN_ bjp4delhi ' _END_
_BEGIN_ gilgit - baltistan belongs to india , says british parliament _END_
0.16666666666666666
_BEGIN_ wolfsburg ' _END_
_BEGIN_ gujarat female teacher shows porn to students , dances naked _END_
0.16666666666666666
_BEGIN_ srinu ' _END_
_BEGIN_ kejriwal to face trial in defamation case filed by jaitley _END_
0.16666666666666666
_BEGIN_ ' navtej _END_
_BEGIN_ ips officer suspended for tweeting against yogi govt in up _END_
0.16666666666666666
_BEGIN_ ' navtej _

_BEGIN_ wolfsburg ' _END_
_BEGIN_ not using patanjali can make one anti - national now : kanhaiya _END_
0.14285714285714285
_BEGIN_ ' navtej _END_
_BEGIN_ driver stops train midway for 2 hrs to take bath in bihar _END_
0.14285714285714285
_BEGIN_ mandatorily ' _END_
_BEGIN_ show that lets men guess if woman is fat or pregnant slammed _END_
0.14285714285714285
_BEGIN_ amlekhganj shahbad _END_
_BEGIN_ kejriwal wants evm to work like sisodia does for him : tiwari _END_
0.14285714285714285
_BEGIN_ wolfsburg ' _END_
_BEGIN_ new jio offer classic case of old wine in new bottle : airtel _END_
0.14285714285714285
_BEGIN_ williams bc _END_
_BEGIN_ 137 police personnel for every 1 lakh people in india : govt _END_
0.14285714285714285
_BEGIN_ ' navtej _END_
_BEGIN_ hc reduces sentence saying man wanted to burn , not kill wife _END_
0.14285714285714285
_BEGIN_ ' navtej _END_
_BEGIN_ trump backs usa for joint 2026 wc bid with canada , mexico _END_
0.14285714285714285
_BEGIN_ ' navtej _END_
_BEGIN_ 