In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,6,7"


In [2]:
from datasets import load_dataset, load_from_disk
from transformers import AutoTokenizer, T5ForConditionalGeneration
from tqdm.auto import tqdm
import json
import os


In [3]:
raw_datasets = load_dataset("cnn_dailymail", "3.0.0")


In [4]:
tokenizer_path = "./checkpoint/tokenizer"
tokenized_dataset_path = "./checkpoint/tokenized_dataset"

if os.path.exists(tokenizer_path):
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
else:
    tokenizer = AutoTokenizer.from_pretrained("t5-base")
    tokenizer.save_pretrained(tokenizer_path)


In [5]:
prefix = "summarize: "

def preprocess(example):
    inputs = [prefix + doc for doc in example["article"]]
    model_inputs = tokenizer(
        inputs,
        max_length=512,
        padding="max_length",
        truncation=True,
        return_tensors=None
    )

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            example["highlights"],
            max_length=128,
            padding="max_length",
            truncation=True,
            return_tensors=None
        )

    # Replace pad_token_id (typically 0) with -100 to ignore padding in loss
    labels["input_ids"] = [
        [(token if token != tokenizer.pad_token_id else -100) for token in label]
        for label in labels["input_ids"]
    ]

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


if os.path.exists(tokenized_dataset_path):
    print("Loading tokenized dataset from checkpoint...")
    tokenized_datasets = load_from_disk(tokenized_dataset_path)
else:
    print("Tokenizing dataset...")
    tokenized_datasets = raw_datasets.map(preprocess, batched=True, remove_columns=["article", "highlights", "id"])
    tokenized_datasets.save_to_disk(tokenized_dataset_path)

# train_data = tokenized_datasets["train"]
# eval_data = tokenized_datasets["validation"]
train_data = tokenized_datasets["train"]
eval_data = tokenized_datasets["validation"]



Loading tokenized dataset from checkpoint...


In [7]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, EarlyStoppingCallback
import evaluate

def get_trainer(model, output_dir):
    training_args = Seq2SeqTrainingArguments(
        output_dir=output_dir,
        eval_strategy="epoch",
        save_strategy="epoch",
        per_device_train_batch_size=32,
        per_device_eval_batch_size=32,
        predict_with_generate=True,
        num_train_epochs=5,
        learning_rate=3e-5,
        weight_decay=0.01,
        logging_dir=f'{output_dir}/logs',
        logging_steps=10,
        save_total_limit=2,
        load_best_model_at_end=True,
        metric_for_best_model="rougeL",
        generation_max_length=128,       # <-- set appropriate max length
        generation_num_beams=4,  
        report_to="none",
        local_rank=-1,  
    )
    rouge = evaluate.load("rouge")

    import torch
    import numpy as np
    from torch.nn.parallel import DataParallel

    def compute_metrics(eval_pred):
        preds, labels = eval_pred

        # Handle tuple output (common in Hugging Face models)
        if isinstance(preds, tuple):
            preds = preds[0]

        # Ensure tensors are on CPU and handle multi-GPU gathering
        if isinstance(preds, torch.Tensor):
            if torch.cuda.device_count() > 1:
                # If using DataParallel, ensure proper gathering
                preds = preds if preds.dim() > 0 else preds.unsqueeze(0)
            preds = preds.cpu().numpy()
        if isinstance(labels, torch.Tensor):
            if torch.cuda.device_count() > 1:
                labels = labels if labels.dim() > 0 else labels.unsqueeze(0)
            labels = labels.cpu().numpy()

        # Convert to lists
        preds = preds.tolist() if isinstance(preds, np.ndarray) else preds
        labels = labels.tolist() if isinstance(labels, np.ndarray) else labels

        # Debug: Inspect data before cleaning
        print("Sample preds before cleaning:", preds[:1])
        print("Sample labels before cleaning:", labels[:1])

        # Clean predictions and labels: replace -100 with tokenizer.pad_token_id
        # Ensure token IDs are within valid range
        vocab_size = tokenizer.vocab_size
        if tokenizer.pad_token_id is None:
            tokenizer.pad_token_id = tokenizer.eos_token_id  # Fallback if pad_token_id is None
        preds = [
            [token if token != -100 else tokenizer.pad_token_id for token in seq]
            for seq in preds
        ]
        labels = [
            [token if token != -100 else tokenizer.pad_token_id for token in seq]
            for seq in labels
        ]

        # Clamp token IDs to valid range [0, vocab_size - 1]
        preds = [[int(min(max(token, 0), vocab_size - 1)) for token in seq] for seq in preds]
        labels = [[int(min(max(token, 0), vocab_size - 1)) for token in seq] for seq in labels]

        # Debug: Inspect cleaned data
        print("Sample preds after cleaning:", preds[:1])
        print("Sample labels after cleaning:", labels[:1])

        try:
            decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
            decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
            # Debug: Inspect decoded outputs
            print("Sample decoded preds:", decoded_preds[:1])
            print("Sample decoded labels:", decoded_labels[:1])
        except Exception as e:
            print("Decoding error:", e)
            print("Sample bad preds:", preds[:1])
            print("Sample bad labels:", labels[:1])
            return {"rouge1": 0.0, "rouge2": 0.0, "rougeL": 0.0}

        # Compute ROUGE scores
        result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
        
        return {
            "rouge1": round(result["rouge1"], 4),
            "rouge2": round(result["rouge2"], 4),
            "rougeL": round(result["rougeL"], 4),
        }



    data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_data,
        eval_dataset=eval_data,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
    )
    return trainer


