In [15]:
from utils import read_data
from transformers import AutoModelForSeq2SeqLM, PreTrainedTokenizerFast, AutoTokenizer
from transformers.models.bart.modeling_bart import BartForConditionalGeneration
import torch.functional as F
import torch

# load data
file_path = "data/task1/train/eLife_train.jsonl"
articles, summaries = read_data(file_path)

bart_tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large')


# tokenizer = AutoTokenizer.from_pretrained("google/bigbird-pegasus-large-arxiv")
# by default encoder-attention is `block_sparse` with num_random_blocks=3, block_size=64
# model = BigBirdPegasusForConditionalGeneration.from_pretrained("google/bigbird-pegasus-large-arxiv") # 2.31G

Read data from :  data/task1/train/eLife_train.jsonl
The number of data:  4346


## Content Selection by Rouge Scores
select silient sentences

In [26]:
import numpy as np
from rouge import Rouge
from nltk import tokenize
rouge_pltrdy = Rouge()


def get_rouge2recall_scores_nopad(sentences, reference, oracle_type=None):
    # rouge_pltrdy is case sensitive
    reference = reference.lower()
    scores = [None for _ in range(len(sentences))]
    count_nonzero_rouge2recall = 0
    for i, sent in enumerate(sentences):
        sent = sent.lower()
        try:
            rouge_scores = rouge_pltrdy.get_scores(sent, reference)
            scores[i]  = rouge_scores[0]['rouge-2']['r'] # rouge2recall
        except ValueError:
            scores[i] = 0.0
        except RecursionError:
            scores[i] = 0.5 # just assign 0.5 as this sentence is simply too long
        if scores[i] > 0.0: count_nonzero_rouge2recall += 1
    # print('count_nonzero_rouge2recall=', count_nonzero_rouge2recall)
    scores = np.array(scores)
    N = len(scores)

    if oracle_type == 'padlead':
        biases = np.array([(N-i)*1e-12 for i in range(N)])
    elif oracle_type == 'padrand':
        biases = np.random.normal(scale=1e-10,size=(N,))
    else: # no pad 
        return np.array(scores)
    return np.array(scores) + biases

def compress_article(article):
    sentences = tokenize.sent_tokenize(article)
    # print(f'There are {len(sentences)} sentences.')
    reference = summaries[0]

    ## rank by ROUGH
    keep_idx = []
    scores = get_rouge2recall_scores_nopad( sentences, reference, oracle_type='padrand' )
    num_postive = sum(a > 0 for a in scores)
    rank = np.argsort(scores)[::-1][:num_postive] # only consider positive ones

    ## select high-ranked sentences
    keep_idx = []
    total_length = 0
    max_abssum_len = 1024-2
    for sent_i in rank:
        if total_length < max_abssum_len:
            sent = sentences[sent_i]
            total_length += len(bart_tokenizer.encode(sent)[1:-1]) # ignore <s> and </s>
            keep_idx.append(sent_i)
        else:
            break
    assert len(keep_idx) > 0
    ## if found nothing, selecting the top3 longest sentences
    # if len(keep_idx) == 0:
    #     sent_lengths = [len(tokenize.word_tokenize(ssent)) for ssent in sentences]
    #     keep_idx = np.argsort(sent_lengths)[::-1][:3].tolist()
    keep_idx = sorted(keep_idx) # back to original order
    filtered_sentences = [sentences[j] for j in keep_idx]
    filtered_input_text = " ".join(filtered_sentences)
    return filtered_input_text


In [19]:
compressed_articles = []
for i, article in enumerate(articles):
    print(i)
    filtered_input_text = compress_article(article)
    compressed_articles.append(compressed_articles)

There are 126 sentences.
count_nonzero_rouge2recall= 73


In [22]:
# with open(out_path, "w") as f:
#     f.write(filtered_input_text)
# print("write:", out_path)

'However , there is limited information on the timing and the relative magnitudes of maximum and minimum mortality , by local climate , age group , sex and medical cause of death . We used geo-coded mortality data and wavelets to analyse the seasonality of mortality by age group and sex from 1980 to 2016 in the USA and its subnational climatic regions . In adolescents and young adults , especially in males , death rates peaked in June/July and were lowest in December/January , driven by injury deaths . It is well-established that death rates vary throughout the year , and in temperate climates there tend to be more deaths in winter than in summer ( Campbell , 2017; Fowler et al . In a large country like the USA , which possesses distinct climate regions , the seasonality of mortality may vary geographically , due to geographical variations in mortality , localized weather patterns , and regional differences in adaptation measures such as heating , air conditioning and healthcare ( Davi

## Train BART

In [None]:
# load model
tokenizer = PreTrainedTokenizerFast.from_pretrained("facebook/bart-base") # no <pad> token
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
bart = BartForConditionalGeneration.from_pretrained("facebook/bart-base")

# text encoding
# with tokenizer.as_target_tokenizer(): # same behaviour with/w.o context manager
summary = tokenizer([lst[0]['lay_summary']], return_tensors="pt", padding="longest" ) 
target_ids, target_mask = summary["input_ids"], summary["attention_mask"] # (bsz, target_seq_len)

article = tokenizer([lst[0]['article']], return_tensors="pt", padding="max_length", truncation=True )
input_ids, attention_mask = article["input_ids"], article["attention_mask"] # (bsz, 1024)



def sequence_cross_entropy_with_logits(logits, shifted_target_ids, shifted_target_mask):
    # flatten
    logits_flat = logits.view(-1, logits.size(-1))
    targets_flatten = shifted_target_ids.view(-1)
    return F.cross_entropy(logits_flat, targets_flatten, shifted_target_mask)



bart_output = bart(
    input_ids=input_ids, 
    attention_mask=attention_mask, 
    decoder_input_ids=target_ids[:, :-1].contiguous(),
    decoder_attention_mask=target_mask[:, :-1].contiguous(),
    use_cache=False,
    return_dict=True 
)
logits = bart_output.logits # (bsz, target_seq_len-1, vocab_size), '-1' for the last position
shifted_target_ids =  target_ids[:, 1:].type(torch.LongTensor).contiguous() # (bsz, target_seq_len-1, vocab_size), '-1' for the first position
shifted_target_mask = target_mask[:, 1:].type(torch.LongTensor).contiguous()
loss = sequence_cross_entropy_with_logits(logits, shifted_target_ids, shifted_target_mask, shifted_target_mask)
