In [None]:
# TODO: Change PATH to desired file location where results will be saved.
PATH = '.'

In [None]:
!pip3 install bert_score rouge_score

In [None]:
from datasets import load_dataset
import numpy as np
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

In [None]:
ds = load_dataset("ccdv/pubmed-summarization", "section")

In [None]:
model = "t5-base"
tokenizer = AutoTokenizer.from_pretrained(model)
model = AutoModelForSeq2SeqLM.from_pretrained(model)

In [None]:
train_size = len(ds['train'])
val_size = len(ds['validation'])
test_size = len(ds['test'])
total_size = train_size + val_size + test_size
print(f'Number of instances in training set = {train_size}; {train_size / total_size} portion of data')
print(f'Number of instances in validation set = {val_size}; {val_size / total_size} portion of data')
print(f'Number of instances in test set = {test_size}; {test_size / total_size} portion of data')

In [None]:
def chunk_paper(text, max_tokens=512, overlap=50):
    tokens = tokenizer.tokenize(text)
    chunks = []
    for i in range(0, len(tokens), max_tokens - overlap):
        chunk = tokens[i:i + max_tokens]
        chunks.append(tokenizer.convert_tokens_to_string(chunk))
    return chunks

# T5 Setup

In [None]:
abstract_lengths = []
for abstract in ds['train']['abstract']:
  abs_tokens = tokenizer.tokenize(abstract)
  abstract_lengths.append(len(abs_tokens))

In [None]:
print(f'Average abstract length = {np.mean(abstract_lengths)}.')
print(f'Max abstract length = {max(abstract_lengths)}.')
print(f'5th percentile abstract length = {np.quantile(abstract_lengths, 0.05)}.')
print(f'25th percentile abstract length = {np.quantile(abstract_lengths, 0.25)}.')
print(f'75th percentile abstract length = {np.quantile(abstract_lengths, 0.75)}.')
print(f'95th percentile abstract length = {np.quantile(abstract_lengths, 0.95)}.')
print(f'99th percentile abstract length = {np.quantile(abstract_lengths, 0.99)}.')

In [None]:
paper_chunked = chunk_paper(ds['train'][1]['article'], max_tokens=512, overlap=128)
chunk_summaries = []
for chunk in paper_chunked:
    chunk = "Summarize:" + chunk
    inputs = tokenizer(chunk, return_tensors="pt", truncation=True, max_length=1024)
    summary_ids = model.generate(**inputs, max_length=256, min_length=64)
    chunk_summaries.append(tokenizer.decode(summary_ids[0], skip_special_tokens=True))

all_summaries = ' '.join(chunk_summaries)
final_summary = model.generate( tokenizer(all_summaries, return_tensors="pt", truncation=True, max_length=1024)["input_ids"], min_length=100, max_length=606)
model_summary_text = tokenizer.decode(final_summary[0], skip_special_tokens=True)


In [None]:
from sentence_transformers import SentenceTransformer, util

eval_model = SentenceTransformer('all-mpnet-base-v2')

texts = [model_summary_text, ds['train'][1]['abstract']]

embeddings = eval_model.encode(texts, convert_to_tensor=True)

similarity = util.cos_sim(embeddings[0], embeddings[1])
print(f"Semantic similarity: {similarity.item():.3f}")


# Full Fine Tuning

In [None]:
!pip install -q peft accelerate

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
base_model = "t5-base"

tokenizer = AutoTokenizer.from_pretrained(base_model)
t5_model = AutoModelForSeq2SeqLM.from_pretrained(base_model)

t5_model.to(device)


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
t5_model.to(device)

