# Notebook Overview
This notebook demonstrates the process of fine-tuning the Mistral 7b model using our proprietary paper-review datasets. 


# Environment Setup

In [None]:
import os

import torch
import wandb
from datasets import load_dataset
from huggingface_hub import login
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
)

from peft import LoraConfig
from trl import SFTTrainer, setup_chat_format

In [None]:
# set up wandb environment variables
os.environ["WANDB_PROJECT"] = ""
os.environ["WANDB_LOG_MODEL"] = ""
os.environ["WANDB_KEY"] = ""
os.environ["HUGGINGFACE_ACCESS_TOKEN"] = ""

In [None]:
# hugging face login
login(token=os.getenv("HUGGINGFACE_ACCESS_TOKEN"))
# wandb login
wandb.login(key=os.getenv("WANDB_KEY"))

# Load Dataset

In [None]:
train_dataset = load_dataset(path="../data/abstract_dataset/", split="train")
test_dataset = load_dataset(path="../data/abstract_dataset", split="test")

print("Length of train dataset: ", len(train_dataset))
print("Length of test dataset: ", len(test_dataset))

In [None]:
model_name = "mistralai/Mistral-7B-Instruct-v0.2"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    # attn_implementation="flash_attention_2",  # uncomment if you have an Ampere GPU
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config
)

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
tokenizer.padding_side = 'right' 
tokenizer.pad_token = tokenizer.unk_token
tokenizer.pad_token_id = tokenizer.unk_token_id

# set chat template to OAI chatML, remove if you start from a fine-tuned model
model, tokenizer = setup_chat_format(model, tokenizer)

# Prepare for Training

In [None]:
training_args = TrainingArguments(
    output_dir="models/mistral-7b-instruct-v0.2-lora",
    num_train_epochs=1,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    gradient_checkpointing=True,
    optim="adamw_torch",
    logging_steps=10,
    max_steps=10,
    save_strategy="steps",
    evaluation_strategy="steps",
    learning_rate=2e-4,
    fp16=True,                   
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="constant",
    run_name="mistral-7b-instruct-v0.2-lora",
    load_best_model_at_end=True,
    auto_find_batch_size=True
)

# Set up Lora configuration
peft_config = LoraConfig(
    lora_alpha=128,
    lora_dropout=0.05,
    r=256,
    bias="none",
    target_modules=["q_proj", "v_proj", "all_linear"],
    task_type="CAUSAL_LM",
)

In [None]:
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    peft_config=peft_config,
    tokenizer=tokenizer,
    dataset_kwargs={"add_special_tokens": False, "append_concat_token": False}
)

In [None]:
trainer.train()
trainer.save_model()