<a href="https://colab.research.google.com/github/stewart-lab/kmGPT/blob/fine-tuning/Unsloth_Lora_Fine_Tuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install accelerate
!pip install peft
!pip install wandb
!pip install trl
!pip install bitsandbytes
!pip install scikit-learn
!pip install "unsloth[cu118-torch230] @ git+https://github.com/unslothai/unsloth.git"
!pip install "unsloth[cu121-torch230] @ git+https://github.com/unslothai/unsloth.git"
!pip install "unsloth[cu118-ampere-torch230] @ git+https://github.com/unslothai/unsloth.git"
!pip install "unsloth[cu121-ampere-torch230] @ git+https://github.com/unslothai/unsloth.git"
!pip install xformers



In [None]:
!pip install flash-attn

In [None]:
from accelerate import Accelerator
from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig
import wandb
import transformers
import torch
import glob
import pandas as pd
from tqdm import tqdm
import numpy as np
import os
import random
from unsloth import FastLanguageModel
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from torch import nn
import sys
import gc
from transformers import AdamW
from accelerate import notebook_launcher
from sklearn.model_selection import train_test_split
from accelerate import DistributedDataParallelKwargs
import time
import re
from transformers import get_cosine_schedule_with_warmup
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from transformers import TrainingArguments, BitsAndBytesConfig
import accelerate
import json
from peft import IA3Config, IA3Model, LoraConfig
import jinja2
import math
import bitsandbytes as bnb
from datasets import Dataset
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import math
from trl import setup_chat_format
from peft import prepare_model_for_kbit_training

# From this Gist: https://gist.github.com/ihoromi4/b681a9088f348942b01711f251e5f964
def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

In [None]:
!huggingface-cli login --token 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

In [None]:
!wandb login 4a376fd0ab1c0901b9d9886d0734a88b4794a7fd

In [None]:
class config:
    # General Configuration
    device_type = "gpus"
    model = "unsloth/Phi-3-mini-4k-instruct"

    # Training Configuration
    max_seq_length = 2048
    trust = True

    # Porpoise One (Relevance Filtering Parameters)
    ab_hypothesis = "There exists an interaction between the disease {a_term} and the gene {b_term}."
    bc_hypothesis = "There exists an interaction between the drug {c_term} and the gene {b_term}."
    ac_hypothesis = "The drug {c_term} has an interaction with the disease {a_term}."

    rel_instr = "Classify this abstract as either 0 (Not Relevant) or 1 (Relevant) for evaluating the provided hypothesis."

    # Porpoise Two (Supports parameters)
    sup_instr = "Explain why (or why not) this biomedical abstract supports the provided statement. Give a score of 1 for supports and a score of 0 for does not support."

In [None]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = config.model,
    max_seq_length = config.max_seq_length,
    load_in_4bit = True,
    trust_remote_code = config.trust,
    attn_implementation = 'flash_attention_2',
    device_map = "auto",
)

model = FastLanguageModel.get_peft_model(
    model,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    r = 16,
    # lora_alpha = 32,
    lora_dropout = 0,
    bias = "none",
    use_dora = True,
    loftq_config = None
)

### Relevance data prep

In [None]:
def train_ans_prompt(hyp, abstract, instr, label, cot) -> str:
	return f"Abstract: {abstract}\nHypothesis: {hyp}\nInstructions: {config.rel_instr}\nScore: {label}\nExplanation: {cot}"

def test_ans_prompt(hyp, abstract, instr, label) -> str:
	return f"Abstract: {abstract}\nHypothesis: {hyp}\nInstructions: {config.rel_instr}\nScore: {label}"

def eval_ans_prompt(hyp, abstract, instr) -> str:
	return f"Abstract: {abstract}\nHypothesis: {hyp}\nInstructions: {config.rel_instr}\nScore: "

In [None]:
train_rel = pd.read_csv("./data/Porpoise_1/same_dist_train.tsv", sep="\t")
test_rel = pd.read_csv("./data/Porpoise_1/same_dist_test.tsv", sep="\t")

In [None]:
def processRowTrainText(row, prompt_fn):
    if pd.isnull(row["a_term"]):
        hypothesis = config.bc_hypothesis.format(c_term=row["c_term"], b_term=row["b_term"])
    elif pd.isnull(row["b_term"]):
        hypothesis = config.ac_hypothesis.format(c_term=row["c_term"], a_term=row["a_term"])
    elif pd.isnull(row["c_term"]):
        hypothesis = config.ab_hypothesis.format(a_term=row["a_term"], b_term=row["b_term"])
    return prompt_fn(hypothesis, row["abstract"], config.rel_instr, int(row["label"]), row["cot"])

In [None]:
def processRowTestText(row, prompt_fn):
    if pd.isnull(row["a_term"]):
        hypothesis = config.bc_hypothesis.format(c_term=row["c_term"], b_term=row["b_term"])
    elif pd.isnull(row["b_term"]):
        hypothesis = config.ac_hypothesis.format(c_term=row["c_term"], a_term=row["a_term"])
    elif pd.isnull(row["c_term"]):
        hypothesis = config.ab_hypothesis.format(a_term=row["a_term"], b_term=row["b_term"])
    return prompt_fn(hypothesis, row["abstract"], config.rel_instr, int(row["label"]))