class PubMedSeq2SeqDataset(Dataset):
    def __init__(self, hf_dataset, tokenizer, inputLen_max=512, targetLen_max=256, size=None):

        if size is not None:
            self.data = hf_dataset.select(range(size))
        else:
            self.data = hf_dataset
        self.tokenizer = tokenizer
        self.inputLen_max = inputLen_max
        self.targetLen_max = targetLen_max

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        sourceText = "summarize: " + item["article"]
        targetText = item["abstract"]


        enc = self.tokenizer(
            sourceText,
            truncation=True,
            padding="max_length",
            max_length=self.inputLen_max,
            return_tensors="pt",
        )


        dec = self.tokenizer(
            targetText,
            truncation=True,
            padding="max_length",
            max_length=self.targetLen_max,
            return_tensors="pt",
        )

        input_ids = enc["input_ids"].squeeze(0)
        attention_mask = enc["attention_mask"].squeeze(0)
        labels = dec["input_ids"].squeeze(0)


        labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }


trainingData = PubMedSeq2SeqDataset(
    ds["train"],
    tokenizer,
    inputLen_max=512,
    targetLen_max=256,
    size=600,
)
train_loader = DataLoader(trainingData, batch_size=2, shuffle=True)

len(trainingData)


In [None]:
from torch.optim import AdamW

t5_model.train()
optimizer = AdamW(t5_model.parameters(), lr=5e-5)

epochs = 5

for epoch in range(epochs):
    totalLoss = 0.0

    for batch in train_loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        outputs = t5_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(t5_model.parameters(), 1.0)
        optimizer.step()

        totalLoss += loss.item()

    avg_loss = totalLoss / len(train_loader)
    print(f"Epoch {epoch+1}/{epochs} - train loss: {avg_loss:.4f}")

t5_model.eval()


In [None]:
def generate_ft_summary(article_text, inputLen_max=512, max_sum_len=256):
    enc = tokenizer(
        "summarize: " + article_text,
        return_tensors="pt",
        truncation=True,
        max_length=inputLen_max,
    ).to(device)

    with torch.no_grad():
        gen_ids = t5_model.generate(
            **enc,
            max_length=max_sum_len,
            num_beams=4,
            early_stopping=True,
        )

    return tokenizer.decode(gen_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)


In [None]:
from transformers import AutoModelForSeq2SeqLM
import numpy as np
import torch
from tqdm import tqdm
import pandas as pd

# Load models
t5_base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device)
t5_base.eval()
t5_model.eval()

def summary_gen(model, article_text, inputLen_max=512, max_sum_len=512):
    enc = tokenizer(
        "summarize: " + article_text,
        return_tensors="pt",
        truncation=True,
        max_length=inputLen_max,
    ).to(device)

    with torch.no_grad():
        gen_ids = model.generate(
            **enc,
            max_length=max_sum_len,
            num_beams=4,
            early_stopping=True,
        )

    return tokenizer.decode(
        gen_ids[0],
        skip_special_tokens=True,
        clean_up_tokenization_spaces=True,
    )


In [None]:
n_eval = 200

generated_summaries = []
expected_summaries = []

for i in range(n_eval):
    article_text = ds["test"][i]["article"]
    expected_abstract = ds["test"][i]["abstract"]

    inputs = tokenizer(
        "summarize: " + article_text,
        return_tensors="pt",
        truncation=True,
        max_length=512
    ).to(t5_base.device)

    with torch.no_grad():
        gen_ids = t5_base.generate(
            **inputs,
            max_length=512,
            num_beams=4,
            early_stopping=True
        )

    generated_summary = tokenizer.decode(
        gen_ids[0],
        skip_special_tokens=True,
        clean_up_tokenization_spaces=True
    )

    # storing the text summaries
    generated_summaries.append(generated_summary)
    expected_summaries.append(expected_abstract)

    print(f"Baseline: Processed {i+1}/{n_eval}")

df = pd.DataFrame({
    "article": [ds["test"][i]["article"] for i in range(n_eval)],
    "expected_abstract": expected_summaries,
    "generated_abstract": generated_summaries
})

df.to_csv(PATH, index=False)

In [None]:
np.save(
    PATH,
    np.array(generated_summaries, dtype=object),
    allow_pickle=True
)

