<a href="https://colab.research.google.com/github/stewart-lab/kmGPT/blob/fine-tuning/Unsloth_Lora_Fine_Tuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install accelerate
!pip install peft
!pip install wandb
!pip install trl
!pip install bitsandbytes
!pip install scikit-learn
!pip install "unsloth[cu118-torch230] @ git+https://github.com/unslothai/unsloth.git"
!pip install "unsloth[cu121-torch230] @ git+https://github.com/unslothai/unsloth.git"
!pip install "unsloth[cu118-ampere-torch230] @ git+https://github.com/unslothai/unsloth.git"
!pip install "unsloth[cu121-ampere-torch230] @ git+https://github.com/unslothai/unsloth.git"
!pip install xformers

Collecting accelerate
  Downloading accelerate-0.30.1-py3-none-any.whl.metadata (18 kB)
Collecting huggingface-hub (from accelerate)
  Downloading huggingface_hub-0.23.0-py3-none-any.whl.metadata (12 kB)
Collecting safetensors>=0.3.1 (from accelerate)
  Downloading safetensors-0.4.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)
Collecting fsspec>=2023.5.0 (from huggingface-hub->accelerate)
  Downloading fsspec-2024.5.0-py3-none-any.whl.metadata (11 kB)
Collecting tqdm>=4.42.1 (from huggingface-hub->accelerate)
  Downloading tqdm-4.66.4-py3-none-any.whl.metadata (57 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.6/57.6 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
Downloading accelerate-0.30.1-py3-none-any.whl (302 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.6/302.6 kB[0m [31m25.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading safetensors-0.4.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (

In [2]:
!pip install flash-attn

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [3]:
from accelerate import Accelerator
from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig
import wandb
import transformers
import torch
import glob
import pandas as pd
from tqdm import tqdm
import numpy as np
import os
import random
from unsloth import FastLanguageModel
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from torch import nn
import sys
import gc
from transformers import AdamW
from accelerate import notebook_launcher
from sklearn.model_selection import train_test_split
from accelerate import DistributedDataParallelKwargs
import time
import re
from transformers import get_cosine_schedule_with_warmup
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from transformers import TrainingArguments, BitsAndBytesConfig
import accelerate
import json
from peft import IA3Config, IA3Model, LoraConfig
import jinja2
import math
import bitsandbytes as bnb
from datasets import Dataset
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import math
from trl import setup_chat_format
from peft import prepare_model_for_kbit_training

# From this Gist: https://gist.github.com/ihoromi4/b681a9088f348942b01711f251e5f964
def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


In [4]:
!huggingface-cli login --token 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

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [5]:
!wandb login 4a376fd0ab1c0901b9d9886d0734a88b4794a7fd

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [125]:
class config:
    # General Configuration
    device_type = "gpus"
    model = "unsloth/Phi-3-mini-4k-instruct"

    # Training Configuration
    max_seq_length = 2048
    trust = True

    # Prompt Parameters
    ab_hypothesis = "There exists an interaction between the disease {a_term} and the gene {b_term}."
    bc_hypothesis = "There exists an interaction between the drug {c_term} and the gene {b_term}."
    ac_hypothesis = "The drug {c_term} has an interaction with the disease {a_term}."

    instr = "Classify this abstract as either 0 (Not Relevant) or 1 (Relevant) for evaluating the provided hypothesis."

In [126]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = config.model,
    max_seq_length = config.max_seq_length,
    load_in_4bit = True,
    trust_remote_code = config.trust,
    attn_implementation = 'flash_attention_2',
    device_map = "auto",
)

model = FastLanguageModel.get_peft_model(
    model,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    r = 4,
    lora_alpha = 8,
    lora_dropout = 0,
    bias = "none",
    use_rslora = True,
    loftq_config = None
)

==((====))==  Unsloth: Fast Mistral patching release 2024.5
   \\   /|    GPU: NVIDIA A100 80GB PCIe. Max memory: 79.15 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.3.0+cu121. CUDA = 8.0. CUDA Toolkit = 12.1.
\        /    Bfloat16 = TRUE. Xformers = 0.0.26.post1. FA = True.
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [127]:
def train_ans_prompt(hyp, abstract, instr, label, cot) -> str:
	return f"Abstract: {abstract}\nHypothesis: {hyp}\nInstructions: {instr}\nScore: {label}\nExplanation: {cot}"

def test_ans_prompt(hyp, abstract, instr, label) -> str:
	return f"Abstract: {abstract}\nHypothesis: {hyp}\nInstructions: {instr}\nScore: {label}"

def eval_ans_prompt(hyp, abstract, instr) -> str:
	return f"Abstract: {abstract}\nHypothesis: {hyp}\nInstructions: {instr}\nScore: "

In [128]:
train = pd.read_csv("./data/bigger_filtered_synthetic_train.tsv", sep="\t")
test = pd.read_csv("./data/test.tsv", sep="\t")

In [129]:
def processRowTrainText(row, prompt_fn):
    if pd.isnull(row["a_term"]):
        hypothesis = config.bc_hypothesis.format(c_term=row["c_term"], b_term=row["b_term"])
    elif pd.isnull(row["b_term"]):
        hypothesis = config.ac_hypothesis.format(c_term=row["c_term"], a_term=row["a_term"])
    elif pd.isnull(row["c_term"]):
        hypothesis = config.ab_hypothesis.format(a_term=row["a_term"], b_term=row["b_term"])
    return prompt_fn(hypothesis, row["abstract"], config.instr, int(row["label"]), row["cot"])

In [130]:
def processRowTestText(row, prompt_fn):
    if pd.isnull(row["a_term"]):
        hypothesis = config.bc_hypothesis.format(c_term=row["c_term"], b_term=row["b_term"])
    elif pd.isnull(row["b_term"]):
        hypothesis = config.ac_hypothesis.format(c_term=row["c_term"], a_term=row["a_term"])
    elif pd.isnull(row["c_term"]):
        hypothesis = config.ab_hypothesis.format(a_term=row["a_term"], b_term=row["b_term"])
    return prompt_fn(hypothesis, row["abstract"], config.instr, int(row["label"]))

In [131]:
def processRowPrompt(row, prompt_fn):
    if pd.isnull(row["a_term"]):
        hypothesis = config.bc_hypothesis.format(c_term=row["c_term"], b_term=row["b_term"])
    elif pd.isnull(row["b_term"]):
        hypothesis = config.ac_hypothesis.format(c_term=row["c_term"], a_term=row["a_term"])
    elif pd.isnull(row["c_term"]):
        hypothesis = config.ab_hypothesis.format(a_term=row["a_term"], b_term=row["b_term"])
    return prompt_fn(hypothesis, row["abstract"], config.instr)

In [132]:
train["text"] = train.apply(lambda row: processRowTrainText(row, train_ans_prompt), axis=1)
train["prompt"] = train.apply(lambda row: processRowPrompt(row, eval_ans_prompt), axis=1)
train = Dataset.from_pandas(train)

In [133]:
test["text"] = test.apply(lambda row: processRowTestText(row, test_ans_prompt), axis=1)
test["prompt"] = test.apply(lambda row: processRowPrompt(row, eval_ans_prompt), axis=1)
test = Dataset.from_pandas(test)

In [134]:
print(train["text"][170])

Abstract: Atrial fibrillation is a common, but potentially preventable, complication following coronary artery bypass graft (CABG) surgery. To assess the nature and consequences of atrial fibrillation after CABG surgery and to develop a comprehensive risk index that can better identify patients at risk for atrial fibrillation. Prospective observational study of 4657 patients undergoing CABG surgery between November 1996 and June 2000 at 70 centers located within 17 countries, selected using a systematic sampling technique. From a derivation cohort of 3093 patients, associations between predictor variables and postoperative atrial fibrillation were identified to develop a risk model, which was assessed in a validation cohort of 1564 patients. New-onset atrial fibrillation after CABG surgery. A total of 1503 patients (32.3%) developed atrial fibrillation after CABG surgery. Postoperative atrial fibrillation was associated with subsequent greater resource use as well as with cognitive cha

In [135]:
print(train["prompt"][170])

Abstract: Atrial fibrillation is a common, but potentially preventable, complication following coronary artery bypass graft (CABG) surgery. To assess the nature and consequences of atrial fibrillation after CABG surgery and to develop a comprehensive risk index that can better identify patients at risk for atrial fibrillation. Prospective observational study of 4657 patients undergoing CABG surgery between November 1996 and June 2000 at 70 centers located within 17 countries, selected using a systematic sampling technique. From a derivation cohort of 3093 patients, associations between predictor variables and postoperative atrial fibrillation were identified to develop a risk model, which was assessed in a validation cohort of 1564 patients. New-onset atrial fibrillation after CABG surgery. A total of 1503 patients (32.3%) developed atrial fibrillation after CABG surgery. Postoperative atrial fibrillation was associated with subsequent greater resource use as well as with cognitive cha

In [136]:
print(test["text"][69])

Abstract: PMID 36092956: The abnormal expression of SEC61G plays an important role in the development of various tumors. This study explored the effects of SEC61G on MAPK signaling pathway and proliferation of cervical cancer (CC) cells. shRNA was used to inhibit the expression of SEC61G and EdU to observe its effect on the proliferation of CC cell SiHa. The effect of SEC61G on invasion was evaluated by Transwell assay. TCGA database was used to analyze the influence of high or low SEC61G expression level on the overall survival of CC patients. Western blot was used to detect the expressions of SEC61G, p-RAF1, Raf1, p-MEK1/2, MEK1/2, and p-ERK1/2 in cells. SiHa cells overexpressing SEC61G (SiHa-SEC61G) and control group (SiHa-mock) were subcutaneously implanted in nude mice. The tumor growth curve was measured at the specified time points between SiHa-SEC61G and SiHa-mock. The inhibitory effect of gefitinib on SEC61G was further evaluated. In patients with CC, high SEC61G expression pr

In [137]:
print(test["prompt"][69])

Abstract: PMID 36092956: The abnormal expression of SEC61G plays an important role in the development of various tumors. This study explored the effects of SEC61G on MAPK signaling pathway and proliferation of cervical cancer (CC) cells. shRNA was used to inhibit the expression of SEC61G and EdU to observe its effect on the proliferation of CC cell SiHa. The effect of SEC61G on invasion was evaluated by Transwell assay. TCGA database was used to analyze the influence of high or low SEC61G expression level on the overall survival of CC patients. Western blot was used to detect the expressions of SEC61G, p-RAF1, Raf1, p-MEK1/2, MEK1/2, and p-ERK1/2 in cells. SiHa cells overexpressing SEC61G (SiHa-SEC61G) and control group (SiHa-mock) were subcutaneously implanted in nude mice. The tumor growth curve was measured at the specified time points between SiHa-SEC61G and SiHa-mock. The inhibitory effect of gefitinib on SEC61G was further evaluated. In patients with CC, high SEC61G expression pr

# Training

In [138]:
wandb.init(project="kmGPT", entity = "morgridge", group = "Fine Tuning", name = "Unslothed RSLora 4 & (More Filtered Labels + CoT) & Phi-3", reinit=True)

VBox(children=(Label(value='14.342 MB of 14.342 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Running Validation Accuracy,▄▁█▅▇
Running Validation F1,▅▁█▅▇
Running Validation Precision,▂▂█▁▅
Running Validation Recall,▄▁█▅▇
eval/loss,██▂▂▁▁▁▁▁▁
eval/runtime,▁▁▆▆██▇▇▆▆
eval/samples_per_second,██▃▃▁▁▂▂▃▃
eval/steps_per_second,██▃▃▁▁▂▂▃▃
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████

0,1
Running Validation Accuracy,0.81119
Running Validation F1,0.81169
Running Validation Precision,0.81227
Running Validation Recall,0.81119
eval/loss,1.39148
eval/runtime,4.4229
eval/samples_per_second,18.088
eval/steps_per_second,2.261
total_flos,2.3718677408907264e+16
train/epoch,4.85981


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011113672631068362, max=1.0…

In [139]:
from transformers.integrations import WandbCallback
class LLMSampleCB(WandbCallback):
    def __init__(self, trainer, test_dataset):
        super().__init__()
        self.test = test_dataset
        self.y = torch.tensor(self.test["label"])
        self.model, self.tokenizer = trainer.model, trainer.tokenizer

    def get_metrics(self):
        FastLanguageModel.for_inference(trainer.model)
        y_hat = []
        for i in tqdm(range(len(self.test["prompt"]))):
            prompt = self.test["prompt"][i]
            prompt_ids = tokenizer(prompt, return_tensors="pt")["input_ids"]
            out = trainer.model.generate(prompt_ids.cuda(), max_new_tokens = 1)[-1]
            response = tokenizer.decode(out)
            try:
                score = int(response[-1])
            except:
                score = 1 - self.y[i]
            y_hat.append(score)

        y_hat = torch.tensor(y_hat)

        acc = accuracy_score(self.y, y_hat)
        prec = precision_score(self.y, y_hat, average='weighted')
        recall = recall_score(self.y, y_hat, average='weighted')
        f1 = f1_score(self.y, y_hat, average='weighted')

        return acc, prec, recall, f1

    def on_evaluate(self, args, state, control,  **kwargs):
        super().on_evaluate(args, state, control, **kwargs)
        acc, prec, recall, f1 = self.get_metrics()
        self._wandb.log({"Running Validation Accuracy": acc})
        self._wandb.log({"Running Validation Precision": prec})
        self._wandb.log({"Running Validation Recall": recall})
        self._wandb.log({"Running Validation F1": f1})
        epoch = math.ceil(trainer.state.epoch)

        print(f"Epoch {epoch}:\n\tAccuracy: {acc:.3f}\n\tPrecision: {prec:.3f}\n\tRecall: {recall:.3f}\n\tF-1 Score: {f1:.3f}")


In [140]:
training_args = TrainingArguments(
    output_dir = "checkpoints",
    report_to = "wandb",
    learning_rate = 2e-4,
    warmup_ratio = 0.03,
    lr_scheduler_type = "cosine",
    num_train_epochs = 5,
    per_device_train_batch_size = 2,
    gradient_accumulation_steps = 4,
    bf16 = True,
    optim = "paged_adamw_32bit",
    evaluation_strategy="epoch",
    save_strategy = "epoch",
    logging_steps = 1,
    do_eval=True,
    neftune_noise_alpha = 5,
    weight_decay = 0.1,
)



In [141]:
trainer = SFTTrainer(
    args = training_args,
    model=model,
    # peft_config=peft_config,
    # data_collator=DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer = tokenizer),
    packing = True,
    train_dataset=train,
    eval_dataset=test,
    dataset_text_field="text",
    tokenizer=tokenizer,
)



In [142]:
wandb_callback = LLMSampleCB(trainer, test)
trainer.add_callback(wandb_callback)

In [143]:
trainer.train()

Epoch,Training Loss,Validation Loss
0,1.1207,1.393769
1,1.0806,1.386165
2,1.0649,1.401389
4,0.9898,1.421071


100%|██████████| 143/143 [00:10<00:00, 14.13it/s]


Epoch 1:
	Accuracy: 0.783
	Precision: 0.821
	Recall: 0.783
	F-1 Score: 0.749


100%|██████████| 143/143 [00:10<00:00, 14.16it/s]


Epoch 2:
	Accuracy: 0.804
	Precision: 0.820
	Recall: 0.804
	F-1 Score: 0.808


100%|██████████| 143/143 [00:09<00:00, 14.33it/s]


Epoch 3:
	Accuracy: 0.832
	Precision: 0.830
	Recall: 0.832
	F-1 Score: 0.826


100%|██████████| 143/143 [00:10<00:00, 14.27it/s]


Epoch 4:
	Accuracy: 0.818
	Precision: 0.826
	Recall: 0.818
	F-1 Score: 0.821


100%|██████████| 143/143 [00:09<00:00, 14.42it/s]


Epoch 5:
	Accuracy: 0.811
	Precision: 0.815
	Recall: 0.811
	F-1 Score: 0.813


TrainOutput(global_step=130, training_loss=1.0888260034414439, metrics={'train_runtime': 252.5613, 'train_samples_per_second': 4.217, 'train_steps_per_second': 0.515, 'total_flos': 2.374245488472883e+16, 'train_loss': 1.0888260034414439, 'epoch': 4.859813084112149})

In [147]:
model.load_adapter("checkpoints/checkpoint-88", "adapter")

ValueError: Can't find 'adapter_config.json' at 'checkpoints/checkpoint-88'

In [148]:
FastLanguageModel.for_inference(model)
prompt = test["prompt"][0]
prompt_ids = tokenizer(prompt, return_tensors="pt")["input_ids"]
out = model.generate(prompt_ids.cuda(), max_new_tokens = 100)
response = tokenizer.decode(out[0])

In [149]:
print(response)

<s> Abstract: PMID 10368299: Several cholinesterase inhibitors are either being utilized for symptomatic treatment of Alzheimer's disease or are in advanced clinical trials. E2020, marketed as Aricept, is a member of a large family of N-benzylpiperidine-based acetylcholinesterase (AChE) inhibitors developed, synthesized and evaluated by the Eisai Company in Japan. These inhibitors were designed on the basis of QSAR studies, prior to elucidation of the three-dimensional structure of Torpedo californica AChE (TcAChE). It significantly enhances performance in animal models of cholinergic hypofunction and has a high affinity for AChE, binding to both electric eel and mouse AChE in the nanomolar range. Our experimental structure of the E2020-TcAChE complex pinpoints specific interactions responsible for the high affinity and selectivity demonstrated previously. It shows that E2020 has a unique orientation along the active-site gorge, extending from the anionic subsite of the active site, at

In [150]:
with torch.inference_mode():
    with torch.cuda.amp.autocast():
        y_hat = []
        cots = []
        for i in tqdm(range(len(test["prompt"]))):
            prompt = test["prompt"][i]
            prompt_ids = tokenizer(prompt, return_tensors="pt")["input_ids"]
            out = trainer.model.generate(prompt_ids.cuda(), max_new_tokens = 500)
            response = tokenizer.decode(out)
            prompt, ans = response.split("Score: ")[1]
            score = int(ans[0])
            cot = ans[1:]
            
            y_hat.append(score)
            cots.append(cot)
            # print(score)

100%|██████████| 143/143 [00:09<00:00, 14.44it/s]


In [151]:
y = torch.tensor(test["label"])
y_hat = torch.tensor(y_hat)

In [155]:
data = list(zip(test["prompt"], y_hat, y, cots))
test_table = wandb.Table(columns = ["prompt", "y_hat", "y", "rationale"], data = data)
wandb.log({"Predictions": test_table})

In [156]:
wandb.log({"Validation Accuracy": accuracy_score(y, y_hat)})
wandb.log({"Validation Precision": precision_score(y, y_hat, average='weighted')})
wandb.log({"Validation Recall": recall_score(y, y_hat, average='weighted')})
wandb.log({"Validation F1-Score": f1_score(y, y_hat, average='weighted')})

In [157]:
accuracy_score(y, y_hat)

0.8111888111888111

In [158]:
precision_score(y_hat, y)

0.84375

In [159]:
recall_score(y_hat, y)

0.8709677419354839

In [160]:
f1_score(y_hat, y)

0.8571428571428571

In [161]:
wandb.log({f"Confusion Matrix": wandb.plot.confusion_matrix(y_true=y.tolist(), preds=y_hat.tolist(), class_names=["Irrelevant", "Relevant"], title = "Relevance Confusion Matrix")})

In [162]:
wandb.finish()

VBox(children=(Label(value='29.161 MB of 29.184 MB uploaded (0.004 MB deduped)\r'), FloatProgress(value=0.9992…

0,1
Running Validation Accuracy,▁▄█▆▅
Running Validation F1,▁▆█▇▇
Running Validation Precision,▄▃█▆▁
Running Validation Recall,▁▄█▆▅
Validation Accuracy,▁
Validation F1-Score,▁
Validation Precision,▁
Validation Recall,▁
eval/loss,▃▃▁▁▄▄▇▇██
eval/runtime,▁▁▆▆▇▇██▆▆

0,1
Running Validation Accuracy,0.81119
Running Validation F1,0.81261
Running Validation Precision,0.81478
Running Validation Recall,0.81119
Validation Accuracy,0.81119
Validation F1-Score,0.81261
Validation Precision,0.81478
Validation Recall,0.81119
eval/loss,1.42107
eval/runtime,4.423
