# Fine-tuning LLM for Particle Accelerators

This code is inspired by 
https://medium.com/@ud.chandra/instruction-fine-tuning-llama-2-with-pefts-qlora-method-d6a801ebb19

In [None]:
'''
!pip install --upgrade accelerate
!pip install --upgrade datasets
!pip install --upgrade bitsandbytes
!pip install --upgrade transformers
!pip install --upgrade peft
!pip install --upgrade deepspeed
!pip install --upgrade optimum
'''

In [None]:
import gzip
import pickle
import re
import sys
from glob import glob
from unidecode import unidecode

import datasets
import matplotlib.pyplot as plt
import numpy as np
import torch
import transformers
from bs4 import BeautifulSoup
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)
from trl import SFTTrainer
from simcse import SimCSE
from tqdm.notebook import tqdm

sys.path.insert(0, "../code")
import os

from core import *
from nltk.tokenize import sent_tokenize

def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

    
def prompt_formatter(question, answer = ""):
    return f'### Human:\n{question}\n### Assistant:\n{answer}'
    # return f'USER: {question}\nASSISTANT: {answer}</s>'


In [None]:
IMG_TAG = re.compile("<image:.*>")
NEW_LINE = re.compile(r"-\n")

os.environ["WANDB_DISABLED"] = "true"

smoothen = lambda l, N=12: np.convolve(l, np.ones(N) / N, mode="valid")

# model_id = "EleutherAI/gpt-neox-20b"
# model_id = 'lmsys/vicuna-7b-v1.5-16k'
# model_id = "lmsys/vicuna-7b-v1.5"
model_id = "lmsys/vicuna-7b-v1.5-16k"
# model_id = "openlm-research/open_llama_3b"
# model_id = 'openlm-research/open_llama_7b'
# model_id = 'NousResearch/Nous-Hermes-llama-2-7b'
# model_id = "NousResearch/Nous-Hermes-13b"
# model_id = "NousResearch/Llama-2-13b-chat-hf"
# model_id = "NousResearch/Llama-2-7b-chat-hf"
# model_id = "tiiuae/falcon-7b"
# model_id = 'tiiuae/falcon-7b-instruct'
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(model_id, legacy=True, use_fast = False)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    # load_in_8bit = True,
    # quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.float16,
    trust_remote_code=True,
)


model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
tokenizer.pad_token = tokenizer.eos_token

config = LoraConfig(
    r=64,
    lora_alpha=128,
    # target_modules=["query_key_value"],
    target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, config)
print_trainable_parameters(model)
tokenizer.pad_token = tokenizer.eos_token


# viz https://huggingface.co/docs/transformers/v4.32.0/en/perf_train_gpu_one#optimizer-choice
# using flash attention - not enabled, because there is a problem with saving a training loop
# model = model.to_bettertransformer()


## Dataset loading

In [None]:
def prepare_pretrain(filename, crit = lambda text : True):
'''
This script prepares the raw text (preprocessed MMD from nougat) and removes some redundant parts for the unsupervised 
fine-tuning
'''
    train_sentences_ = []
    with gzip.open(filename,'rb') as f:
        x = pickle.load(f)
        text_set = set()
        for qa in tqdm(x):
            text = qa['metadata']['text']
            if crit(text):
                if not text in text_set:
                    text = re.sub('^\s+','',text)
                    text = re.sub('\s+$','',text)
                    text = re.sub('#','', text)
                    train_sentences_.append(text)
                    text_set.update([text])
    return train_sentences_

def prepare_finetune(filename):
'''
This script prepares the raw text (preprocessed MMD from nougat) and removes some redundant parts for the supervised 
fine-tuning
'''
    train_sentences_ = []
    with gzip.open(filename,'rb') as f:
        x = pickle.load(f)
        for x_ in tqdm(x):
            for qa in x_['pairs']:
                q = qa['question']
                a = qa['answer']
                
                # removing redundant white spaces
                q = re.sub('\s+$','',q)
                a = re.sub('\s+$','',a)
                
                # removing redunadnt white spaces (end)
                q = re.sub('^\s+','',q)
                a = re.sub('^\s+','',a)
                
                # removing "answer" texts generated by vicuna
                a = re.sub('^\s*Answer:','',a)
                a = re.sub('^\s*Answers:','',a)
                a = re.sub('^\s*A:','',a)
                
                # removing "question" texts generated by vicuna
                a = re.sub('^\s*Question:','',a)
                a = re.sub('^\s*Questions:','',a)
                a = re.sub('^\s*Q:','',a)
                
                # removing numbers in front 1.
                a = re.sub('^\d+\.\s*','',a)
                q = re.sub('^\d+\.\s*','',q)
                
                train_sentences_.append(prompt_formatter(q,a))
    return train_sentences_

data_folder = '../data/'
sentences = []
sentences.extend(prepare_pretrain(data_folder + 'arxiv_pretrain.pickle.gzip', lambda x : True))# 'DESY' in x))
sentences.extend(prepare_pretrain(data_folder + 'books_accelerators_pretrain.pickle.gzip', lambda x : True))
sentences.extend(prepare_finetune(data_folder + 'books_accelerators_qa_vicuna.pickle.gzip'))
sentences.extend(prepare_pretrain(data_folder + 'proc_pretrain.pickle.gzip', lambda x : True))# 'DESY' in x))
sentences.extend(prepare_finetune(data_folder + 'proc_qa_vicuna.pickle.gzip'))

MAX_LEN = tokenizer.model_max_length
MIN_LEN = 16
train_sentences = []
for sent in tqdm(sentences):  
    LEN = len(tokenizer.encode(sent))
    if LEN <= MAX_LEN and LEN > MIN_LEN:
        train_sentences.append(unidecode(sent))

In [None]:
print(len(sentences), len(train_sentences))
del sentences

In [None]:
dataset = datasets.Dataset.from_dict({"text": train_sentences})
dataset = dataset.shuffle(seed = 42)
data = datasets.DatasetDict({"train": dataset})
data = data.map(lambda samples: tokenizer(samples["text"]), batched=True)

In [None]:
# trainer = transformers.Trainer(
trainer = SFTTrainer(
    model=model,
    train_dataset=data["train"],    
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
    dataset_text_field = 'text',
    args=transformers.TrainingArguments(
        # lr_scheduler_type = 'cosine', 
        per_device_train_batch_size=2,  # 1,
        gradient_accumulation_steps=16, # 4
        warmup_ratio=0.05,
        num_train_epochs=4,
        # max_steps=100,
        auto_find_batch_size=True,
        learning_rate=5e-5,
        fp16=True,
        logging_steps=10,
        output_dir="outputs/human_assistant_prompt_all_papers",
        save_steps = 100,
        # optim = 'paged_adamw_8bit'
        # optim= "paged_adamw_32bit"
        # optim= "paged_adamw_8bit"
    ),
)
model.config.use_cache = False  # silennce the warnings. Please re-enable for inference!
trainer.train(resume_from_checkpoint = None)

In [None]:
l = np.array([t["loss"] for t in trainer.state.log_history if "loss" in t])
plt.plot(l)
plt.show()