In [1]:
!pip install accelerate --upgrade
!pip install peft
!pip install wandb
!pip install trl
!pip install bitsandbytes --upgrade
!pip install torcheval

[0m

In [2]:
!pip install flash-attn --no-build-isolation
!pip install scikit-learn

[0m

In [3]:
from accelerate import Accelerator
from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig
import wandb
import transformers
import torch
import glob
import pandas as pd
from tqdm import tqdm
import numpy as np
import os
import random
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from torch import nn
import sys
import gc
from transformers import DataCollatorWithPadding
from transformers import AdamW
from accelerate import notebook_launcher
from sklearn.model_selection import train_test_split
from accelerate import DistributedDataParallelKwargs
import time
import re
from transformers import get_cosine_schedule_with_warmup
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from transformers import TrainingArguments
import accelerate
import json
from peft import IA3Config, IA3Model, LoraConfig
import jinja2
import math
import bitsandbytes as bnb
from datasets import Dataset
# From this Gist: https://gist.github.com/ihoromi4/b681a9088f348942b01711f251e5f964
def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

In [4]:
!wandb login 4a376fd0ab1c0901b9d9886d0734a88b4794a7fd

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [7]:
class config:
    # General Configuration
    device_type = "gpus"
    model = "OpenOrca/Mistral-7B-OpenOrca"

    # Prompt Parameters
    ab_hypothesis = "There exists an interaction between the disease {a_term} and the gene {b_term}."
    bc_hypothesis = "There exists an interaction between the drug {c_term} and the gene {b_term}."
    ac_hypothesis = "The drug {c_term} has an interaction with the disease {a_term}."

    sys_prompt = "You are an incredibly capable and intelligent language model specialized in biomedical research. You have spent your whole life reading all the papers on PubMed. Your purpose is to assist researchers by evaluating the relevance and utility of individual biomedical abstracts in relation to specific research hypotheses. You focus on meticulously reviewing each abstract presented to you, determining its significance and potential contribution to the hypothesis in question. Your evaluations are grounded in a deep understanding of biomedical literature, ensuring accuracy and reliability in identifying the value of each abstract without comparing it to other studies."
    cot_instr = "Determine whether or not this abstract is relevant for scientifically evaluating the provided hypothesis. An abstract is considered relevant if it even comments on the hypothesis a little. Analyze the abstract above, and throughly describe your thought process for evaluating the hypothesis. Pay attention to particular details in the abstract as it relates to the hypothesis. Make sure to stay focused on what the hypothesis is specifically saying. Take a deep breath and work on this problem step-by-step."
    ans_context = "Make sure to use information from the provided abstract and hypothesis to support your answer. Remember, an abstract is considered relevant if it even partially comments on the hypothesis."

    # IA3 Parameters
    ia3_config = {
        "task_type": "CAUSAL_LM",
        "target_modules": ["k_proj", "v_proj", "down_proj"],
        "feedforward_modules": ["down_proj"],
    }

In [6]:
from trl import setup_chat_format
from peft import prepare_model_for_kbit_training

model = AutoModelForCausalLM.from_pretrained(
    config.model,
    trust_remote_code=True,
    low_cpu_mem_usage=True,
    torch_dtype=torch.bfloat16,
    load_in_4bit=True,
    device_map = "auto",
    use_cache=False,)

tokenizer = AutoTokenizer.from_pretrained(config.model, padding_side = "left", add_eos_token=True)
model.config.pad_token_id = tokenizer.pad_token_id

model, tokenizer = setup_chat_format(model, tokenizer)

peft_config = LoraConfig(
    r=64, lora_alpha=16, bias="none", task_type="CAUSAL_LM", lora_dropout=0.05
)

model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)

OSError: OpenOrca/Mistral-7B-OpenOrca is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'
If this is a private repository, make sure to pass a token having permission to this repo either by logging in with `huggingface-cli login` or by passing `token=<your_token>`

In [None]:
# def cot_prompt(sys_prompt: str, hyp: str, abstract: str, cot_instr: str, response: str) -> str:
#     context = {
# 		"sys_prompt": sys_prompt,
# 		"hyp": hyp,
# 		"abstract": abstract,
# 		"cot_instr": cot_instr,
# 		"response": response,
# 	}

#     template = jinja2.Template("""<|im_start|>system
# 	{{sys_prompt}}<|im_end|>
# 	<|im_start|>user
# 	Hypothesis: {{hyp}}
# 	Abstract: {{abstract}}
# 	{{cot_instr}}
# 	Reasoning: <|im_end|>
# 	<|im_start|>assistant
# 	{{response}}
# 	""")
#     return template.render(context)
def train_cot_prompt(sys_prompt: str, hyp: str, abstract: str, cot_instr: str, ans_context: str, cot: str, label: str, tokenizer: AutoTokenizer) -> str:
    output = [{"role": "system", "content": sys_prompt},
	 {"role": "user", "content": f"Hypothesis: {hyp}\nAbstract: {abstract}\n{cot_instr}\nReasoning: "},
	  {"role": "assistant", "content": f"{cot}"}]
    return tokenizer.apply_chat_template(output, tokenize = False)

def eval_cot_prompt(sys_prompt: str, hyp: str, abstract: str, cot_instr: str, ans_context: str, tokenizer: AutoTokenizer) -> str:
	output = [{"role": "system", "content": sys_prompt},
	 {"role": "user", "content": f"Hypothesis: {hyp}\nAbstract: {abstract}\n{cot_instr}\nReasoning: "}]
	return f"{tokenizer.apply_chat_template(output, tokenize = False)}<|im_start|>assistant"

def train_ans_prompt(sys_prompt: str, hyp: str, abstract: str, cot_instr: str, ans_context: str, cot: str, label: str, tokenizer: AutoTokenizer) -> str:
    output = [{"role": "system", "content": sys_prompt},
	 {"role": "user", "content": f"Hypothesis: {hyp}\nAbstract: {abstract}.\nClassify the given abstract as either 0 (Not relevant for scientifically assessing the hypothesis) or 1 (Relevant for scientifically assessing the hypothesis). {ans_context}\nAnswer: "},
	  {"role": "assistant", "content": f"{label}"}]
    return tokenizer.apply_chat_template(output, tokenize = False)

def eval_ans_prompt(sys_prompt: str, hyp: str, abstract: str, cot_instr: str, ans_context: str, tokenizer: AutoTokenizer) -> str:
	output = [{"role": "system", "content": sys_prompt},
	{"role": "user", "content": f"Hypothesis: {hyp}\nAbstract: {abstract}.\nClassify the given abstract as either 0 (Not relevant for scientifically assessing the hypothesis) or 1 (Relevant for scientifically assessing the hypothesis). {ans_context}\nAnswer: "}]
	return f"{tokenizer.apply_chat_template(output, tokenize = False)}<|im_start|>assistant"

def train_cot_and_ans_prompt(sys_prompt: str, hyp: str, abstract: str, cot_instr: str, ans_context: str, cot: str, label: str, tokenizer: AutoTokenizer) -> str:
    output = [{"role": "system", "content": sys_prompt},
	 {"role": "user", "content": f"Hypothesis: {hyp}\nAbstract: {abstract}.\n{cot_instr}\nClassify the given abstract as either 0 (Not relevant for scientifically assessing the hypothesis) or 1 (Relevant for scientifically assessing the hypothesis). {ans_context}\nAnswer: "},
	  {"role": "assistant", "content": f"{label} {cot}"}]
    return tokenizer.apply_chat_template(output, tokenize = False)

def eval_cot_and_ans_prompt(sys_prompt: str, hyp: str, abstract: str, cot_instr: str, ans_context: str, tokenizer: AutoTokenizer) -> str:
	output = [{"role": "system", "content": sys_prompt},
	{"role": "user", "content": f"Hypothesis: {hyp}\nAbstract: {abstract}.\n{cot_instr}\nClassify the given abstract as either 0 (Not relevant for scientifically assessing the hypothesis) or 1 (Relevant for scientifically assessing the hypothesis). {ans_context}\nAnswer: "}]
	return f"{tokenizer.apply_chat_template(output, tokenize = False)}<|im_start|>assistant"

In [None]:
train = pd.read_csv("train.tsv", sep="\t")
test = pd.read_csv("test.tsv", sep="\t")

In [None]:
def processRowTrain(row, prompt_fn):
    if pd.isnull(row["a_term"]):
        hypothesis = config.bc_hypothesis.format(c_term=row["c_term"], b_term=row["b_term"])
    elif pd.isnull(row["b_term"]):
        hypothesis = config.ac_hypothesis.format(c_term=row["c_term"], a_term=row["a_term"])
    elif pd.isnull(row["c_term"]):
        hypothesis = config.ab_hypothesis.format(a_term=row["a_term"], b_term=row["b_term"])
    return prompt_fn(config.sys_prompt, hypothesis, row["abstract"], config.cot_instr, config.ans_context, row.get("cot", ""), int(row["label"]), tokenizer = tokenizer)

In [None]:
def processRowEval(row, prompt_fn):
    if pd.isnull(row["a_term"]):
        hypothesis = config.bc_hypothesis.format(c_term=row["c_term"], b_term=row["b_term"])
    elif pd.isnull(row["b_term"]):
        hypothesis = config.ac_hypothesis.format(c_term=row["c_term"], a_term=row["a_term"])
    elif pd.isnull(row["c_term"]):
        hypothesis = config.ab_hypothesis.format(a_term=row["a_term"], b_term=row["b_term"])
    return prompt_fn(config.sys_prompt, hypothesis, row["abstract"], config.cot_instr, config.ans_context, tokenizer = tokenizer)

In [None]:
train["text"] = train.apply(lambda row: processRowTrain(row, train_cot_and_ans_prompt), axis=1)
train["prompt"] = train.apply(lambda row: processRowEval(row, eval_cot_and_ans_prompt), axis=1)
train = Dataset.from_pandas(train)

In [None]:
test["text"] = test.apply(lambda row: processRowTrain(row, train_cot_and_ans_prompt), axis=1)
test["prompt"] = test.apply(lambda row: processRowEval(row, eval_cot_and_ans_prompt), axis=1)
test = Dataset.from_pandas(test)

In [None]:
print(train["text"][0])

In [None]:
print(train["prompt"][0])

In [None]:
print(test["text"][0])

In [None]:
print(test["prompt"][0])

# Training

In [None]:
wandb.init(project="kmGPT", entity = "morgridge", group = "Fine Tuning", name = "LoRA Training (Labels + CoT Tacked) & Phi-3 + Regularization & Neftune + Validation tester", reinit=True)

In [None]:
zero = tokenizer.encode("0", add_special_tokens=False)[-1]
one = tokenizer.encode("1", add_special_tokens=False)[-1]

In [None]:
# Lowkey really ugly code but it works
def label_filter(logits, labels):
    binary_labels = []
    for row in labels:
        if zero in row:
            binary_labels.append(0)
        elif one in row:
            binary_labels.append(1)
        else:
            raise Exception("No label was found????")
    first_predictions = logits.argmax(-1)[:, 0]

    return 

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)

    # Calculate accuracy
    accuracy = accuracy_score(labels, preds)

   # Calculate precision, recall, and F1-score
    precision = precision_score(labels, preds, average='weighted')
    recall = recall_score(labels, preds, average='weighted')
    f1 = f1_score(labels, preds, average='weighted')

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }

In [None]:
training_args = TrainingArguments(
    output_dir = "mistral_training",
    report_to = "wandb",
    learning_rate = 2e-4,
    warmup_ratio = 0.04,
    lr_scheduler_type = "cosine",
    num_train_epochs = 15,
    per_device_train_batch_size = 2,
    gradient_accumulation_steps = 1,
    bf16 = True,
    # evaluation_strategy="epoch",
    evaluation_strategy="no",
    save_strategy = "epoch",
    logging_steps = 1,
    do_eval=True,
    neftune_noise_alpha = 5,
    # load_best_model_at_end = True,
    save_total_limit = 2,
    weight_decay = 0.1,
    # metric_for_best_model = "accuracy",
)

In [None]:
response_template_with_context = "\nAnswer: <|im_end|>\n"
response_template_ids = tokenizer.encode(response_template_with_context, add_special_tokens=False)[2:4]

In [None]:
tokenizer.tokenize(response_template_with_context, add_special_tokens=False)

In [None]:
trainer = SFTTrainer(
    args = training_args,
    model=model,
    data_collator=DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer = tokenizer),
    train_dataset=train,
    # eval_dataset=test,
    dataset_text_field="text",
    max_seq_length=2048,
    tokenizer=tokenizer,
    # preprocess_logits_for_metrics = label_filter,
    # compute_metrics = compute_metrics
)