In [None]:
arr_baseline = np.load(PATH, allow_pickle=True)
print(arr_baseline.shape)

In [None]:
generated_summaries_finetune = []
expected_summaries_finetune = []

for i in range(n_eval):
    article_text = ds["test"][i]["article"]
    expected_abstract = ds["test"][i]["abstract"]

    inputs = tokenizer(
        "summarize: " + article_text,
        return_tensors="pt",
        truncation=True,
        max_length=512
    ).to(t5_model.device)

    with torch.no_grad():
        gen_ids = t5_model.generate(
            **inputs,
            max_length=512,
            num_beams=4,
            early_stopping=True
        )

    generated_summary_finetune = tokenizer.decode(
        gen_ids[0],
        skip_special_tokens=True,
        clean_up_tokenization_spaces=True
    )

    # storing the summaries
    generated_summaries_finetune.append(generated_summary_finetune)
    expected_summaries_finetune.append(expected_abstract)

    print(f"Full Fine Tune: Processed {i+1}/{n_eval}")

df = pd.DataFrame({
    "article": [ds["test"][i]["article"] for i in range(n_eval)],
    "expected_abstract": expected_summaries_finetune,
    "generated_abstract": generated_summaries_finetune
})

# saving into CSV
df.to_csv(PATH, index=False)

In [None]:
np.save(
    PATH,
    np.array(generated_summaries_finetune, dtype=object),
    allow_pickle=True
)

In [None]:
arr_finetune = np.load(PATH, allow_pickle=True)
print(arr_finetune.shape)

# SK-tuning (Final script)

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn

base_model = "t5-base"

tokenizer = AutoTokenizer.from_pretrained(base_model)
t5_model = AutoModelForSeq2SeqLM.from_pretrained(base_model)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
t5_model = t5_model.to(device)

print("Loaded T5 Base and moved to device:", device)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
t5_model.to(device)


In [None]:
class PubMedSeq2SeqDataset(Dataset):
    def __init__(self, hf_dataset, tokenizer, inputLen_max=1024, targetLen_max=256, size=600):

        self.data = hf_dataset.select(range(size))
        self.tokenizer = tokenizer
        self.inputLen_max = inputLen_max
        self.targetLen_max = targetLen_max

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        sourceText = "summarize: " + item["article"]
        targetText = item["abstract"]


        enc = self.tokenizer(
            sourceText,
            truncation=True,
            padding="max_length",
            max_length=self.inputLen_max,
            return_tensors="pt",
        )


        dec = self.tokenizer(
            targetText,
            truncation=True,
            padding="max_length",
            max_length=self.targetLen_max,
            return_tensors="pt",
        )

        input_ids = enc["input_ids"].squeeze(0)
        attention_mask = enc["attention_mask"].squeeze(0)

        labels = dec["input_ids"].squeeze(0)
        labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }

trainingData = PubMedSeq2SeqDataset(ds["train"], tokenizer, size=600)
train_loader = DataLoader(trainingData, batch_size=2, shuffle=True)

len(trainingData)


In [None]:
sk_init_text = (
    "Summarize biomedical research articles into clear, structured abstracts "
    "with background, methods, results, and conclusions using formal scientific language."
)

num_virtual_tokens = 20
lambda_semantic = 0.1

embed_tokens = t5_model.encoder.embed_tokens

with torch.no_grad():
    init_ids = tokenizer(sk_init_text, return_tensors="pt").input_ids.to(device)
    init_embeds = embed_tokens(init_ids)
    init_mean = init_embeds.mean(dim=1).squeeze(0)


prompt_init = init_mean.repeat(num_virtual_tokens, 1)
prompt_embeddings = nn.Parameter(prompt_init.clone())
fixed_prompt_embeddings = prompt_init.clone().detach().to(device)

prompt_embeddings = prompt_embeddings.to(device)

optimizer = torch.optim.AdamW([prompt_embeddings], lr=5e-3)


for p in t5_model.parameters():
    p.requires_grad = False

t5_model.train()


