In [1]:

import numpy as np
import torch

from tqdm import tqdm
from torch.utils.data import DataLoader
from rouge_score import rouge_scorer
from nltk.tokenize import sent_tokenize as nltk_sent_tokenize

from model import Model, torch_load_all
from config import CNNConfig
from data_utils import ESDataset, collate_fn

scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeLsum'], use_stemmer=True)

def sent_tokenize(doc):
    return nltk_sent_tokenize(doc)

def get_test_loader(tokenizer, docs, config):
    encodings = []
    for doc in docs:
        encodings.append(tokenizer(doc[:config.MAX_DOC_LEN], truncation=True,
                                   max_length=config.MAX_SEQ_LEN, padding='max_length'))
    
    test_dataset = ESDataset(encodings)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

    return test_loader

def get_all_probs(model, tokenizer, all_doc, config):
    res = []
    test_loader = get_test_loader(tokenizer, all_doc, config=config)
    for item in tqdm(test_loader):
        ids = item['input_ids'].to(config.device)
        document_mask = item['document_mask'].to(config.device)
        attention_mask = item['attention_mask'].to(config.device)
        prob = model(ids, document_mask, attention_mask)[0].tolist()
        res.append(prob)
            
    return res

def calculateSimilarity(sentence, doc):
    score = scorer.score('\n'.join(doc), sentence)
    return np.mean([score['rouge2'][2], score['rougeLsum'][2]])

def choose_summary_mmr(doc, prob, k=3, alpha=0.9):
    prob = np.array(prob)
    idx = [np.argmax(prob)]
    prob[idx[0]] = 0
    summary = [doc[idx[0]]]

    while len(idx) < min(k, len(doc)):
        mmr = -100 * np.ones_like(prob)
        for i, sent in enumerate(doc):
            if prob[i] != 0:
                mmr[i] = alpha * prob[i] - (1-alpha) * calculateSimilarity(sent, summary)
        pos = np.argmax(mmr)
        prob[pos] = 0
        summary.append(doc[pos])
        idx.append(pos)
    summary = sorted(list(zip(idx, summary)), key=lambda x: x[0])
    return [x[1] for x in summary]

def choose_summary(doc, prob, k=3):
    idx = torch.topk(torch.tensor(prob), k=k).indices.tolist()
    return [doc[i] for i in sorted(idx)]

def choose_all_summary(docs, all_probs, k=3):
    summaries = []
    for i, doc in enumerate(docs):
        prob = all_probs[i]
        idx = torch.topk(torch.tensor(prob), k=k).indices.tolist()
        summaries.append([doc[i] for i in sorted(idx)])
    return summaries


2021-10-07 00:08:37.980039: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-10.1/lib64:/usr/local/cuda-10.1/lib64:
2021-10-07 00:08:37.980075: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


In [2]:
# load the trained model
final_dict = torch_load_all('save/cnn/best-model')
config = CNNConfig()
config.device = 'cpu'

bert = config.bert_type.from_pretrained(config.bert_name)
tokenizer = config.tokenizer_type.from_pretrained(config.bert_name)
model = Model(bert, config).to(config.device)
model.load_state_dict(final_dict['model_state_dict'])
# model.eval()

Some weights of the model checkpoint at prajjwal1/bert-small were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<All keys matched successfully>

In [3]:
with open('data/cnndm/example.txt') as f:
    doc = f.read()
doc = sent_tokenize(doc)
# doc

In [4]:
docs = [doc]
probs = get_all_probs(model, tokenizer, docs, config)
sumamries = choose_all_summary(docs, probs, k=3)
sumamries

100%|██████████| 1/1 [00:00<00:00,  2.48it/s]


[['(CNN) — After a long, golden sunset of being installed on fewer and fewer aircraft, the retirement of older aircraft caused by the Covid-19 pandemic means that when air travel resumes, international first class will be very nearly a thing of the past.',
  'Its replacement is a new generation of superbusiness minisuites, more spacious than regular business class, and with a privacy door to create your own space, but without the over-the-top luxury of first class.',
  'The Qsuite is unique to Qatar Airways, but a growing number of airlines offer or plan to offer superbusiness seats, from Delta to China Eastern, JetBlue to British Airways, Shanghai Airlines to Aeroflot, to the very latest, Air China.']]