Part 2 (5 points): Fine-tuning with PEFT and LoRA

In [None]:
%pip install --quiet transformers==4.37.2 accelerate==0.24.0 sentencepiece==0.1.99 optimum==1.13.2 peft==0.5.0 bitsandbytes==0.41.2.post2 datasets==2.14.7

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm.auto import tqdm, trange
import torch
import torch.nn as nn
import torch.nn.functional as F
import peft
import transformers
from datasets import load_dataset
import random
const_seed = 100

In [None]:
!pip install flash-attn --no-build-isolation

Collecting flash-attn
  Downloading flash_attn-2.5.4.tar.gz (2.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m38.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting einops (from flash-attn)
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
Collecting ninja (from flash-attn)
  Downloading ninja-1.11.1.1-py2.py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl (307 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m307.2/307.2 kB[0m [31m41.2 MB/s[0m eta [36m0:00:00[0m
Building wheels for collected packages: flash-attn
  Building wheel for flash-attn (setup.py) ... [?25l[?25hdone
  Created wheel for flash-attn: filename=flash_attn-2.5.4-cp310-cp310-linux_x86_64.whl size=120038621 sha256=c4ff6919917d4ace36ed5adb65bf056b2572bb395d4a6d7c3c81fdeb6fb2fb87
  St

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# load llama tokenizer
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
tokenizer.pad_token_id = tokenizer.eos_token_id

# Note: to speed up inference you can use flash attention 2 (https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2)
model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.2", device_map='auto', low_cpu_mem_usage=True, offload_state_dict=True,
    load_in_4bit=True, torch_dtype=torch.float16,
    attn_implementation="flash_attention_2",
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
for param in model.parameters():
    param.requires_grad=False

model.gradient_checkpointing_enable()
model.enable_input_require_grads()

In [None]:
df = load_dataset("google/boolq")

In [None]:
peft_config = peft.PromptTuningConfig(task_type=peft.TaskType.CAUSAL_LM,
                                      num_virtual_tokens=16) #
model = peft.get_peft_model(model, peft_config)  # note: for most peft methods, this line also modifies model in-plac)))

In [None]:
model.print_trainable_parameters() # Wow so small amount of trainable params

trainable params: 65,536 || all params: 7,241,797,632 || trainable%: 0.000904968673943746


In [None]:
# creating simple prompt formating
def format_prompt(sample):
    return f'''
    text: {sample['passage']}
    question: {sample['question']}
    answer: {sample['answer']}
    '''

In [None]:
train = [format_prompt(df["train"][k]) for k in range(0, len(df["train"]))]
valid= [format_prompt(df["validation"][k]) for k in range(0,len(df["validation"]))]

In [None]:
len(train)

9427

In [None]:
train_without_true_false = [train[i].replace("True", "").replace("False", "").replace("TRUE", "").replace("FALSE", "").replace("false", "").replace("true", "") for i in range(len(train))]
valid_without_true_false = [valid[i].replace("True", "").replace("False", "").replace("TRUE", "").replace("FALSE", "").replace("false", "").replace("true", "") for i in range(len(valid))]

In [None]:
tlabel_dataset = Dataset.from_dict({"prompt": train})
vlabel_dataset = Dataset.from_dict({"prompt": valid})
train_dataset = Dataset.from_dict({"prompt": train_without_true_false})
valid_dataset = Dataset.from_dict({"prompt": valid_without_true_false})

In [None]:
train_labels = [label for label in tlabel_dataset['prompt']]  # 从tlabel_dataset提取标签
valid_labels = [label for label in vlabel_dataset['prompt']]  # 从vlabel_dataset提取标签

# 将标签信息添加到train_dataset和valid_dataset中
train_dataset = Dataset.from_dict({"prompt": train_without_true_false, "completion": train_labels})
valid_dataset = Dataset.from_dict({"prompt": valid_without_true_false, "completion": valid_labels})


In [None]:
print(train_dataset[0])

{'prompt': '\n    text: Persian (/ˈpɜːrʒən, -ʃən/), also known by its endonym Farsi (فارسی fārsi (fɒːɾˈsiː) ( listen)), is one of the Western Iranian languages within the Indo-Iranian branch of the Indo-European language family. It is primarily spoken in Iran, Afghanistan (officially known as Dari since 1958), and Tajikistan (officially known as Tajiki since the Soviet era), and some other regions which historically were Persianate societies and considered part of Greater Iran. It is written in the Persian alphabet, a modified variant of the Arabic script, which itself evolved from the Aramaic alphabet.\n    question: do iran and afghanistan speak the same language\n    answer: \n    ', 'completion': '\n    text: Persian (/ˈpɜːrʒən, -ʃən/), also known by its endonym Farsi (فارسی fārsi (fɒːɾˈsiː) ( listen)), is one of the Western Iranian languages within the Indo-Iranian branch of the Indo-European language family. It is primarily spoken in Iran, Afghanistan (officially known as Dari si

In [None]:
from transformers import AutoTokenizer
from datasets import Dataset


def preprocess_function(examples):
    prompt_text = examples["prompt"]
    completion_text = examples["completion"]

    tokenized_prompt = tokenizer(prompt_text, padding="max_length", truncation=True, max_length=128)

    tokenized_examples = {
        "completion": completion_text  # 将completion文本放入tokenized_examples
    }

    return tokenized_examples



tokenized_train_dataset = train_dataset.map(preprocess_function, batched=True)
tokenized_valid_dataset = valid_dataset.map(preprocess_function, batched=True)

Map:   0%|          | 0/9427 [00:00<?, ? examples/s]

Map:   0%|          | 0/3270 [00:00<?, ? examples/s]

In [None]:
print(tokenized_train_dataset[0])

{'prompt': '\n    text: Persian (/ˈpɜːrʒən, -ʃən/), also known by its endonym Farsi (فارسی fārsi (fɒːɾˈsiː) ( listen)), is one of the Western Iranian languages within the Indo-Iranian branch of the Indo-European language family. It is primarily spoken in Iran, Afghanistan (officially known as Dari since 1958), and Tajikistan (officially known as Tajiki since the Soviet era), and some other regions which historically were Persianate societies and considered part of Greater Iran. It is written in the Persian alphabet, a modified variant of the Arabic script, which itself evolved from the Aramaic alphabet.\n    question: do iran and afghanistan speak the same language\n    answer: \n    ', 'completion': '\n    text: Persian (/ˈpɜːrʒən, -ʃən/), also known by its endonym Farsi (فارسی fārsi (fɒːɾˈsiː) ( listen)), is one of the Western Iranian languages within the Indo-Iranian branch of the Indo-European language family. It is primarily spoken in Iran, Afghanistan (officially known as Dari si

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
from transformers import TrainingArguments
training_args = TrainingArguments(
    output_dir="/content/drive/My Drive/model_output",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=2,
    logging_dir="/content/drive/My Drive/logs",
    logging_steps=10,
    save_steps=100,
    evaluation_strategy="epoch",
    save_strategy="steps",
    learning_rate=2e-5,
    weight_decay=0.01,
    gradient_accumulation_steps=2
)

In [None]:
!pip install trl


Collecting trl
  Downloading trl-0.7.11-py3-none-any.whl (155 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m155.3/155.3 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
Collecting tyro>=0.5.11 (from trl)
  Downloading tyro-0.7.3-py3-none-any.whl (79 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.8/79.8 kB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
Collecting docstring-parser>=0.14.1 (from tyro>=0.5.11->trl)
  Downloading docstring_parser-0.15-py3-none-any.whl (36 kB)
Collecting shtab>=1.5.6 (from tyro>=0.5.11->trl)
  Downloading shtab-1.7.0-py3-none-any.whl (14 kB)
Installing collected packages: shtab, docstring-parser, tyro, trl
Successfully installed docstring-parser-0.15 shtab-1.7.0 trl-0.7.11 tyro-0.7.3


In [None]:
from trl import SFTTrainer
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_valid_dataset,
    packing=True
)


Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]



In [None]:
trainer.train()

Epoch,Training Loss,Validation Loss


Checkpoint destination directory /content/drive/My Drive/model_output/checkpoint-100 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory /content/drive/My Drive/model_output/checkpoint-200 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory /content/drive/My Drive/model_output/checkpoint-300 already exists and is non-empty.Saving will proceed but saved results may be invalid.


Epoch,Training Loss,Validation Loss
0,2.0569,2.022271
1,2.0061,2.008353




TrainOutput(global_step=784, training_loss=2.0248886079204325, metrics={'train_runtime': 18046.3916, 'train_samples_per_second': 0.348, 'train_steps_per_second': 0.043, 'total_flos': 2.7401048139025613e+17, 'train_loss': 2.0248886079204325, 'epoch': 2.0})

In [None]:
from google.colab import drive
drive.mount('/content/drive')
model_path = "/content/drive/My Drive/Mistral_model"
torch.save(model.state_dict(), model_path)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


TODO: initialize Trainer and pass train part of our dataset for 2-3 epoches

Note: carefully set max_seq_length and args (that are transformers.TrainingArguments)

TODO: save and check your tuned model. Provide scores on our 20 validation examples and save result to csv file

READ OUTPUT