In [1]:
import json
import math
import pandas as pd
import pickle
from tqdm import tqdm

from rouge_metric import PyRouge
import torch

%load_ext autoreload
%autoreload 2

from evaluation.rouge_evaluator import RougeEvaluator
from inference_utils import summarize
from CustomTokenizer import CustomTokenizer
from utils import Seq2SeqTransformer, TokenEmbedding, PositionalEncoding

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
pd.set_option('display.max_columns', None)

In [3]:
with open("data/train.json", 'r') as f:
    train_data = json.load(f)

with open("data/validation.json", 'r') as f:
    validation_data = json.load(f)

with open("data/test.json", 'r') as f:
    test_data = json.load(f)

# ================= LOAD DATASET ===========================
train_articles = [article['article'] for article in train_data]
train_summaries = [article['summary'] for article in train_data]
train_summaries = [summary.replace('\n', ' ') for summary in train_summaries]

val_articles = [article['article'] for article in validation_data]
val_summaries = [article['summary'] for article in validation_data]
val_summaries = [summary.replace('\n', ' ') for summary in val_summaries]

test_articles = [article['article'] for article in test_data]
test_summaries = [article['summary'] for article in test_data]
test_summaries = [summary.replace('\n', ' ') for summary in test_summaries]


# ================= REDUCE SIZE ===========================
size_of_dataset = 10000
train_articles = train_articles[:size_of_dataset]
train_summaries = train_summaries[:size_of_dataset]

val_articles = val_articles[:math.ceil(size_of_dataset/5)]
val_summaries = val_summaries[:math.ceil(size_of_dataset/5)]


test_articles = test_articles[:math.ceil(size_of_dataset/5)]
test_summaries = test_summaries[:math.ceil(size_of_dataset/5)]

In [4]:
model_save_name = "M_limited_dataset_bs16_es512_ffn2048_nh4_nl4_tknzrbpe_fdsFalse" #

In [5]:
epoch_checkpoint = 12
transformer = torch.load(f"models/e{epoch_checkpoint}_{model_save_name}.pt")
with open(f'tokenizers/{model_save_name}_tokenizer.pkl', 'rb') as file:
    tokenizer = pickle.load(file)

In [7]:
def flatten_scores(rouge_scores):
    flattened_scores = {f"{main_key}-{sub_key}": sub_value
                    for main_key, sub_dict in rouge_scores.items()
                    for sub_key, sub_value in sub_dict.items()}
    return flattened_scores

def get_test_scores_for_model(
    model,
    model_tag,
    test_articles,
    test_summaries,
    tokenizer
):
    rouge = PyRouge(rouge_n=(1, 2, 4), rouge_l=True, rouge_w=True, rouge_s=True, rouge_su=True)

    decoding_methods = [
        'greedy',
        #'beam_search',
        'top_k',
        #'top_p',
        'random_sampling'
    ]
    decoding_methods_to_scores = {}

    for decoding_method in decoding_methods:
        print(f"Generating summaries using {decoding_method} method.")
        generated_summaries = []
        for i, test_article in enumerate(tqdm(test_articles, desc=f'Generating summaries with {decoding_method} method.')):
            generated_summary = summarize(model, test_article, tokenizer=tokenizer, decoding_method=decoding_method, max_len=100)
            generated_summaries.append(generated_summary)

        print(f'Calculating ROUGE scores for {decoding_method} method.\n')
        scores = rouge.evaluate(generated_summaries, [[x] for x in test_summaries])
        flattened_scores = flatten_scores(scores)
        decoding_methods_to_scores[f"{decoding_method}_{model_tag}"] = flattened_scores
        print(flattened_scores)
        print()

    return decoding_methods_to_scores


In [None]:
decoding_methods_to_scores = get_test_scores_for_model(transformer, model_save_name, test_articles, test_summaries, tokenizer)

In [15]:
decoding_methods_to_scores_df = pd.DataFrame.from_dict(decoding_methods_to_scores, orient='index')
decoding_methods_to_scores_df