In [None]:
# from wandb import WandbCallback
# class LLMSampleCB(WandbCallback):
#     def __init__(self, trainer, test_dataset, max_new_tokens=1024, log_model="checkpoint"):
#         super().__init__()
#         self._log_model = log_model
#         self.sample_dataset = test_dataset.select(range(num_samples))
#         self.model, self.tokenizer = trainer.model, trainer.tokenizer
#         self.gen_config = GenerationConfig.from_pretrained(trainer.model.name_or_path,
#                                                             max_new_tokens=500)
#     def generate(self, prompt):
#         tokenized_prompt = self.tokenizer(prompt, return_tensors='pt')['input_ids'].cuda()
#         with torch.inference_mode():
#             output = self.model.generate(tokenized_prompt, generation_config=self.gen_config)
#         return self.tokenizer.decode(output[0][len(tokenized_prompt[0]):], skip_special_tokens=True)

#     def samples_table(self, examples):
#         records_table = wandb.Table(columns=["prompt", "generation"] + list(self.gen_config.to_dict().keys()))
#         for example in tqdm(examples, leave=False):
#             prompt = example["text"]
#             generation = self.generate(prompt=prompt)
#             records_table.add_data(prompt, generation, *list(self.gen_config.to_dict().values()))
#         return records_table

