In [None]:
# Setup and Imports
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model, TaskType
import torch
import datetime
import copy

# Import helper functions
from llm_explore.utils import (get_torch_device, print_number_of_model_parameters,
                                make_n_shot_summary_prompt, get_model_completion)

[32m2025-05-10 16:42:21.682[0m | [1mINFO    [0m | [36mllm_explore.config[0m:[36m<module>[0m:[36m11[0m - [1mPROJ_ROOT path is: /Users/zeromh/ds/llm_explore[0m


In [2]:
# User sets dataset and model names
DATASET_NAME = "knkarthick/dialogsum"
MODEL_NAME = "google/flan-t5-base"

In [3]:
# Device Configuration
device = get_torch_device()

Returned MPS device


In [7]:

# Dataset and Model Initialization
dataset = load_dataset(DATASET_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).to(device)
 # Keep a copy of the original model for later use (keep on CPU for now)
model_orig = copy.deepcopy(model).to(torch.device("cpu"))
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

# Parameter Inspection
all_params, trainable_params = print_number_of_model_parameters(model)


Total parameters: 247577856
Trainable parameters: 247577856
Percentage of trainable parameters: 100.00%


In [None]:

# Tokenization and Dataset Preparation
def tokenize_function(example):
    """Tokenizes the input and output text for the model, 
    including a hardcoded prompt to summarize the conversation."""
    
    start_prompt = "Summarize the following conversation.\n\n"
    end_prompt = "\n\nSummary: "
    prompt = [start_prompt + dialogue + end_prompt for dialogue in example['dialogue']]
    output = tokenizer(prompt, truncation=True, padding='max_length', return_tensors='pt')
    output['labels'] = tokenizer(example['summary'], truncation=True, padding='max_length', return_tensors='pt').input_ids
    return output

tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=['id', 'topic', 'dialogue', 'summary'])
#tokenized_dataset_small = tokenized_dataset.filter(lambda example, index: index % 10 == 0, with_indices=True)


In [None]:

# LoRA Configuration and Model Setup
lora_config = LoraConfig(
    r=32,
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM
)

peft_model = get_peft_model(model, lora_config) # Modifies base model in place
print_number_of_model_parameters(peft_model)


Total parameters: 251116800
Trainable parameters: 3538944
Percentage of trainable parameters: 1.41%


(251116800, 3538944)

In [12]:

# Training with LoRA
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
output_dir = f"../models/peft-dialogue-summary-training-{timestamp}"

peft_training_args = TrainingArguments(
    output_dir=output_dir,
    auto_find_batch_size=True,
    num_train_epochs=1,
    learning_rate=1e-3,
    logging_steps=20,
    per_device_train_batch_size=2,
    max_steps=-1,
    label_names=["labels"],
    include_num_input_tokens_seen=True,
    # fp16=True,  # for mixed-precision training, but doesn't work on apple silicon
)

peft_trainer = Trainer(
    model=peft_model,
    args=peft_training_args,
    train_dataset=tokenized_dataset['train']
)

# Uncomment to train
torch.mps.empty_cache()
# peft_trainer.train()

# CPU: Took 24 minutes to do 1 step which is 8 samples
# GPU: Took < 1 minute to do 125 samples
# GPU: Took 8:27 minutes to do 1250 samples


In [None]:
# Model Saving
# Don't need to run this, as Trainer automatically saves checkpoints and final model above

# peft_model.save_pretrained(f"../models/peft-dialogue-summary-training-{timestamp}_lora_results")

In [None]:

# Evaluation and Results
my_id = 200
prompt = make_n_shot_summary_prompt(summarize_id=my_id, data=dataset)
completion = get_model_completion(prompt, tokenizer, peft_model)
print(completion)

