In [1]:
import sys

from peft import LoraConfig, TaskType
from transformers import AutoModelForSeq2SeqLM

sys.path.append("../")
from src import PromtTuningConfig, HybridPeftWrapper

In [2]:
def print_number_of_trainable_model_params(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    
    print(f"Trainable model parameters: {trainable_model_params}\nAll model parameters: {all_model_params}\nPercentage of trainable parameters: {trainable_model_params/all_model_params*100:.2f}%")

Original Model

In [3]:
model_name = "google/flan-t5-base"

In [4]:
original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

In [5]:
print_number_of_trainable_model_params(original_model)

Trainable model parameters: 247577856
All model parameters: 247577856
Percentage of trainable parameters: 100.00%


PEFT Model - Original Model Only 

In [6]:
peft_model = HybridPeftWrapper.from_config(original_model)

In [7]:
print_number_of_trainable_model_params(peft_model)

Trainable model parameters: 247577856
All model parameters: 247577856
Percentage of trainable parameters: 100.00%


PEFT Model - LoRA

In [8]:
lora_config = LoraConfig(
    r=4,
    lora_alpha=4,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM,
)

peft_model = HybridPeftWrapper.from_config(original_model, lora_config=lora_config)

In [9]:
print_number_of_trainable_model_params(peft_model)

Trainable model parameters: 442368
All model parameters: 248020224
Percentage of trainable parameters: 0.18%


PEFT Model - Prompt Tuning

In [10]:
pt_config = PromtTuningConfig(
    n_tokens=30,
    initialize_from_vocab=True,
)

peft_model = HybridPeftWrapper.from_config(original_model, pt_config=pt_config)

In [11]:
print_number_of_trainable_model_params(peft_model)

Trainable model parameters: 23040
All model parameters: 248043264
Percentage of trainable parameters: 0.01%


PEFT Model - LoRA and Prompt Tuning

In [12]:
peft_model = HybridPeftWrapper.from_config(original_model, lora_config=lora_config, pt_config=pt_config)

In [13]:
print_number_of_trainable_model_params(peft_model)

Trainable model parameters: 465408
All model parameters: 248043264
Percentage of trainable parameters: 0.19%