#     def on_evaluate(self, args, state, control,  **kwargs):
#         super().on_evaluate(args, state, control, **kwargs)
#         records_table = self.samples_table(self.sample_dataset)
#         self._wandb.log({"sample_predictions":records_table})

In [None]:
trainer.train()

In [None]:
with torch.inference_mode():
    trainer.model.eval()
    prompt = train["prompt"][0]
    prompt_ids = tokenizer(prompt, return_tensors="pt")["input_ids"]
    out = trainer.model.generate(prompt_ids.cuda(), max_new_tokens = 200)
    response = tokenizer.decode(out[0])

In [None]:
response

In [None]:
from transformers import LogitsProcessor, LogitsProcessorList
class AnswerConstraint(LogitsProcessor):
    def __init__(self, allowed_tokens):
        self.allowed_tokens = torch.tensor(allowed_tokens)

    def __call__(self, input_ids, scores):
        # Set the logits of all tokens not in the allowed_tokens set to -inf
        newScores = torch.full(scores.shape, -float("inf")).cuda()
        for token in self.allowed_tokens:
            newScores[0][token] = scores[0][token]
        return newScores

In [None]:
processor = AnswerConstraint([zero, one])

In [None]:
trainer.model.load_adapter("", "adapter")

In [None]:
with torch.inference_mode():
    y_hat = []
    for i in tqdm(range(len(test["prompt"]))):
        prompt = test["prompt"][i]
        prompt_ids = tokenizer(prompt, return_tensors="pt")["input_ids"]
        out = trainer.model.generate(prompt_ids.cuda(), max_new_tokens = 1, logits_processor = LogitsProcessorList([processor]))
        response = tokenizer.decode(out[0])
        score = int(response[-1])
        y_hat.append(score)

In [None]:
y = [int(i) for i in test["label"]]
y = torch.tensor(y)

In [None]:
y_hat = torch.tensor(y_hat)

In [None]:
y_hat

In [None]:
accuracy = (y_hat == y).sum() / len(y)
accuracy

In [None]:
from torcheval.metrics.functional.classification import binary_recall, binary_precision, binary_accuracy, binary_f1_score
wandb.log({"Accuracy": binary_accuracy(y_hat, y)})
wandb.log({"Precision": binary_precision(y_hat, y)})
wandb.log({"Recall": binary_recall(y_hat, y)})
wandb.log({"F1-Score": binary_f1_score(y_hat, y)})

In [None]:
binary_precision(y_hat, y)

In [None]:
binary_recall(y_hat, y)

In [None]:
binary_f1_score(y_hat, y)

In [None]:
wandb.log({f"Confusion Matrix": wandb.plot.confusion_matrix(y_true=y.tolist(), preds=y_hat.tolist(), class_names=["Irrelevant", "Relevant"], title = "Relevance Confusion Matrix")})

In [None]:
tokenizer.decode(out[0])

In [None]:
test["text"][i]