In [None]:
def processRowPrompt(row, prompt_fn):
    if pd.isnull(row["a_term"]):
        hypothesis = config.bc_hypothesis.format(c_term=row["c_term"], b_term=row["b_term"])
    elif pd.isnull(row["b_term"]):
        hypothesis = config.ac_hypothesis.format(c_term=row["c_term"], a_term=row["a_term"])
    elif pd.isnull(row["c_term"]):
        hypothesis = config.ab_hypothesis.format(a_term=row["a_term"], b_term=row["b_term"])
    return prompt_fn(hypothesis, row["abstract"], config.rel_instr)

In [None]:
train_rel["text"] = train_rel.apply(lambda row: processRowTrainText(row, train_ans_prompt), axis=1)
train_rel["prompt"] = train_rel.apply(lambda row: processRowPrompt(row, eval_ans_prompt), axis=1)

In [None]:
test_rel["text"] = test_rel.apply(lambda row: processRowTestText(row, test_ans_prompt), axis=1)
test_rel["prompt"] = test_rel.apply(lambda row: processRowPrompt(row, eval_ans_prompt), axis=1)

### Support data prep

In [None]:
def getText(row):
    return f'Abstract: {row["abstract"]}\nStatement: {row["statement"]}\nInstructions: {config.sup_instr}\nScore: {row["label"]}\nExplanation: {row["cot"]}'

def getPrompt(row):
    return f'Abstract: {row["abstract"]}\nStatement: {row["statement"]}\nInstructions: {config.sup_instr}\nScore: '

In [None]:
train_sup = pd.read_csv("./data/Porpoise_2/train.tsv", sep = "\t")
test_sup = pd.read_csv("./data/Porpoise_2/test.tsv", sep = "\t")

In [None]:
train_sup["text"] = train_sup.apply(lambda row: getText(row), axis = 1)
train_sup["prompt"] = train_sup.apply(lambda row: getPrompt(row), axis = 1)

In [None]:
test_sup["text"] = test_sup.apply(lambda row: getText(row), axis = 1)
test_sup["prompt"] = test_sup.apply(lambda row: getPrompt(row), axis = 1)

### Merging Data together

In [None]:
train_text = pd.concat([train_sup["text"], train_rel["text"]], ignore_index = True)
train_prompts = pd.concat([train_sup["prompt"], train_rel["prompt"]], ignore_index = True)

test_text = pd.concat([test_sup["text"], test_rel["text"]], ignore_index = True)
test_prompts = pd.concat([test_sup["prompt"], test_rel["prompt"]], ignore_index = True)

In [None]:
train = pd.DataFrame({"text": train_text, "prompt": train_prompts})
test = pd.DataFrame({"text": test_text, "prompt": test_prompts})

In [None]:
train = Dataset.from_pandas(train)
test = Dataset.from_pandas(test)

In [None]:
print(len(train))

In [None]:
print(len(test))

# Training

In [None]:
wandb.init(project="kmGPT", entity = "morgridge", group = "Porpoise 2.0", name = "Dora?", reinit=True)

In [None]:
from transformers.integrations import WandbCallback
class LLMSampleCB(WandbCallback):
    def __init__(self, trainer, test_sup, test_rel):
        super().__init__()
        self.test_rel = test_rel
        self.test_sup = test_sup
        
        self.y_sup = torch.tensor(self.test_sup["label"])
        self.y_rel = torch.tensor(self.test_rel["label"])
        
        self.model, self.tokenizer = trainer.model, trainer.tokenizer

    def get_metrics(self, test_set, labels):
        FastLanguageModel.for_inference(trainer.model)
        y_hat = []
        for i in tqdm(range(len(test_set["prompt"]))):
            prompt = test_set["prompt"][i]
            prompt_ids = tokenizer(prompt, return_tensors="pt")["input_ids"]
            out = trainer.model.generate(prompt_ids.cuda(), max_new_tokens = 1)[-1]
            response = tokenizer.decode(out)
            try:
                score = int(response[-1])
            except:
                score = 1 - labels[i]
            y_hat.append(score)

        y_hat = torch.tensor(y_hat)

        acc = accuracy_score(labels, y_hat)
        prec = precision_score(labels, y_hat, average='weighted')
        recall = recall_score(labels, y_hat, average='weighted')
        f1 = f1_score(labels, y_hat, average='weighted')

        return acc, prec, recall, f1

    def on_evaluate(self, args, state, control,  **kwargs):
        super().on_evaluate(args, state, control, **kwargs)
        acc, prec, recall, f1 = self.get_metrics(self.test_rel, self.y_rel)
        self._wandb.log({"Relevance Running Validation Accuracy": acc})
        self._wandb.log({"Relevance Running Validation Precision": prec})
        self._wandb.log({"Relevance Running Validation Recall": recall})
        self._wandb.log({"Relevance Running Validation F1": f1})
        epoch = math.ceil(trainer.state.epoch)
        print("*********** RELEVANCE FILTERING RESULTS ***********")
        print(f"Epoch {epoch}:\n\tAccuracy: {acc:.3f}\n\tPrecision: {prec:.3f}\n\tRecall: {recall:.3f}\n\tF-1 Score: {f1:.3f}")

        acc, prec, recall, f1 = self.get_metrics(self.test_sup, self.y_sup)
        self._wandb.log({"Support Running Validation Accuracy": acc})
        self._wandb.log({"Support Running Validation Precision": prec})
        self._wandb.log({"Support Running Validation Recall": recall})
        self._wandb.log({"Support Running Validation F1": f1})
        epoch = math.ceil(trainer.state.epoch)
        print("*********** SUPPORT RESULTS ***********")
        print(f"Epoch {epoch}:\n\tAccuracy: {acc:.3f}\n\tPrecision: {prec:.3f}\n\tRecall: {recall:.3f}\n\tF-1 Score: {f1:.3f}")


