In [1]:
import os
import time
import torch
import csv
from datasets import load_dataset
from transformers import BartForConditionalGeneration, BartTokenizer

In [2]:
torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset_name = "narrativeqa" # "multifieldqa_en"
dataset = load_dataset("THUDM/LongBench", dataset_name, split="test", cache_dir="custom_cache_dir", trust_remote_code=True)

Generating test split:   0%|          | 0/200 [00:00<?, ? examples/s]

In [3]:
summ_model_name = 'facebook/bart-large-cnn'
summ_model = BartForConditionalGeneration.from_pretrained(summ_model_name).to(device)
summ_tokenizer = BartTokenizer.from_pretrained(summ_model_name)

In [4]:
def summarize_text(input_text, ratio):
    inputs = summ_tokenizer(input_text, return_tensors="pt", max_length=1024, truncation=True).to(device)
    input_token_length = inputs['input_ids'].shape[1]
    target_length = int(input_token_length * ratio)
    summary_ids = summ_model.generate(
        inputs['input_ids'], 
        max_length=min(target_length + 100, 1024),  # Maximum summary length based on compression ratio
        min_length=target_length,  # Optional: set a minimum length to avoid very short summaries
        # length_penalty=2.0,  # Optional: tweak the length penalty to get more compact summaries
        num_beams=4,  # Optional: use beam search for better quality
        
        early_stopping=True  # Stops early if the beam has converged
    )
    # print(input_token_length, len(summary_ids[0]))
    summary = summ_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return summary

def generate_multiple_summaries(input_text, ratio):
  idx = 0
  input_len = len(input_text)
  summary = ""
  while idx < input_len:
    end_idx = min(idx+4000, input_len)
    summary += " " + summarize_text(input_text[idx:end_idx], ratio)
    idx = end_idx + 1
  return summary

def write_dicts_to_csv(data, filename):
    file_exists = os.path.isfile(filename)
    fieldnames = list(data[0].keys())
    # Open the file in append mode if it exists, otherwise write mode
    with open(filename, mode='a' if file_exists else 'w', newline='', encoding='utf-8') as file:
        writer = csv.DictWriter(file, fieldnames=fieldnames)

        # Write the header only if the file is being created
        if not file_exists:
            writer.writeheader()

        # Write the rows
        writer.writerows(data)

In [6]:
def generate_dataset_summaries(ratio, n, start = 0):
    all_data = [None] * n
    filename = f"summarize_{dataset_name}_{int(ratio * 100)}.csv"
    
    for i in range(start, start + n):
        item = dataset[i].copy()
        start_time = time.time()
        try:
            # print(i)
            summ_context = generate_multiple_summaries(item["context"], ratio)
            # print('hi')
            item["summary"] = summ_context
            all_data[i - start] = item
        except Exception as e:
            item["summary"] = "Failed to generate summary"
            print("Failed to generate summary :", e)
            all_data[i - start] = item
        print(f"Step {i}: {time.time()-start_time}")
    
    write_dicts_to_csv(all_data, filename)
    return all_data

In [7]:
# all_data_25 = generate_dataset_summaries(ratio=0.25, n=50, start=0)

In [8]:
# all_data_50 = generate_dataset_summaries(ratio=0.5, n=50, start=0)

In [None]:
all_data_75 = generate_dataset_summaries(ratio=0.75, n=50, start=0)

Step 0: 371.13926339149475
