In [1]:
!pip install datasets
! pip install peft
from datasets import load_dataset, load_from_disk
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

import torch
import os
import numpy as np

from torch.utils.data import DataLoader
from torch.nn import functional as F

import json
# Mount to google drive either click it or add a block cell

# Change it to your google drive path where this notebook located.
drive_path = '/content/drive/MyDrive/Projects/CryptoniteAnalysis/'
os.chdir(drive_path)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Collecting datasets
  Downloading datasets-3.0.0-py3-none-any.whl.metadata (19 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.0.0-py3-none-any.whl (474 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m474.3/474.3 kB[0m [31m36.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m13.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (39.9 MB)
[2K

## Load Model

In [2]:
from transformers import RobertaForMultipleChoice, RobertaConfig, RobertaForCausalLM

def load_roberta_MC2GEN(import_dir):
    # Load the pre-trained multiple-choice model
    # model_name = "path_to_your_roberta_multiple_choice_model"
    # multiple_choice_model = RobertaForMultipleChoice.from_pretrained(model_name)
    # Extract the underlying RoBERTa model
    config = RobertaConfig.from_pretrained(import_dir)
    config.is_decoder = True  # Set the model to be a decoder
    config.add_cross_attention = False  # If you need cross-attention

    # Create a new model with a generative head
    generative_model = RobertaForCausalLM(config)

    trained_model = RobertaForMultipleChoice.from_pretrained(import_dir)

    # Load the weights from the multiple-choice model to the new generative model
    multiple_choice_state_dict = trained_model.roberta.state_dict()

    # Filter out the weights related to cross-attention and pooler
    filtered_state_dict = {k: v for k, v in multiple_choice_state_dict.items() if "crossattention" not in k and "pooler" not in k}

    # Load the filtered state dict
    missing_keys, unexpected_keys = generative_model.roberta.load_state_dict(filtered_state_dict, strict=False)
    return generative_model

def load_normal_model(pretrained_model_dir):
    model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_dir)
    return model

def load_lora_model(pretrained_model_dir, hyperparameters, model_name):
    lora_config_dict = hyperparameters['lora_config']
    lora_config = LoraConfig(**lora_config_dict)

    # load base model
    base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

    # load the pretrained lora parameters
    lora_params_fp = os.path.join(pretrained_model_dir, "lora_params.pth")
    lora_params = torch.load(lora_params_fp, map_location=device)

    # add the lora parameters to the model
    lora_model = get_peft_model(base_model, lora_config)
    model_dict = lora_model.state_dict()
    model_dict.update(lora_params)
    lora_model.load_state_dict(model_dict)
    return lora_model

def get_dataloaders(tokenized_dataset_fp, batch_size, subsample=False, remove_enumeration=False):
    # load the preprocessed dataset
    tokenized_datasets = load_from_disk(tokenized_dataset_fp)
    # tokenized_datasets = tokenized_datasets.filter(lambda x: x['enumeration'] == '(9)')
    # comment out the following line if you are running robertaMC2GEN
    if remove_enumeration:
        tokenized_datasets = tokenized_datasets.remove_columns(['enumeration'])

    if subsample:
        # for testing purposes
        n = 16 * 2
        tokenized_datasets['test'] = tokenized_datasets['test'].select(range(n))
        tokenized_datasets['validation'] = tokenized_datasets['validation'].select(range(n))
        tokenized_datasets['train'] = tokenized_datasets['train'].select(range(n))

    tokenized_datasets.set_format("torch")

    # initialize dataloaders
    dataloaders = {}
    dataloaders['train'] = DataLoader(tokenized_datasets['train'], batch_size=batch_size, shuffle=True)
    dataloaders['test'] = DataLoader(tokenized_datasets['test'], batch_size=batch_size)
    dataloaders['validation'] = DataLoader(tokenized_datasets['validation'], batch_size=batch_size, shuffle=True)
    # shuffle because we want to subsample

    return dataloaders

In [3]:
# Model information
BART_BASE = 'facebook/bart-base'
BART_LARGE_CNN = 'facebook/bart-large-cnn'
T5_SMALL = 'google-t5/t5-small'
T5_LARGE = 'google-t5/t5-large'
ROBERTA_MC2GEN = 'Roberta_MC2GEN'
ROBERTA_MC = 'roberta-base'
models_info = {
    BART_BASE:{
        "model_name":BART_BASE,
        "pretrained_model_path": 'Baselines/Seq2Seq/TrainingData/bart-base/epoch=5_batch=16_lr=5e-05/epoch=4',
        "pretrained_tokenizer_path": "facebook/bart-base",
        "hyperparameters_path": 'Baselines/Seq2Seq/TrainingData/bart-base/epoch=5_batch=16_lr=5e-05/hyper_parameters.json',
        "tokenized_dataset_path": 'Baselines/Seq2Seq/ProcessedDatasets/bart-base',
        "usage": "load_normal_model, parameters: pretrained_model_path",
    },
    BART_LARGE_CNN:{
        "model_name":BART_LARGE_CNN,
        "pretrained_model_path": 'Baselines/Seq2Seq/TrainingData/bart-large-cnn/epoch=3_batch=16_lr=0.0005_LoRA_teacher/epoch=2',
        "pretrained_tokenizer_path": "facebook/bart-large-cnn",
        "hyperparameters_path": 'Baselines/Seq2Seq/TrainingData/bart-large-cnn/epoch=3_batch=16_lr=0.0005_LoRA_teacher/hyper_parameters.json',
        "tokenized_dataset_path": 'Baselines/Seq2Seq/ProcessedDatasets/bart-large-cnn',
        "usage": "load_lora_model, parameters: pretrained_model_dir, hyperparameters, model_name",
    },
    T5_SMALL: {
        "model_name":T5_SMALL,
        "pretrained_model_path": 'Baselines/Seq2Seq/TrainingData/t5-small/epoch=5_batch=16_lr=5e-05_teacher/epoch=4',
        "pretrained_tokenizer_path": "google-t5/t5-small",
        "hyperparameters_path": 'Baselines/Seq2Seq/TrainingData/t5-small/epoch=5_batch=16_lr=5e-05_teacher/hyper_parameters.json',
        "tokenized_dataset_path": 'Baselines/Seq2Seq/ProcessedDatasets/t5-small',
        "usage": "load_normal_model, parameters: pretrained_model_path",
    },
    T5_LARGE: {
        "model_name":T5_LARGE,
        "pretrained_model_path": 'Baselines/Seq2Seq/TrainingData/t5-large/epoch=3_batch=16_lr=0.0005_LoRA_teacher/epoch=2',
        "pretrained_tokenizer_path": "google-t5/t5-large",
        "hyperparameters_path": 'Baselines/Seq2Seq/TrainingData/t5-large/epoch=3_batch=16_lr=0.0005_LoRA_teacher/hyper_parameters.json',
        "tokenized_dataset_path": 'Baselines/Seq2Seq/ProcessedDatasets/t5-large',
        "usage": "load_lora_model, parameters: pretrained_model_dir, hyperparameters, model_name",
    },
    ROBERTA_MC2GEN: {
        "model_name":ROBERTA_MC2GEN,
        "pretrained_model_path": 'Baselines/MultipleChoices/Models/roberta-base/checkpoint-14500',
        "pretrained_tokenizer_path": "roberta-base",
        "hyperparameters_path": 'Baselines/MultiChoiceToGen/Results/RobertaMultiChoince2Generation/hyper_parameters.json',
        "tokenized_dataset_path": 'Baselines/MultiChoiceToGen/ProcessedDatasets/',
        "usage": "load_roberta_MC2GEN, parameters: pretrained_model_path",
    },
    ROBERTA_MC: {
        "model_name": ROBERTA_MC,
        "pretrained_model_path": 'Baselines/MultipleChoices/Models/roberta-base/checkpoint-14500',
        "pretrained_tokenizer_path": "roberta-base",
        "hyperparameters_path": None,
        "tokenized_dataset_path": 'Baselines/MultipleChoices/ProcessedDatasets/TokenizedDatasets',
        "usage": "RobertaForMultipleChoice.from_pretrained, parameters: pretrained_model_path",
    }
}

# Evaluation

## Multiple Choice

In [4]:
def compute_accuracy(preds, labels):
    return (preds == labels).mean()

def evaluate(model, data_loader):
    model.eval()
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in data_loader:
            batch = {k: v.to(model.device) for k, v in batch.items()}
            input_ids = batch['input_ids']
            labels = batch['labels']

            outputs = model(**batch)
            logits = outputs.logits
            preds = torch.argmax(logits, dim=-1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    accuracy = compute_accuracy(all_preds, all_labels)

    return accuracy


In [5]:
pretrained_model_dir = models_info[ROBERTA_MC]['pretrained_model_path']
trained_model = RobertaForMultipleChoice.from_pretrained(pretrained_model_dir)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
trained_model.to(device)


data_loader = get_dataloaders(models_info[ROBERTA_MC]['tokenized_dataset_path'], batch_size=16)
data_loader = data_loader['test']
accuracy = evaluate(trained_model, data_loader)
print(f"Accuracy: {accuracy:.4f}")

Accuracy: 0.8912


## Seq2seq

In [6]:
def calculate_accuracy(logits, labels, tokenizer):
    '''
    There can be two way to calculate accuracy:
    1. compare what percentage of the output tokens are the same (expect for special tokens)
    If we want to compare number of tokens to be the same, then we can flatten the tokens and compare one by one.
    predictions = predictions.view(-1)
    labels = labels.view(-1)

    2. compare how many answers are correct in a batch
    Then if we want to compare answers, the dumb way is to first batch decode them, and then compare the decoded strings oe by one.
    Another way is to compare the tokens without decoding. But I am not sure how to deal with special tokens (sometimes it might not
    Generate correct end tokens.)

    '''
    predictions = torch.argmax(logits, dim=-1)

    pred_words = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    gold_standard_words = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # calculate correct predictions
    correct_labels, total_labels = 0, 0
    for i in range(len(pred_words)):
        if pred_words[i] == gold_standard_words[i]:
            correct_labels += 1
        total_labels += 1
    accuracy = correct_labels / total_labels
    return accuracy

def customize_loss_and_accuracy(outputs, target, tokenizer):
    '''
    Potential bugs
    The output of the model is by default the same length as the input sequence.
    Here by chance the input sequence is the same length as the labels (I padded them to be 40)
    If we want to be bug free for calculate accuracy, then we need to consider situations when the input sequence is not the same length to the labels.

    After we batch_decode in accuracy, the evaluation time is super high. So I suggest we don't calculate accuracy during training.
    And also I feel like we don't need to evaluate during training -- take too much time.
    '''
    # make the input and target the correct size (input is (batch* seq_len, dictionary_size), output is (batch*seq_len))
    loss = F.cross_entropy(input=outputs.logits.view(-1, outputs.logits.size(-1)), target=target.view(-1))
    accuracy = calculate_accuracy(logits=outputs.logits, labels=target, tokenizer=tokenizer)
    return loss, accuracy

def evaluate_model(model, tokenizer, dataloader):
    model.to(device)
    '''evaluate means validate or test'''
    # set model to eval mode
    model.eval()
    # calculate number of samples being evaluated
    total_validated_samples = 0
    # calculate total loss and total number of correct labels (weighted acuracy)
    total_loss = 0
    total_accurate = 0
    # turn off grad computation
    with torch.no_grad():
        # evaluate batch by batch
        for batch in dataloader:
            # put everything on the right device
            batch =  {k: v.to(device) for k, v in batch.items()}
            batch_size = batch['labels'].shape[0]

            # forward pass in the model
            outputs = model(**batch)

            # accumulate loss and accuracy
            loss, accuracy = customize_loss_and_accuracy(outputs, target=batch['labels'], tokenizer=tokenizer)
            total_loss += loss.item()
            total_accurate += accuracy * batch_size
            total_validated_samples += batch_size


    # calculate the loss and accuracy
    average_loss = total_loss/total_validated_samples
    accuracy = total_accurate/total_validated_samples

    # record the loss and accuracy
    record = {"avg_loss": average_loss, 'accuracy': accuracy}
    print(record)

    return record




In [8]:
# @title T5-large
model_name = T5_LARGE
pretrained_model_dir = models_info[model_name]['pretrained_model_path']

# get the lora config for the model
hyperparameters_fp = models_info[model_name]['hyperparameters_path']
with open(hyperparameters_fp, 'r') as f:
    hyperparameters = json.load(f)


model = load_lora_model(pretrained_model_dir, hyperparameters, model_name)
tokenizer = AutoTokenizer.from_pretrained(models_info[model_name]['pretrained_tokenizer_path'])
dataloaders = get_dataloaders(models_info[model_name]['tokenized_dataset_path'], batch_size=16, remove_enumeration=True)
dataloader = dataloaders['test']
record = evaluate_model(model, tokenizer, dataloader)
print(record)

  lora_params = torch.load(lora_params_fp, map_location=device)


{'avg_loss': 0.014473656210580515, 'accuracy': 0.012080896127231715}
{'avg_loss': 0.014473656210580515, 'accuracy': 0.012080896127231715}


In [9]:
# @title T5-small
model_name = T5_SMALL
pretrained_model_dir = models_info[model_name]['pretrained_model_path']

model = load_normal_model(pretrained_model_dir)
tokenizer = AutoTokenizer.from_pretrained(models_info[model_name]['pretrained_tokenizer_path'])
dataloaders = get_dataloaders(models_info[model_name]['tokenized_dataset_path'], batch_size=16,remove_enumeration=True)
dataloader = dataloaders['test']
record = evaluate_model(model, tokenizer, dataloader)
print(record)

tokenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

{'avg_loss': 0.016818452274830795, 'accuracy': 0.0056581412241465}
{'avg_loss': 0.016818452274830795, 'accuracy': 0.0056581412241465}


In [10]:
# @title Bart-large-cnn
model_name = BART_LARGE_CNN
pretrained_model_dir = models_info[model_name]['pretrained_model_path']

# get the lora config for the model
hyperparameters_fp = models_info[model_name]['hyperparameters_path']
with open(hyperparameters_fp, 'r') as f:
    hyperparameters = json.load(f)


model = load_lora_model(pretrained_model_dir, hyperparameters, model_name)
tokenizer = AutoTokenizer.from_pretrained(models_info[model_name]['pretrained_tokenizer_path'])
dataloaders = get_dataloaders(models_info[model_name]['tokenized_dataset_path'], batch_size=16,remove_enumeration=True)
dataloader = dataloaders['test']
record = evaluate_model(model, tokenizer, dataloader)
print(record)

config.json:   0%|          | 0.00/1.58k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

  lora_params = torch.load(lora_params_fp, map_location=device)


vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

{'avg_loss': 0.9476232502071, 'accuracy': 0.0}
{'avg_loss': 0.9476232502071, 'accuracy': 0.0}


In [11]:
# @title Bart-base
model_name = BART_BASE
pretrained_model_dir = models_info[model_name]['pretrained_model_path']

model = load_normal_model(pretrained_model_dir)
tokenizer = AutoTokenizer.from_pretrained(models_info[model_name]['pretrained_tokenizer_path'])
dataloaders = get_dataloaders(models_info[model_name]['tokenized_dataset_path'], batch_size=16,remove_enumeration=True)
dataloader = dataloaders['test']
record = evaluate_model(model, tokenizer, dataloader)
print(record)

config.json:   0%|          | 0.00/1.72k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

{'avg_loss': 0.026704550135225432, 'accuracy': 0.013724815536949956}
{'avg_loss': 0.026704550135225432, 'accuracy': 0.013724815536949956}


In [12]:
# @title MC2GEN
model_name = ROBERTA_MC2GEN
model = load_roberta_MC2GEN(models_info[model_name]['pretrained_model_path'])
tokenizer = AutoTokenizer.from_pretrained(models_info[model_name]['pretrained_tokenizer_path'])
dataloaders = get_dataloaders(models_info[model_name]['tokenized_dataset_path'], batch_size=16)
dataloader = dataloaders['test']
record = evaluate_model(model, tokenizer, dataloader)
print(record)

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

{'avg_loss': 8.951577186584473, 'accuracy': 0.0}
{'avg_loss': 8.951577186584473, 'accuracy': 0.0}


# (Legacy ) Some examples

## Example 1
This example is used to see if I can batch decode the results. Notice: in huggingface, almost all tokenizer interit from PreTrainedTokenizer, which inherit from PreTrainedTokenizerBase.

In [None]:
model_name = BART_BASE
tokenized_datasets = get_tokenized_datasets(model_name)
model, tokenizer = load_model_and_tokenizer(model_name)
batch = tokenized_datasets['test'][2:4]
batch =  {k: v.to(device) for k, v in batch.items()}

output = model(**batch)


In [None]:
predictions = torch.argmax(output.logits, dim=-1)

print(tokenizer.batch_decode(predictions, skip_special_tokens=True))
print(tokenizer.batch_decode(batch['labels'], skip_special_tokens=True))

['traicate', 'ca mar']
['dogmatise', 'broad minded']


## Example 2
See how to get all the special tokens during iteration, so that I can skip them when calculate accuracy

In [None]:
tokenized_datasets = get_tokenized_datasets(model_name)

batch = tokenized_datasets['test'][2:4]
batch =  {k: v.to(device) for k, v in batch.items()}

output = model(**batch)
predictions = torch.argmax(output.logits, dim=-1)

tokenizer.batch_decode(batch['labels'])

['<s>dogmatise</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>',
 '<s>broad minded</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>']

In [None]:
special_tokens = tokenizer.all_special_tokens
special_token_ids = tokenizer.all_special_ids

print("Special Tokens:", special_tokens)
print("Special Token IDs:", special_token_ids)

Special Tokens: ['<s>', '</s>', '<unk>', '<pad>', '<mask>']
Special Token IDs: [0, 2, 3, 1, 50264]


## Example 3
Get total number of different enumeration values



In [None]:
# load the original dataset
def load_dataset_from_disk():
    data_dir = 'datasets/cryptonite-official-split/'
    train_fp = data_dir + 'cryptonite-train.jsonl'
    val_fp = data_dir + 'cryptonite-val.jsonl'
    test_fp = data_dir + 'cryptonite-test.jsonl'
    datasets = load_dataset('json', data_files={'train': train_fp, 'validation': val_fp, 'test': test_fp})
    return datasets

datasets = load_dataset_from_disk()
unique_values = set(datasets['train']['enumeration'])
num_unique_values = len(unique_values)
num_unique_values

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

## Example 4

Look at number of parameters before and after LoRA

In [None]:
def count_parameters(model):
    '''
    1. Print the model's layers
    2. Print the number of trainable/non-trainable parameters in the model.
    '''
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    non_trainable_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
    print(model)
    print(f"Trainable parameters: {trainable_params}")
    print(f"Non-trainable parameters: {non_trainable_params}")

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# define model
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")

count_parameters(model)

config.json:   0%|          | 0.00/1.58k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(50264, 1024, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): BartScaledWordEmbedding(50264, 1024, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 1024)
      (layers): ModuleList(
        (0-11): 12 x BartEncoderLayer(
          (self_attn): BartSdpaAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
          (final_l

In [None]:
rank = 32
lora_config = LoraConfig(
    r=rank,
    lora_alpha=32,  # todo: idk what this does yet
    lora_dropout=0.05, # todo: idk what this does yet
    # print(model) to see all the linear layers, and do LoRA on all of them
    target_modules=['q_proj', 'k_proj', 'v_proj', 'out_proj', 'fc1', 'fc2'],
    # unfreeze the head of the model too
    modules_to_save=['lm_head']
)
lora_model = get_peft_model(model, lora_config)
count_parameters(lora_model)

PeftModel(
  (base_model): LoraModel(
    (model): BartForConditionalGeneration(
      (model): BartModel(
        (shared): Embedding(50264, 1024, padding_idx=1)
        (encoder): BartEncoder(
          (embed_tokens): BartScaledWordEmbedding(50264, 1024, padding_idx=1)
          (embed_positions): BartLearnedPositionalEmbedding(1026, 1024)
          (layers): ModuleList(
            (0-11): 12 x BartEncoderLayer(
              (self_attn): BartSdpaAttention(
                (k_proj): lora.Linear(
                  (base_layer): Linear(in_features=1024, out_features=1024, bias=True)
                  (lora_dropout): ModuleDict(
                    (default): Dropout(p=0.05, inplace=False)
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=1024, out_features=32, bias=False)
                  )
                  (lora_B): ModuleDict(
                    (default): Linear(in_features=32, out_features=1024, bias=False)
          