In [None]:
training_args = TrainingArguments(
    output_dir = "checkpoints",
    report_to = "wandb",
    learning_rate = 2e-4,
    warmup_ratio = 0.03,
    lr_scheduler_type = "cosine",
    num_train_epochs = 15,
    per_device_train_batch_size = 4,
    gradient_accumulation_steps = 4,
    bf16 = True,
    optim = "paged_adamw_8bit",
    evaluation_strategy="epoch",
    save_strategy = "epoch",
    logging_steps = 1,
    do_eval=True,
    neftune_noise_alpha = 5,
    weight_decay = 0.1,
)

In [None]:
trainer = SFTTrainer(
    args = training_args,
    model=model,
    # peft_config=peft_config,
    # data_collator=DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer = tokenizer),
    packing = True,
    train_dataset=train,
    eval_dataset=test,
    dataset_text_field="text",
    tokenizer=tokenizer,
    max_seq_length=2048,
)

In [None]:
wandb_callback = LLMSampleCB(trainer, test_sup, test_rel)
trainer.add_callback(wandb_callback)

In [None]:
trainer.train()

In [None]:
with torch.inference_mode():
    with torch.cuda.amp.autocast():
        prompt = test_sup["prompt"][0]
        prompt_ids = tokenizer(prompt, return_tensors="pt")["input_ids"]
        out = model.generate(prompt_ids.cuda(), max_new_tokens = 100)
        response = tokenizer.decode(out[0])

In [None]:
with torch.inference_mode():
    with torch.cuda.amp.autocast():
        y_hat = []
        cots = []
        num_wrong = 0
        for i in tqdm(range(len(test["prompt"]))):
            prompt = test["prompt"][i]
            prompt_ids = tokenizer(prompt, return_tensors="pt")["input_ids"]
            out = trainer.model.generate(prompt_ids.cuda(), max_new_tokens = 1)
            response = tokenizer.decode(out[0])
            score = int(response[-1])
            cot = "Correct! So no explanation was given."

            if score != test["label"][i]:
                rationale = trainer.model.generate(prompt_ids.cuda(), max_new_tokens = 400)
                rationale = tokenizer.decode(rationale[0])
                prompt, ans = rationale.split("Score: ")
                cot = ans[1:]
                num_wrong += 1
                print("wrong")
            
            y_hat.append(score)
            cots.append(cot)
            # print(score)
print(num_wrong)

In [None]:
y = torch.tensor(test["label"])
y_hat = torch.tensor(y_hat)

In [None]:
data = list(zip(test["prompt"], y_hat, y, cots))
test_table = wandb.Table(columns = ["prompt", "y_hat", "y", "rationale"], data = data)
wandb.log({"Predictions": test_table})

In [None]:
wandb.log({"Validation Accuracy": accuracy_score(y, y_hat)})
wandb.log({"Validation Precision": precision_score(y, y_hat, average='weighted')})
wandb.log({"Validation Recall": recall_score(y, y_hat, average='weighted')})
wandb.log({"Validation F1-Score": f1_score(y, y_hat, average='weighted')})

In [None]:
accuracy_score(y, y_hat)

In [None]:
precision_score(y_hat, y)

In [None]:
recall_score(y_hat, y)

In [None]:
f1_score(y_hat, y)

In [None]:
wandb.log({f"Confusion Matrix": wandb.plot.confusion_matrix(y_true=y.tolist(), preds=y_hat.tolist(), class_names=["Irrelevant", "Relevant"], title = "Relevance Confusion Matrix")})

In [None]:
wandb.finish()

In [None]:
model.save_pretrained_merged("Porpoise1", tokenizer, save_method = "merged_16bit")

In [None]:
model.push_to_hub_merged("hf/porpoise1", tokenizer, save_method = "merged_16bit")

In [None]:
m = FastLanguageModel.from_pretrained("lexu14/porpoise1")

In [None]:
m = m[0]

In [None]:
with torch.inference_mode():
    with torch.cuda.amp.autocast():
        prompt = test["prompt"][0]
        prompt_ids = tokenizer(prompt, return_tensors="pt")["input_ids"]
        out = m.generate(prompt_ids.cuda(), max_new_tokens = 100)
        response = tokenizer.decode(out[0])

In [None]:
response