In [8]:
# 1. Prompt Tuning
from peft import PromptTuningConfig, get_peft_model, TaskType

model_pt = T5ForConditionalGeneration.from_pretrained("t5-base")
model_pt.save_pretrained("./checkpoint/t5-prompt-base")
prompt_config = PromptTuningConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    num_virtual_tokens=20,
    tokenizer_name_or_path="t5-base"
)
model_pt = get_peft_model(model_pt, prompt_config)

trainer_pt = get_trainer(model_pt, "./checkpoint/t5-prompt-tuning")
trainer_pt.train()
results_pt = trainer_pt.evaluate()
print("Prompt Tuning Results:", results_pt)


[2025-05-28 07:45:34,176] [INFO] [real_accelerator.py:239:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/home/thaole/miniconda3/envs/hifed/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
/home/thaole/miniconda3/envs/hifed/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::runtime_error::~runtime_error()@GLIBCXX_3.4'
/home/thaole/miniconda3/envs/hifed/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__gxx_personality_v0@CXXABI_1.3'
/home/thaole/miniconda3/envs/hifed/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream::tellp()@GLIBCXX_3.4'
/home/thaole/miniconda3/envs/hifed/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::chrono::_V2::steady_clock::now()@GLIBCXX_3.4.19'
/home/thaole/miniconda3/envs/hifed/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::_M_replace_aux(unsigned long, unsigned long, unsigned long, char)@GLIBCXX_3.4'
/home/thaole/minic

Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel
1,2.306,2.343285,0.4009,0.1808,0.2799
2,2.2361,2.322655,0.4093,0.1887,0.2876
3,2.2303,2.316787,0.4105,0.1895,0.2887
4,2.1721,2.31357,0.4106,0.1895,0.2886
5,2.1812,2.312519,0.4106,0.1895,0.2886


Sample preds before cleaning: [[0, 1811, 6073, 4027, 302, 7, 986, 1891, 80, 13, 160, 11546, 7, 12, 3, 9, 13037, 3, 5, 160, 1876, 47, 15574, 15, 26, 57, 331, 3026, 45, 18936, 18, 60, 3389, 4741, 14152, 3, 5, 8, 3741, 13, 27503, 19, 12, 36, 11972, 95, 9030, 1135, 3, 5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
Sample labels before cleaning: [[1811, 6073, 4027, 302, 7, 986, 1500, 12, 428, 3, 9, 11546, 12, 3, 9, 13037, 3, 5, 71, 126, 1218, 478, 2139, 160, 9294, 18421, 15127, 7, 21, 1296, 11546, 1221, 3, 5, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,



Sample preds before cleaning: [[0, 1811, 6073, 4027, 302, 7, 986, 1891, 80, 13, 160, 11546, 7, 12, 3, 9, 13037, 3, 5, 160, 23721, 3, 13804, 95, 28, 600, 331, 3, 5, 1296, 1221, 1204, 15127, 7, 3, 5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
Sample labels before cleaning: [[1811, 6073, 4027, 302, 7, 986, 1500, 12, 428, 3, 9, 11546, 12, 3, 9, 13037, 3, 5, 71, 126, 1218, 478, 2139, 160, 9294, 18421, 15127, 7, 21, 1296, 11546, 1221, 3, 5, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -10



Sample preds before cleaning: [[0, 1811, 6073, 4027, 302, 7, 986, 1891, 80, 13, 160, 11546, 7, 12, 3, 9, 13037, 3, 5, 160, 23721, 3, 13804, 95, 28, 600, 331, 6, 11, 1296, 1221, 1204, 15127, 7, 3, 5, 96, 20349, 7, 21, 66, 8, 380, 11, 14394, 976, 3, 9, 1670, 30, 3, 9, 13301, 543, 16, 160, 564, 608, 3, 5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
Sample labels before cleaning: [[1811, 6073, 4027, 302, 7, 986, 1500, 12, 428, 3, 9, 11546, 12, 3, 9, 13037, 3, 5, 71, 126, 1218, 478, 2139, 160, 9294, 18421, 15127, 7, 21, 1296, 11546, 1221, 3, 5, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,



Sample preds before cleaning: [[0, 1811, 6073, 4027, 302, 7, 986, 1891, 80, 13, 160, 11546, 7, 12, 3, 9, 13037, 3, 5, 160, 23721, 3, 13804, 95, 28, 600, 331, 6, 11, 1296, 1221, 1204, 15127, 7, 3, 5, 96, 20349, 7, 21, 66, 8, 380, 11, 14394, 976, 3, 9, 1670, 30, 3, 9, 13301, 543, 16, 160, 564, 608, 3, 5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
Sample labels before cleaning: [[1811, 6073, 4027, 302, 7, 986, 1500, 12, 428, 3, 9, 11546, 12, 3, 9, 13037, 3, 5, 71, 126, 1218, 478, 2139, 160, 9294, 18421, 15127, 7, 21, 1296, 11546, 1221, 3, 5, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,



Sample preds before cleaning: [[0, 1811, 6073, 4027, 302, 7, 986, 1891, 80, 13, 160, 11546, 7, 12, 3, 9, 13037, 3, 5, 160, 23721, 3, 13804, 95, 28, 600, 331, 6, 11, 1296, 1221, 1204, 15127, 7, 3, 5, 96, 20349, 7, 21, 66, 8, 380, 11, 14394, 976, 3, 9, 1670, 30, 3, 9, 13301, 543, 16, 160, 564, 608, 3, 5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
Sample labels before cleaning: [[1811, 6073, 4027, 302, 7, 986, 1500, 12, 428, 3, 9, 11546, 12, 3, 9, 13037, 3, 5, 71, 126, 1218, 478, 2139, 160, 9294, 18421, 15127, 7, 21, 1296, 11546, 1221, 3, 5, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,



Sample preds before cleaning: [[0, 1811, 6073, 4027, 302, 7, 986, 1891, 80, 13, 160, 11546, 7, 12, 3, 9, 13037, 3, 5, 160, 23721, 3, 13804, 95, 28, 600, 331, 6, 11, 1296, 1221, 1204, 15127, 7, 3, 5, 96, 20349, 7, 21, 66, 8, 380, 11, 14394, 976, 3, 9, 1670, 30, 3, 9, 13301, 543, 16, 160, 564, 608, 3, 5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
Sample labels before cleaning: [[1811, 6073, 4027, 302, 7, 986, 1500, 12, 428, 3, 9, 11546, 12, 3, 9, 13037, 3, 5, 71, 126, 1218, 478, 2139, 160, 9294, 18421, 15127, 7, 21, 1296, 11546, 1221, 3, 5, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,

In [9]:
# 2. Layer Freezing (freeze encoder)
model_lf = T5ForConditionalGeneration.from_pretrained("t5-base")
model_lf.save_pretrained("./checkpoint/t5-layer-base")
for param in model_lf.encoder.parameters():
    param.requires_grad = False

trainer_lf = get_trainer(model_lf, "./checkpoint/t5-layer-freeze")
trainer_lf.train()
results_lf = trainer_lf.evaluate()
print("Layer Freezing Results:", results_lf)


  trainer = Seq2SeqTrainer(


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel
1,1.5728,1.548251,0.4354,0.2093,0.3052
2,1.5717,1.539878,0.4355,0.2093,0.3053
3,1.5333,1.53702,0.4358,0.2098,0.3057
4,1.5248,1.536475,0.436,0.2095,0.3053
5,1.5302,1.536032,0.4362,0.2097,0.3055


Sample preds before cleaning: [[0, 1811, 6073, 4027, 302, 7, 986, 1891, 80, 13, 160, 11546, 7, 12, 3, 9, 13037, 3, 5, 1347, 23721, 3, 13804, 95, 28, 600, 331, 6, 11, 1296, 1221, 1204, 15127, 7, 3, 5, 96, 196, 214, 48, 1297, 2027, 19, 231, 4038, 145, 66, 13, 178, 976, 255, 845, 3, 5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100]]
Sample labels before cleaning: [[1811, 6073, 4027, 302, 7, 986, 1500, 12, 428, 3, 9, 11546, 12, 3, 9, 13037, 3, 5, 71, 126, 1218, 478, 2139, 160, 9294, 18421, 15127, 7, 21, 1296, 11546, 1221, 3, 5, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -



Sample preds before cleaning: [[0, 1811, 6073, 4027, 302, 7, 986, 1891, 80, 13, 160, 11546, 7, 12, 3, 9, 13037, 3, 5, 1347, 23721, 3, 13804, 95, 28, 600, 331, 6, 11, 1296, 1221, 1204, 15127, 7, 3, 5, 96, 196, 214, 48, 1297, 2027, 19, 231, 4038, 145, 66, 13, 178, 976, 255, 845, 3, 5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -100, -100, -100, -100, -100, -100, -100, -100]]
Sample labels before cleaning: [[1811, 6073, 4027, 302, 7, 986, 1500, 12, 428, 3, 9, 11546, 12, 3, 9, 13037, 3, 5, 71, 126, 1218, 478, 2139, 160, 9294, 18421, 15127, 7, 21, 1296, 11546, 1221, 3, 5, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100



Sample preds before cleaning: [[0, 1811, 6073, 4027, 302, 7, 986, 1891, 80, 13, 160, 11546, 7, 12, 3, 9, 13037, 3, 5, 1347, 23721, 3, 13804, 95, 28, 600, 331, 6, 11, 1296, 1221, 1204, 15127, 7, 3, 5, 96, 196, 214, 48, 1297, 2027, 19, 231, 4038, 145, 66, 13, 178, 976, 255, 845, 3, 5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -100, -100, -100, -100, -100, -100, -100, -100]]
Sample labels before cleaning: [[1811, 6073, 4027, 302, 7, 986, 1500, 12, 428, 3, 9, 11546, 12, 3, 9, 13037, 3, 5, 71, 126, 1218, 478, 2139, 160, 9294, 18421, 15127, 7, 21, 1296, 11546, 1221, 3, 5, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100



Sample preds before cleaning: [[0, 1811, 6073, 4027, 302, 7, 986, 1891, 80, 13, 160, 11546, 7, 12, 3, 9, 13037, 3, 5, 1347, 23721, 3, 13804, 95, 28, 600, 331, 6, 11, 1296, 1221, 1204, 15127, 7, 3, 5, 96, 196, 214, 48, 1297, 2027, 19, 231, 4038, 145, 66, 13, 178, 976, 255, 845, 3, 5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100]]
Sample labels before cleaning: [[1811, 6073, 4027, 302, 7, 986, 1500, 12, 428, 3, 9, 11546, 12, 3, 9, 13037, 3, 5, 71, 126, 1218, 478, 2139, 160, 9294, 18421, 15127, 7, 21, 1296, 11546, 1221, 3, 5, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -



Sample preds before cleaning: [[0, 1811, 6073, 4027, 302, 7, 986, 1891, 80, 13, 160, 11546, 7, 12, 3, 9, 13037, 3, 5, 1347, 23721, 3, 13804, 95, 28, 600, 331, 6, 11, 1296, 1221, 1204, 15127, 7, 3, 5, 96, 196, 214, 48, 1297, 2027, 19, 231, 4038, 145, 66, 13, 178, 976, 255, 845, 3, 5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100]]
Sample labels before cleaning: [[1811, 6073, 4027, 302, 7, 986, 1500, 12, 428, 3, 9, 11546, 12, 3, 9, 13037, 3, 5, 71, 126, 1218, 478, 2139, 160, 9294, 18421, 15127, 7, 21, 1296, 11546, 1221, 3, 5, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100

There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight'].


Sample preds before cleaning: [[0, 1811, 6073, 4027, 302, 7, 986, 1891, 80, 13, 160, 11546, 7, 12, 3, 9, 13037, 3, 5, 1347, 23721, 3, 13804, 95, 28, 600, 331, 6, 11, 1296, 1221, 1204, 15127, 7, 3, 5, 96, 196, 214, 48, 1297, 2027, 19, 231, 4038, 145, 66, 13, 178, 976, 255, 845, 3, 5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -100, -100, -100, -100, -100, -100, -100, -100]]
Sample labels before cleaning: [[1811, 6073, 4027, 302, 7, 986, 1500, 12, 428, 3, 9, 11546, 12, 3, 9, 13037, 3, 5, 71, 126, 1218, 478, 2139, 160, 9294, 18421, 15127, 7, 21, 1296, 11546, 1221, 3, 5, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100

In [10]:
# 3. LoRA Fine-tuning
from peft import LoraConfig

model_lora = T5ForConditionalGeneration.from_pretrained("t5-base")
model_lora.save_pretrained("./checkpoint/t5-lora-base")
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q", "v"],
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM
)
model_lora = get_peft_model(model_lora, lora_config)

trainer_lora = get_trainer(model_lora, "./checkpoint/t5-lora")
trainer_lora.train()
results_lora = trainer_lora.evaluate()
print("LoRA Results:", results_lora)


  trainer = Seq2SeqTrainer(
No label_names provided for model class `PeftModelForSeq2SeqLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel
1,1.6046,1.577658,0.4342,0.2078,0.3036
2,1.6085,1.569851,0.435,0.2084,0.3046
3,1.5659,1.566509,0.4354,0.2089,0.3048
4,1.5668,1.565153,0.4355,0.2089,0.3046
5,1.5721,1.565161,0.4351,0.2086,0.3045


Sample preds before cleaning: [[0, 1811, 6073, 4027, 302, 7, 986, 1891, 80, 13, 160, 11546, 7, 12, 3, 9, 13037, 3, 5, 1347, 23721, 3, 13804, 95, 28, 600, 331, 3, 5, 7643, 1221, 33, 4281, 15127, 7, 3, 5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -100, -100, -100, -100, -100, -100, -100]]
Sample labels before cleaning: [[1811, 6073, 4027, 302, 7, 986, 1500, 12, 428, 3, 9, 11546, 12, 3, 9, 13037, 3, 5, 71, 126, 1218, 478, 2139, 160, 9294, 18421, 15127, 7, 21, 1296, 11546, 1221, 3, 5, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -



Sample preds before cleaning: [[0, 1811, 6073, 4027, 302, 7, 986, 1891, 80, 13, 160, 11546, 7, 12, 3, 9, 13037, 3, 5, 1347, 23721, 3, 13804, 95, 28, 600, 331, 3, 5, 7643, 1221, 33, 4281, 15127, 7, 3, 5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100]]
Sample labels before cleaning: [[1811, 6073, 4027, 302, 7, 986, 1500, 12, 428, 3, 9, 11546, 12, 3, 9, 13037, 3, 5, 71, 126, 1218, 478, 2139, 160, 9294, 18421, 15127, 7, 21, 1296, 11546, 1221, 3, 5, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -



Sample preds before cleaning: [[0, 1811, 6073, 4027, 302, 7, 986, 1891, 80, 13, 160, 11546, 7, 12, 3, 9, 13037, 3, 5, 1347, 23721, 3, 13804, 95, 28, 600, 331, 3, 5, 7643, 1221, 33, 4281, 15127, 7, 3, 5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -100, -100, -100, -100, -100, -100, -100]]
Sample labels before cleaning: [[1811, 6073, 4027, 302, 7, 986, 1500, 12, 428, 3, 9, 11546, 12, 3, 9, 13037, 3, 5, 71, 126, 1218, 478, 2139, 160, 9294, 18421, 15127, 7, 21, 1296, 11546, 1221, 3, 5, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -



Sample preds before cleaning: [[0, 1811, 6073, 4027, 302, 7, 986, 1891, 80, 13, 160, 11546, 7, 12, 3, 9, 13037, 3, 5, 1347, 23721, 47, 15574, 15, 26, 57, 600, 331, 3, 5, 7643, 1221, 33, 4281, 15127, 7, 3, 5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -100, -100, -100, -100, -100, -100, -100]]
Sample labels before cleaning: [[1811, 6073, 4027, 302, 7, 986, 1500, 12, 428, 3, 9, 11546, 12, 3, 9, 13037, 3, 5, 71, 126, 1218, 478, 2139, 160, 9294, 18421, 15127, 7, 21, 1296, 11546, 1221, 3, 5, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,



Sample preds before cleaning: [[0, 1811, 6073, 4027, 302, 7, 986, 1891, 80, 13, 160, 11546, 7, 12, 3, 9, 13037, 3, 5, 1347, 23721, 47, 15574, 15, 26, 57, 600, 331, 3, 5, 7643, 1221, 33, 4281, 15127, 7, 3, 5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -100, -100, -100, -100, -100, -100, -100]]
Sample labels before cleaning: [[1811, 6073, 4027, 302, 7, 986, 1500, 12, 428, 3, 9, 11546, 12, 3, 9, 13037, 3, 5, 71, 126, 1218, 478, 2139, 160, 9294, 18421, 15127, 7, 21, 1296, 11546, 1221, 3, 5, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,



Sample preds before cleaning: [[0, 1811, 6073, 4027, 302, 7, 986, 1891, 80, 13, 160, 11546, 7, 12, 3, 9, 13037, 3, 5, 1347, 23721, 3, 13804, 95, 28, 600, 331, 3, 5, 7643, 1221, 33, 4281, 15127, 7, 3, 5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -100, -100, -100, -100, -100, -100, -100]]
Sample labels before cleaning: [[1811, 6073, 4027, 302, 7, 986, 1500, 12, 428, 3, 9, 11546, 12, 3, 9, 13037, 3, 5, 71, 126, 1218, 478, 2139, 160, 9294, 18421, 15127, 7, 21, 1296, 11546, 1221, 3, 5, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -

In [11]:
pip install tabulate

Note: you may need to restart the kernel to use updated packages.


In [12]:
from tabulate import tabulate
print(tabulate([
    ["Prompt Tuning", results_pt['eval_rouge1'], results_pt['eval_rouge2'], results_pt['eval_rougeL']],
    ["Layer Freezing", results_lf['eval_rouge1'], results_lf['eval_rouge2'], results_lf['eval_rougeL']],
    ["LoRA", results_lora['eval_rouge1'], results_lora['eval_rouge2'], results_lora['eval_rougeL']]
], headers=["Method", "ROUGE-1", "ROUGE-2", "ROUGE-L"]))


Method            ROUGE-1    ROUGE-2    ROUGE-L
--------------  ---------  ---------  ---------
Prompt Tuning      0.4105     0.1895     0.2887
Layer Freezing     0.4358     0.2098     0.3057
LoRA               0.4354     0.2089     0.3048


In [13]:
from transformers import pipeline

summarizer = pipeline("summarization", model=model_lora, tokenizer=tokenizer)
article = """
NASA's Perseverance rover has successfully collected samples from Mars that may contain signs of ancient microbial life. Scientists are now preparing to bring the samples back to Earth for further analysis, hoping to answer the age-old question of whether life ever existed on the red planet.
"""
summary = summarizer("summarize: " + article, max_length=128, min_length=30, do_sample=False)
print("\nExample Article:", article)
print("\nExample Summary:\n", summary[0]['summary_text'])


Device set to use cuda:0
Your max_length is set to 128, but your input_length is only 67. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=33)
Both `max_new_tokens` (=256) and `max_length`(=128) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)



Example Article: 
NASA's Perseverance rover has successfully collected samples from Mars that may contain signs of ancient microbial life. Scientists are now preparing to bring the samples back to Earth for further analysis, hoping to answer the age-old question of whether life ever existed on the red planet.


Example Summary:
 NASA's Perseverance rover has successfully collected samples from Mars . Scientists are now preparing to bring the samples back to Earth for further analysis .
