## **Fine-Tuning LLaMA**

**Installing Dependencies**

In [None]:
%%capture
%pip install accelerate peft bitsandbytes transformers trl

### **Getting Started**

**Imports**

In [None]:
import os
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import LoraConfig
from trl import SFTTrainer

**Building**

In [None]:
class FineTuned_LlaMA:
    
    def __init__(self,base_model,dataset,new_model):
        self.base_model = base_model
        self.dataset = dataset
        self.new_model = new_model
        self.dataset = load_dataset(dataset)
        
        # 4-bit quantization config
        self.compute_dtype = getattr(torch, "float16")
        self.quant_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=self.compute_dtype,
            bnb_4bit_use_double_quant=False,
        )
        
    def load_model(self):
        self.model = AutoModelForCausalLM.from_pretrained(
            self.base_model,
            quantization_config=self.quant_config,
            device_map={"": 0}
        )
        self.model.config.use_cache = False
        self.model.config.pretraining_tp = 1
        
    def load_tokenizer(self):
        self.tokenizer = AutoTokenizer.from_pretrained(self.base_model,trust_remote_code=True)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.padding_side = "right"
        
        self.peft_params = LoraConfig(
                    lora_alpha=16,
                    lora_dropout=0.1,
                    r=64,
                    bias="none",
                    task_type="CAUSAL_LM",
                )
    
    def model_trainer(self):
        self.training_params = TrainingArguments(
                        # The output directory is where the model predictions and checkpoints will be stored.
                        output_dir="./results", 
                        num_train_epochs=1, # One training epoch.
                        per_device_train_batch_size=4,
                        gradient_accumulation_steps=1,
                        optim="paged_adamw_32bit",
                        save_steps=25,
                        logging_steps=25,
                        learning_rate=2e-4,
                        weight_decay=0.001,
                        fp16=False,
                        bf16=False,
                        max_grad_norm=0.3,
                        max_steps=-1,
                        warmup_ratio=0.03,
                        group_by_length=True,
                        lr_scheduler_type="constant",
                        report_to="tensorboard"
                )
        self.trainer = SFTTrainer(
            model=self.model,
            train_dataset=self.dataset,
            peft_config=self.peft_params,
            dataset_text_field="text",
            max_seq_length=None,
            tokenizer=self.tokenizer,
            args=self.training_params,
            packing=False,
        )
        
        self.trainer.model.save_pretrained(self.new_model)
        self.trainer.tokenizer.save_pretrained(self.new_model)
        
    def create_pipeline(self):
        self.pipe = pipeline(
            "text-generation",
            model=self.model,
            tokenizer=self.tokenizer,
            max_length=100,
        )
        
    def inference(self,prompt):
        self.results = self.pipe(f"<s>[INST] {prompt} [/INST]")
        return self.results

### **TESTING THE MODEL**

In [None]:
BASE_MODEL = "NousResearch/Llama-2-7b-chat-hf"
GUANANCO_DATASET = "mlabonne/guanaco-llama2-1k"
NEW_MODEL = "llama2-7b-vivek"

In [None]:
fine_tuned_model = FineTuned_LlaMA(BASE_MODEL,GUANANCO_DATASET,NEW_MODEL)
fine_tuned_model.load_model()
fine_tuned_model.load_tokenizer()
fine_tuned_model.model_trainer()
fine_tuned_model.create_pipeline()

In [None]:
while True:
    try:
        prompt = input("\n<You>: ")
        if prompt == "exit":
            break
        else:
            print(f"<LlaMA> : {fine_tuned_model.inference(prompt)}")
            
    except KeyboardInterrupt:
        break