Unnamed: 0,rouge-1-r,rouge-1-p,rouge-1-f,rouge-2-r,rouge-2-p,rouge-2-f,rouge-4-r,rouge-4-p,rouge-4-f,rouge-l-r,rouge-l-p,rouge-l-f,rouge-w-1.2-r,rouge-w-1.2-p,rouge-w-1.2-f,rouge-s*-r,rouge-s*-p,rouge-s*-f,rouge-su*-r,rouge-su*-p,rouge-su*-f
greedy_S_bs8_es256_ffn1024_nh2_nl2_tknzrbpe_fdsFalse,0.175094,0.216748,0.193707,0.018176,0.023457,0.020482,0.000826,0.000967,0.000891,0.138871,0.169881,0.152819,0.047501,0.125188,0.06887,0.029778,0.050438,0.037448,0.034928,0.057353,0.043416
top_k_S_bs8_es256_ffn1024_nh2_nl2_tknzrbpe_fdsFalse,0.200761,0.234244,0.216214,0.020728,0.024867,0.022609,0.000574,0.000674,0.00062,0.143812,0.16561,0.153943,0.049029,0.122025,0.069952,0.036538,0.052844,0.043204,0.042449,0.060155,0.049774
random_sampling_S_bs8_es256_ffn1024_nh2_nl2_tknzrbpe_fdsFalse,0.189899,0.224072,0.205575,0.015502,0.0191,0.017114,0.000191,0.000242,0.000214,0.136197,0.159836,0.147073,0.046654,0.118206,0.066903,0.032296,0.046844,0.038233,0.037965,0.054018,0.044591


In [22]:
for i, (article, summary) in enumerate(zip(test_articles[:10], test_summaries[:10])):
    generated_summary = summarize(transformer, article, tokenizer=tokenizer, decoding_method='top_k', max_len=100)
    print(f'**Input article no.{i}**: {article}')
    print(f'**Gold label summary**: {summary}')
    print(f'**Generated summary**: {generated_summary}\n\n')

    # "M_full_dataset_bs32_es512_nh8_nl6_tknzrbpe"

**Input article no.0**: The mother and daughter who survived a tragic car accident this week, which saw three children die, have been reunited. Aluel Manyang was moved from the intensive care unit at the Royal Children's about 5.15pm on Friday, and greeted her distraught mother, Akon Goode, with a 'big hug', her father said. 'She didn't believe that her mum was still alive,' Joseph Manyang said, according to the Herald Sun. Scroll down for videos . Aueel Manyang, pictured here as a baby with her mother Akon Guode, believes her three siblings who died in the crash at a Melbourne lake were eaten by crocodiles in the water . Ms Guode visited her daughter for the first time but did not stay the night in the hospital. Mr Manyang said his daughter was expected to make a '100 per cent' recovery and she should be allowed to go home within four days. The five-year-old girl who survived when a car driven by her mother plunged into a lake believes her three siblings who died in the crash were eat

In [16]:
tmp_txt = """
In the early days computers were much simpler. The var-
ious components of a system, such as the CPU, memory,
mass storage, and network interfaces, were developed to-
gether and, as a result, were quite balanced in their per-
formance. For example, the memory and network inter-
faces were not (much) faster than the CPU at providing
data.
This situation changed once the basic structure of com-
puters stabilized and hardware developers concentrated
on optimizing individual subsystems. Suddenly the per-
formance of some components of the computer fell sig-
nificantly behind and bottlenecks developed. This was
especially true for mass storage and memory subsystems
which, for cost reasons, improved more slowly relative
to other components.
The slowness of mass storage has mostly been dealt with
using software techniques: operating systems keep most
often used (and most likely to be used) data in main mem-
ory, which can be accessed at a rate orders of magnitude
faster than the hard disk. Cache storage was added to the
storage devices themselves, which requires no changes in
the operating system to increase performance.1 For the
purposes of this paper, we will not go into more details
of software optimizations for the mass storage access
"""

summarize(transformer, tmp_txt, tokenizer=tokenizer, decoding_method='greedy', max_len=100)

'The new new new technology is a new new new technology . The new technology is the most expensive is a few years .'

In [6]:
summarize(transformer, train_articles[2], tokenizer=tokenizer, decoding_method='top_k', max_len=100)

'The couple had been arrested after the two weeks of the year . She was taken to a hospital in the UK after being released on the hospital . The boy was found dead at the time of the hospital .'

In [7]:
summarize(transformer, train_articles[2], tokenizer=tokenizer, decoding_method='greedy', max_len=100)

'The man was found in a hospital in the hospital in 2012 . The couple had been charged with a hospital in the hospital . The couple were found in the hospital in the hospital .'