In [25]:
import torch
import glob
import os
import pickle
from rouge import Rouge
import random

In [26]:
def _get_ngrams(n, text):
    ngram_set = set()
    text_length = len(text)
    max_index_ngram_start = text_length - n
    for i in range(max_index_ngram_start + 1):
        ngram_set.add(tuple(text[i:i + n]))
    return ngram_set

def _block_tri(c, p):
    tri_c = _get_ngrams(3, c.split())
    for s in p:
        tri_s = _get_ngrams(3, s.split())
        if len(tri_c.intersection(tri_s))>0:
            return True
    return False

In [27]:
dataname = "/mnt/d/MLData/Repos/PyTorch-ExtractiveTextSummarization/ckpts/model-2g21tx8c/infer_all.pkl"
with open(dataname, 'rb') as file:
    data = pickle.load(file)

In [28]:
print(data.keys())


dict_keys(['src_txt', 'tgt_txt', 'src_sent_labels', 'scores'])


In [29]:
print(data['src_txt'][0])

['a university of iowa student has died nearly three months after a fall in rome in a suspected robbery attack in rome .', 'andrew mogni , 20 , from glen ellyn , illinois , had only just arrived for a semester program in italy when the incident happened in january .', 'he was flown back to chicago via air ambulance on march 20 , but he died on sunday .', 'andrew mogni , 20 , from glen ellyn , illinois , a university of iowa student has died nearly three months after a fall in rome in a suspected robbery', 'he was taken to a medical facility in the chicago area , close to his family home in glen ellyn .', "he died on sunday at northwestern memorial hospital - medical examiner 's office spokesman frank shuftan says a cause of death wo n't be released until monday at the earliest .", 'initial police reports indicated the fall was an accident but authorities are investigating the possibility that mogni was robbed .', "on sunday , his cousin abby wrote online : ` this morning my cousin andr

In [30]:
print(data['src_sent_labels'][0])

[0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


In [31]:
data_list = [dict(zip(data,t)) for t in zip(*data.values())]

In [32]:
print(data_list[0].keys())

dict_keys(['src_txt', 'tgt_txt', 'src_sent_labels', 'scores'])


### Evaluate top-k sentence selection

In [33]:
def get_topk_selection(dataset, num_examples,k = 3, ngram=3):
    hyps = []
    refs = []
    for i in range(min(len(dataset),num_examples)):
        example = dataset[i]
        selected = []
        scores = example['scores']
        scores = [(score, i) for i, score in enumerate(scores)]
        scores_sorted = sorted(scores)
        ngram_set = set()
        count = 0
        while scores_sorted and count < k:
            _, cand_idx = scores_sorted.pop()
            cand_sent = example['src_txt'][cand_idx]
            cand_ngram = _get_ngrams(ngram, cand_sent.split(" "))
            overlapping_count = len(ngram_set.intersection(cand_ngram))
            #print(i, overlapping_count)
            if  overlapping_count == 0:
                ngram_set.update(cand_ngram)
                selected.append(cand_sent)
                count += 1
        if not selected:
            continue
        #print(f"selected {len(selected)} sentences for {i}")
        out = " ".join(selected)
        hyps.append(out)
        ref = example['tgt_txt'].replace('<q>',' ')
        refs.append(ref)

    return hyps, refs

In [34]:
hyps, refs = get_topk_selection(data_list, 10000)

In [35]:
print(len(hyps), len(refs))

10000 10000


In [36]:
rouge_metric = Rouge()

rouge_metric.get_scores(hyps, refs, avg=True)

{'rouge-1': {'r': 0.5132470613509534,
  'p': 0.39442132228963556,
  'f': 0.43471181599596287},
 'rouge-2': {'r': 0.24006050494507203,
  'p': 0.17376429468524407,
  'f': 0.1945841089197315},
 'rouge-l': {'r': 0.40419369586770526,
  'p': 0.3087566219052895,
  'f': 0.34110305918238815}}

In [37]:
rouge_metric.get_scores(hyps, refs, avg=False)

[{'rouge-1': {'r': 0.6428571428571429,
   'p': 0.8181818181818182,
   'f': 0.7199999950720002},
  'rouge-2': {'r': 0.5285714285714286,
   'p': 0.6727272727272727,
   'f': 0.5919999950720001},
  'rouge-l': {'r': 0.625, 'p': 0.7954545454545454, 'f': 0.6999999950720001}},
 {'rouge-1': {'r': 0.84375, 'p': 0.391304347826087, 'f': 0.5346534610175474},
  'rouge-2': {'r': 0.35294117647058826,
   'p': 0.1518987341772152,
   'f': 0.21238937632390956},
  'rouge-l': {'r': 0.5625, 'p': 0.2608695652173913, 'f': 0.3564356392353691}},
 {'rouge-1': {'r': 0.48484848484848486,
   'p': 0.22535211267605634,
   'f': 0.3076923033598373},
  'rouge-2': {'r': 0.1891891891891892,
   'p': 0.07608695652173914,
   'f': 0.1085271276918456},
  'rouge-l': {'r': 0.42424242424242425,
   'p': 0.19718309859154928,
   'f': 0.26923076489829884}},
 {'rouge-1': {'r': 0.40625,
   'p': 0.24528301886792453,
   'f': 0.30588234824636684},
  'rouge-2': {'r': 0.02857142857142857,
   'p': 0.015384615384615385,
   'f': 0.0199999954500

### validate oracle baseline

In [38]:
def get_oracle_selection(dataset, num_examples):
    hyps = []
    refs = []
    for i in range(min(len(dataset),num_examples)):
        example = dataset[i]
        selected = []
        for idx, j in enumerate(example['src_sent_labels']):
            if j == 1:
                selected.append(example['src_txt'][idx])
        if not selected:
            continue
        out = " ".join(selected)
        hyps.append(out)
        ref = example['tgt_txt'].replace('<q>',' ')
        refs.append(ref)

    return hyps, refs
        

In [39]:
hyps_oracle, refs_oracle = get_oracle_selection(data_list, 1000)


In [40]:
print(len(hyps_oracle))

986


In [41]:
print(len(refs_oracle))

986


In [42]:
rouge_metric = Rouge()
rouge_metric.get_scores(hyps_oracle, refs_oracle, avg=True)

{'rouge-1': {'r': 0.5676532245422679,
  'p': 0.5775291688161241,
  'f': 0.5508185184090183},
 'rouge-2': {'r': 0.3268545562111377,
  'p': 0.33508357536518946,
  'f': 0.31426476241945256},
 'rouge-l': {'r': 0.4685939421781878,
  'p': 0.4776104916914664,
  'f': 0.4552125771451799}}

### Evaluate random-3 selection

In [43]:
def get_random_k_selection(dataset, num_examples,k = 3, ngram=3):
    hyps = []
    refs = []
    for i in range(min(len(dataset),num_examples)):
        example = dataset[i]
        selected = []
        scores = example['scores']
        scores = [(score, i) for i, score in enumerate(scores)]
        random.shuffle(scores)
        ngram_set = set()
        count = 0
        while scores and count < k:
            _, cand_idx = scores.pop()
            cand_ngram = _get_ngrams(ngram, example['src_txt'][cand_idx])
            if len(ngram_set.intersection(cand_ngram)) == 0:
                ngram_set.update(cand_ngram)
                selected.append(example['src_txt'][cand_idx])
                count += 1
        if not selected:
            continue
        out = " ".join(selected)
        hyps.append(out)
        ref = example['tgt_txt'].replace('<q>',' ')
        refs.append(ref)

    return hyps, refs

In [44]:
hyps_random, refs_random = get_random_k_selection(data_list, 100000)

In [45]:
rouge_metric = Rouge()
rouge_metric.get_scores(hyps_random, refs_random, avg=True)

{'rouge-1': {'r': 0.16794228926298452,
  'p': 0.33537787693738186,
  'f': 0.2144194438849009},
 'rouge-2': {'r': 0.04253921988056252,
  'p': 0.0950766750031537,
  'f': 0.055753131146374084},
 'rouge-l': {'r': 0.12305451876960677,
  'p': 0.24903721935619255,
  'f': 0.15763634599716145}}

### validate lead 3 baseline

In [46]:
def get_lead_k_selection(dataset, num_examples, k=3):
    hyps = []
    refs = []
    for i in range(min(len(dataset),num_examples)):
        example = dataset[i]
        selected = []
        for j in range(min(k, len(example['src_sent_labels']))):
            selected.append(example['src_txt'][j])
        out = " ".join(selected)
        hyps.append(out)
        ref = example['tgt_txt'].replace('<q>',' ')
        refs.append(ref)

    return hyps, refs

In [47]:
hyps_lead, refs_lead = get_lead_k_selection(data_list, 100000)
rouge_metric.get_scores(hyps_lead, refs_lead, avg=True)

{'rouge-1': {'r': 0.5084508230373855,
  'p': 0.3589632929014358,
  'f': 0.4097730984074396},
 'rouge-2': {'r': 0.22330570331577615,
  'p': 0.14606575470871974,
  'f': 0.17037691800181165},
 'rouge-l': {'r': 0.40554740519038524,
  'p': 0.2855664228268964,
  'f': 0.3262607094733872}}