## Finetuning T5 using LoRA on GoEmotions and DialogSum Dataset

In [4]:
from datasets import load_dataset, concatenate_datasets, DatasetDict

#load datasets
goemotions = load_dataset("google-research-datasets/go_emotions", "simplified")
dialogsum = load_dataset("knkarthick/dialogsum")


#preprocess GoEmotions dataset
def preprocess_goemotions(batch):
  inputs = [f"Emotion: {label} Context: {text}" for text, label in zip(batch["text"], batch["labels"])]
  targets = batch["text"] #use the same text as target for simplicity
  return {"input_text": inputs, "target_text": targets}

processed_goemotions = goemotions.map(
    preprocess_goemotions, batched=True, remove_columns=goemotions["train"].column_names)


#preprocess DialogSum dataset
def preprocess_dialogsum(batch):
    inputs = [f"summarize: {dialogue}" for dialogue in batch["dialogue"]]
    targets = batch["summary"]
    return {"input_text": inputs, "target_text": targets}

processed_dialogsum = dialogsum.map(
    preprocess_dialogsum, batched=True, remove_columns=dialogsum["train"].column_names)

Generating train split: 100%|██████████| 43410/43410 [00:00<00:00, 1303811.99 examples/s]
Generating validation split: 100%|██████████| 5426/5426 [00:00<00:00, 843962.53 examples/s]
Generating test split: 100%|██████████| 5427/5427 [00:00<00:00, 874404.11 examples/s]
Generating train split: 100%|██████████| 12460/12460 [00:00<00:00, 78968.72 examples/s]
Generating validation split: 100%|██████████| 500/500 [00:00<00:00, 65683.79 examples/s]
Generating test split: 100%|██████████| 1500/1500 [00:00<00:00, 126767.20 examples/s]
Map: 100%|██████████| 43410/43410 [00:00<00:00, 227634.11 examples/s]
Map: 100%|██████████| 5426/5426 [00:00<00:00, 316654.75 examples/s]
Map: 100%|██████████| 5427/5427 [00:00<00:00, 324950.93 examples/s]
Map: 100%|██████████| 12460/12460 [00:00<00:00, 301598.73 examples/s]
Map: 100%|██████████| 500/500 [00:00<00:00, 149433.66 examples/s]
Map: 100%|██████████| 1500/1500 [00:00<00:00, 250147.35 examples/s]


In [2]:
#combining train and test splits because we're not interested in inferencing and metrics on test split
train_set = concatenate_datasets([processed_goemotions['train'], processed_goemotions['test'],
                                  processed_dialogsum['train'], processed_dialogsum['test']])
val_set = concatenate_datasets([processed_goemotions['validation'], processed_dialogsum['validation']])

#combine into one DatasetDict
dataset = DatasetDict({"train": train_set, "validation": val_set})
dataset

DatasetDict({
    train: Dataset({
        features: ['input_text', 'target_text'],
        num_rows: 62797
    })
    validation: Dataset({
        features: ['input_text', 'target_text'],
        num_rows: 5926
    })
})

In [3]:
from transformers import T5Tokenizer

model_name = "t5-base"
tokenizer = T5Tokenizer.from_pretrained(model_name)

#tokenize the datasets
def tokenize_function(batch):
    model_inputs = tokenizer(
        batch["input_text"],
        padding=True,
        truncation=True,
        max_length=1024
    )
    labels = tokenizer(
        batch["target_text"],
        padding=True,
        truncation=True,
        max_length=256
        )["input_ids"]
    labels = [
        [(label if label != tokenizer.pad_token_id else -100) for label in label_seq]
        for label_seq in labels
    ]
    model_inputs["labels"] = labels
    return model_inputs

tokenized_dataset = dataset.map(tokenize_function, batched=True)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Map: 100%|██████████| 62797/62797 [00:28<00:00, 2167.52 examples/s]
Map: 100%|██████████| 5926/5926 [00:02<00:00, 2828.22 examples/s]


In [5]:
from transformers import T5ForConditionalGeneration
from peft import LoraConfig, get_peft_model, TaskType

#define LoRA configuration
lora_config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,  # Sequence-to-sequence task
    r=32,                            # LoRA rank
    lora_alpha=32,                   # Scaling factor
    lora_dropout=0.1,                # Regularization
)

#load pre-trained T5 model
model = T5ForConditionalGeneration.from_pretrained(model_name)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

trainable params: 3,538,944 || all params: 226,442,496 || trainable%: 1.5628


In [6]:
from transformers import TrainingArguments, Trainer, DataCollatorForSeq2Seq
import warnings
warnings.filterwarnings("ignore", category=UserWarning)


training_args = TrainingArguments(
    output_dir="./t5-lora",
    label_names=["labels"],
    eval_strategy="epoch",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=5e-4, #higher lr for LoRA
    num_train_epochs=3,
    weight_decay=0.01,
    save_strategy="epoch",
    save_total_limit=1,
    logging_dir="./logs",
    logging_steps=10,
    fp16=True, #mixed precision
    report_to="none"
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset['train'],
    eval_dataset=tokenized_dataset['validation'],
    data_collator=data_collator,
    processing_class=tokenizer,
)

trainer.train()

model.save_pretrained("./t5-lora")
tokenizer.save_pretrained("./t5-lora")

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Epoch,Training Loss,Validation Loss
1,0.4701,0.092581
2,0.3866,0.088932
3,0.3705,0.087933


('./t5-lora/tokenizer_config.json',
 './t5-lora/special_tokens_map.json',
 './t5-lora/spiece.model',
 './t5-lora/added_tokens.json')