In [None]:
epochs = 10

for epoch in range(epochs):
    totalLoss = 0.0
    total_task_loss = 0.0
    total_sem_loss = 0.0

    for batch in train_loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        batch_size = input_ids.size(0)


        prompt_batch = prompt_embeddings.unsqueeze(0).expand(batch_size, -1, -1)


        input_embeds = embed_tokens(input_ids)


        inputs_embeds = torch.cat([prompt_batch, input_embeds], dim=1)


        prompt_mask = torch.ones(
            batch_size,
            num_virtual_tokens,
            dtype=attention_mask.dtype,
            device=device,
        )
        extended_attention_mask = torch.cat([prompt_mask, attention_mask], dim=1)


        outputs = t5_model(
            inputs_embeds=inputs_embeds,
            attention_mask=extended_attention_mask,
            labels=labels,
        )
        task_loss = outputs.loss


        semantic_sim = F.cosine_similarity(
            prompt_embeddings, fixed_prompt_embeddings, dim=-1
        ).mean()
        semantic_loss = 1.0 - semantic_sim

        loss = task_loss + lambda_semantic * semantic_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        totalLoss += loss.item()
        total_task_loss += task_loss.item()
        total_sem_loss += semantic_loss.item()

    avg_loss = totalLoss / len(train_loader)
    avg_task = total_task_loss / len(train_loader)
    avg_sem = total_sem_loss / len(train_loader)

    print(
        f"Epoch {epoch+1}: total={avg_loss:.4f}, "
        f"task={avg_task:.4f}, semantic={avg_sem:.4f}"
    )

t5_model.eval()


In [None]:
def generate_sk_summary(article_text, inputLen_max=512, max_sum_len=256):
    enc = tokenizer(
        "summarize: " + article_text,
        return_tensors="pt",
        truncation=True,
        max_length=inputLen_max,
    ).to(device)

    input_ids = enc["input_ids"]
    attention_mask = enc["attention_mask"]
    batch_size = input_ids.size(0)


    prompt_batch = prompt_embeddings.unsqueeze(0).expand(batch_size, -1, -1)
    input_embeds = embed_tokens(input_ids)
    inputs_embeds = torch.cat([prompt_batch, input_embeds], dim=1)

    prompt_mask = torch.ones(
        batch_size, num_virtual_tokens,
        dtype=attention_mask.dtype,
        device=device,
    )
    extended_attention_mask = torch.cat([prompt_mask, attention_mask], dim=1)

    with torch.no_grad():
        gen_ids = t5_model.generate(
            inputs_embeds=inputs_embeds,
            attention_mask=extended_attention_mask,
            max_length=max_sum_len,
            num_beams=4,
            early_stopping=True,
        )

    return tokenizer.decode(
        gen_ids[0],
        skip_special_tokens=True,
        clean_up_tokenization_spaces=True,
    )


In [None]:
# import nltk
# nltk.download("wordnet")
# nltk.download("punkt")

# from nltk.translate.meteor_score import meteor_score

In [None]:
n_eval = 200

generated_summaries_sk = []
expected_summaries_sk = []

for i in range(n_eval):
    article = ds["test"][i]["article"]
    ref_abs = ds["test"][i]["abstract"]

    pred_sk = generate_sk_summary(article)

    # store text
    generated_summaries_sk.append(pred_sk)
    expected_summaries_sk.append(ref_abs)

    print(f"[SK] Processed {i+1}/{n_eval}")

# Build dataframe
df = pd.DataFrame({
    "article": [ds["test"][i]["article"] for i in range(n_eval)],
    "expected_abstract": expected_summaries_sk,
    "generated_abstract": generated_summaries_sk
})

# saving into CSV
df.to_csv(PATH, index=False)

In [None]:
np.save(
    PATH,
    np.array(generated_summaries_sk, dtype=object),
    allow_pickle=True
)

In [None]:
arr_baseline = np.load(PATH, allow_pickle=True)
print(arr_baseline.shape)