In [1]:
from transformers import (
    AutoTokenizer,
    LEDForConditionalGeneration,
)
from datasets import load_dataset, load_metric
import torch
import random
from tqdm import tqdm
import json
import pandas as pd

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda


In [3]:
# load our augmented dataset json file
augmented_dataset = None
with open('augmented_test.json', 'r') as f:
    augmented_dataset = json.load(f)

In [4]:
# load our model
TOKENIZER = AutoTokenizer.from_pretrained('allenai/PRIMERA')
MODEL = LEDForConditionalGeneration.from_pretrained('allenai/PRIMERA').to(device)

PAD_TOKEN_ID = TOKENIZER.pad_token_id
DOCSEP_TOKEN_ID = TOKENIZER.convert_tokens_to_ids("<doc-sep>")

In [5]:
def process_document(batch):
    # batch is a list of lists of strings
    input_ids_all=[]

    for documents in batch:
        #### concat with global attention on doc-sep
        input_ids = []
        for doc in documents:
            input_ids.extend(
                TOKENIZER.encode(
                    doc,
                    truncation=True,
                    max_length=4096 // len(documents),
                )[1:-1]
            )
            input_ids.append(DOCSEP_TOKEN_ID)
        input_ids = (
            [TOKENIZER.bos_token_id]
            + input_ids
            + [TOKENIZER.eos_token_id]
        )
        input_ids_all.append(torch.tensor(input_ids).to(device))
    input_ids = torch.nn.utils.rnn.pad_sequence(
        input_ids_all, batch_first=True, padding_value=PAD_TOKEN_ID
    ).to(device)

    return input_ids


def batch_process(batch):
    documents = [doc['documents'] for doc in batch]
    input_ids=process_document(documents)
    # get the input ids and attention masks together
    global_attention_mask = torch.zeros_like(input_ids).to(input_ids.device)
    # put global attention on <s> token

    global_attention_mask[:, 0] = 1
    global_attention_mask[input_ids == DOCSEP_TOKEN_ID] = 1

    generated_ids = MODEL.generate(
        input_ids=input_ids,
        global_attention_mask=global_attention_mask,
        use_cache=True,
        max_length=512,
        num_beams=5,
    )
    generated_str = TOKENIZER.batch_decode(
            generated_ids.tolist(), skip_special_tokens=True
        )
    result={}
    result['generated_summaries'] = generated_str
    result['gt_summaries']=[doc['summary'] for doc in batch]
    return result

In [6]:
# make a results directory if it doesn't exist
import os
if not os.path.exists('results'):
    os.mkdir('results')

In [7]:
# load the indices to test on
sample_indices = pd.read_csv('cohere_sample_indices.csv').values.flatten()
dataset_small = [augmented_dataset[i] for i in sample_indices]

batch_sz = 4
result_small = []

for i in tqdm(range(0, len(dataset_small), batch_sz)):
    result_small.append(batch_process(dataset_small[i:i+batch_sz]))
    # let's save the result every 10 batches
    if i % 20 == 0:
        with open(f"results/primera_baseline_{i}.json", "w") as f:
            json.dump(result_small, f)

with open(f"results/primera_baseline_complete.json", "w") as f:
    json.dump(result_small, f)


100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [14:57<00:00, 17.96s/it]


In [8]:
generated_summaries = [result['generated_summaries'] for result in result_small]
gt_summaries = [result['gt_summaries'] for result in result_small]

In [9]:
rouge = load_metric("rouge")

  rouge = load_metric("rouge")


Downloading builder script:   0%|          | 0.00/2.17k [00:00<?, ?B/s]

In [10]:
score=rouge.compute(predictions=generated_summaries, references=gt_summaries)
print(score['rouge1'].mid)
print(score['rouge2'].mid)
print(score['rougeL'].mid)

Score(precision=0.6459678656859029, recall=0.18840455997611502, fmeasure=0.28547745095948196)
Score(precision=0.19227593320281855, recall=0.05595492330635309, fmeasure=0.08460940671241009)
Score(precision=0.2671092590763626, recall=0.0749439437460636, fmeasure=0.11423989611855263)
