In [None]:
# TODO: Change PATH to desired file location where results will be saved.
PATH = '.'

In [None]:
!pip install -q --upgrade transformers==4.57.1 datasets accelerate peft==0.11.1 sentencepiece evaluate sentence-transformers scikit-learn bitsandbytes rouge_score

In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer, EarlyStoppingCallback
from peft import LoraConfig, get_peft_model
import torch
import numpy as np
from rouge_score import rouge_scorer
from sklearn.metrics import precision_score, recall_score, f1_score
import pandas as pd
from sentence_transformers import SentenceTransformer, util
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from scipy import stats
from collections import Counter

In [None]:
import nltk
nltk.download("wordnet")
nltk.download("punkt")

from nltk.translate.meteor_score import meteor_score

In [None]:
# Loading PubMed Dataset
ds = load_dataset("ccdv/pubmed-summarization", "section")


In [None]:
# Setting up the t5 base model
model_name = "t5-base"

tokenizer = AutoTokenizer.from_pretrained(model_name)
base_model = AutoModelForSeq2SeqLM.from_pretrained(
    model_name,
    device_map="auto",
    dtype="auto"
)

In [None]:
# getting the token lengths of each abstract
abstract_lengths = []
for abstract in ds['train']['abstract']:
  abs_tokens = tokenizer.tokenize(abstract)
  abstract_lengths.append(len(abs_tokens))

In [None]:
# Collecting data for each abstract and finding averages and the various quartile ranges
print(f'Average abstract length = {np.mean(abstract_lengths)}.')
print(f'Max abstract length = {max(abstract_lengths)}.')
print(f'5th percentile abstract length = {np.quantile(abstract_lengths, 0.05)}.')
print(f'25th percentile abstract length = {np.quantile(abstract_lengths, 0.25)}.')
print(f'75th percentile abstract length = {np.quantile(abstract_lengths, 0.75)}.')
print(f'95th percentile abstract length = {np.quantile(abstract_lengths, 0.95)}.')
print(f'99th percentile abstract length = {np.quantile(abstract_lengths, 0.99)}.')

In [None]:
# Chunking the paper
def chunk_paper(text, max_tokens=512, overlap=50):
    tokens = tokenizer.tokenize(text)
    chunks = []
    for i in range(0, len(tokens), max_tokens - overlap):
        chunk = tokens[i:i + max_tokens]
        chunks.append(tokenizer.convert_tokens_to_string(chunk))
    return chunks

In [None]:
def get_model_device(model):
    try: return next(model.parameters()).device
    except StopIteration: return torch.device("cpu")

model_device = get_model_device(base_model)
paper_chunked = chunk_paper(ds['train'][1]['article'], max_tokens=512, overlap=50)
chunk_summaries = []

# Summarization process
for chunk in paper_chunked:
    chunk_text = "Summarize: " + chunk
    inputs = tokenizer(chunk_text, return_tensors="pt", truncation=True, max_length=1024)
    inputs = {k: v.to(model_device) for k, v in inputs.items()}

    with torch.no_grad():
        summary_ids = base_model.generate(
            **inputs,
            max_length=512,
            min_length=64,
            num_beams=4,
            early_stopping=True
        )

    chunk_summaries.append(tokenizer.decode(summary_ids[0], skip_special_tokens=True))

all_summaries = " ".join(chunk_summaries)
final_inputs = tokenizer(all_summaries, return_tensors="pt", truncation=True, max_length=1024)
final_inputs = {k: v.to(model_device) for k, v in final_inputs.items()}

with torch.no_grad():
    final_summary_ids = base_model.generate(
        **final_inputs,
        min_length=100,
        max_length=606,
        num_beams=4,
        early_stopping=True
    )


In [None]:
train_small = ds["train"].select(range(600))
eval_small  = ds["test"].select(range(200))

In [None]:
from peft import LoraConfig, get_peft_model, TaskType

# Configuration settings for LoRA
config = LoraConfig(
    r=32,
    lora_alpha=64,
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM,
    target_modules=["q","k","v","o","wi","wo"]
)

model = get_peft_model(base_model, config)
model.print_trainable_parameters()


In [None]:
max_input = 512
max_target = 256

def preprocess(batch):
    inputs = tokenizer(
        batch["article"],
        padding="max_length",
        truncation=True,
        max_length=max_input
    )
    targets = tokenizer(
        batch["abstract"],
        padding="max_length",
        truncation=True,
        max_length=max_target
    )
    inputs["labels"] = targets["input_ids"]
    return inputs

train_tokenized = train_small.map(preprocess, batched=True)
eval_tokenized  = eval_small.map(preprocess, batched=True)


In [None]:
from transformers import DataCollatorForSeq2Seq, TrainingArguments, Trainer

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

training_args = TrainingArguments(
    output_dir="t5_lora_pubmed",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    learning_rate=5e-5,
    num_train_epochs=10,
    logging_steps=10,
    save_steps=50,
    warmup_ratio=0.1,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    eval_strategy="epoch",
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_tokenized,
    eval_dataset=eval_tokenized,
    data_collator=data_collator,
    callbacks=[
        EarlyStoppingCallback(
            early_stopping_patience=2,    # we can stop after 2 consecutive epochs of validation loss not going down
            early_stopping_threshold=0.0
        )
    ]
)

In [None]:
trainer.train()


In [None]:
n_eval = 200

generated_summaries = []
expected_summaries = []

for i in range(n_eval):
    article_text = ds["test"][i]["article"]
    expected_abstract = ds["test"][i]["abstract"]

    inputs = tokenizer(
        "summarize: " + article_text,
        return_tensors="pt",
        truncation=True,
        max_length=512
    ).to(model.device)

    with torch.no_grad():
        gen_ids = model.generate(
            **inputs,
            max_length=512,
            num_beams=4,
            early_stopping=True
        )
    generated_summary = tokenizer.decode(
        gen_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True
    )

    generated_summaries.append(generated_summary)
    expected_summaries.append(expected_abstract)

    if (i+1) % 10 == 0:
        print(f"Processed test article {i+1}/{n_eval}")

df = pd.DataFrame({
    "article": [ds["test"][i]["article"] for i in range(n_eval)],
    "expected_abstract": expected_summaries,
    "generated_abstract": generated_summaries
})

df.to_csv("/content/drive/MyDrive/lora-summaries.csv", index=False)


In [None]:
np.save(
    PATH,
    np.array(generated_summaries, dtype=object),
    allow_pickle=True
)


In [None]:
arr = np.load(PATH, allow_pickle=True)
print(arr.shape)