## Chat to determine PII field

This notebook will take a model checkpoint trained for sequence classification, and see how well it performs

## Load library

Load libraries

In [None]:
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent))
from tab_exp.viz import model_choice, get_hf_model

training, new_data, model_name = model_choice()

## Create an optimized model

Create the model with the optimized config

In [None]:
from typing import Literal
import torch
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
from transformers.utils.quantization_config import BitsAndBytesConfig
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, pipeline, AutoTokenizer
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model # type: ignore

In [None]:
class Llama3:
    def __init__(
        self, 
        model_name: str,
        model_type: Literal["text_classifier", "causal"],
        train: bool,
        init_prediction: bool = True
    ):
        self.model_name = model_name
        self.model_id = get_hf_model(model_name)
        self.tokenizer = self.init_tokenizer()
        self.terminators = [self.tokenizer.eos_token_id,
                            self.tokenizer.convert_tokens_to_ids("")]
        quant_cfg = self.create_4bit_cfg()
        match model_type:
            case "text_classifier":
                self.model = self.create_sequence_model(quant_cfg)
            case _:
                self.model = self.create_causal_model(quant_cfg)
        self.config_model(train)
        self.generator = pipeline("text-generation", 
                                  model=self.model,
                                  tokenizer=self.tokenizer,
                                  )

    def create_4bit_cfg(self) -> BitsAndBytesConfig:
        if not torch.cuda.is_available():
            raise Exception("GPU must be available for training")

        quantization_config = BitsAndBytesConfig(
            load_in_4bit = True, 
            bnb_4bit_quant_type = 'nf4',
            bnb_4bit_use_double_quant = True, 
            bnb_4bit_compute_dtype = torch.bfloat16 
        )
        return quantization_config
    
    def _make_quant_cfg(self, quantization_cfg: BitsAndBytesConfig):
        return dict(quantization_config=quantization_cfg,
                    num_labels=4,
                    device_map='auto')

    def create_sequence_model(self, quant_cfg: BitsAndBytesConfig):
        """Creates a model used for text/sequence classification

        Args:
            quant_cfg (BitsAndBytesConfig): _description_

        Returns:
            AutoModelForSequenceClassification: the model
        """
        print(f"Using model {self.model_name}")

        # Create a model for text classification.  Normally llama3 is used for CausalLLM (question/answer)
        model = AutoModelForSequenceClassification.from_pretrained(
            self.model_name,
            **self._make_quant_cfg(quant_cfg)
        )
        return model
    
    def create_causal_model(self, quant_cfg: BitsAndBytesConfig):
        """Creates a model used for text generation/prompting

        Args:
            quant_cfg (BitsAndBytesConfig): _description_

        Returns:
            AutoModelForCausalLM: the model
        """
        print(f"Using model {self.model_name}")

        # Create a model for chat prompting. 
        return AutoModelForCausalLM.from_pretrained(
            self.model_name,
            **self._make_quant_cfg(quant_cfg)
        )
        

    def init_tokenizer(self) -> PreTrainedTokenizer | PreTrainedTokenizerFast:
        """Retrieves and configures tokenizer based on the model

        Returns:
            PreTrainedTokenizer | PreTrainedTokenizerFast: _description_
        """
        tokenizer = AutoTokenizer.from_pretrained(self.model_id)
        if "llama3" in self.model_id or "Llama-3" in self.model_id:
            # The llama3 tokenizer doesn't do padding like other models.  So set them as End of Sequence
            print("Setting tokenizer for llama3")
            tokenizer.pad_token_id = tokenizer.eos_token_id
            tokenizer.pad_token = tokenizer.eos_token
        print('initialized tokenizer')
        return tokenizer
    
    def lora_config(self) -> LoraConfig:
        """LoRA configuration to train only needed weights in the model

        Returns:
            LoraConfig:
        """
        return LoraConfig(
            r = 16, 
            lora_alpha = 8,
            target_modules = ['q_proj', 'k_proj', 'v_proj', 'o_proj'],
            lora_dropout = 0.05, 
            bias = 'none',
            task_type = 'SEQ_CLS'
        )
    
    def config_model(self, training: bool):
        # The model is now optimized to make training faster, if a little less accurate
        if training:
            print("Configuring model for training")
            self.model = prepare_model_for_kbit_training(self.model)
            self.model = get_peft_model(self.model, self.lora_config())
            # set some llama3 tokenizer specific settings
            self.model.config.use_cache = False  # type: ignore
            self.model.config.pretraining_tp = 1 # type: ignore
        else:
            print("Using checkpointed model to get predictions")

        self.model.config.pad_token_id = self.tokenizer.pad_token_id  # type: ignore
  
    def get_response(
          self, 
          query: str, 
          message_history: list[dict] | None = None, 
          max_tokens=1024*124, 
          temperature=0.6, 
          top_p=0.9
      ):
        if message_history is None:
            message_history = []
        user_prompt = message_history + [{"role": "user", "content": query}]
        prompt = self.tokenizer.apply_chat_template(
            user_prompt, tokenize=False, add_generation_prompt=True
        )
        print(prompt)

        outputs = self.generator(
            prompt,
            max_new_tokens=max_tokens,
            eos_token_id=self.terminators,
            do_sample=True,
            temperature=temperature,
            top_p=top_p,
        )
        print(outputs)
        response = outputs[0]["generated_text"][len(prompt):] # type: ignore
        return response, user_prompt + [{"role": "assistant", "content": response}]
    
    def chatbot(self, system_instructions=""):
        conversation = [{"role": "system", "content": system_instructions}]
        # self.generator = pipeline(
        #     "text-generation",
        #     model=self.model_id,
        #     model_kwargs={
        #         "torch_dtype": torch.float16,
        #         "quantization_config": {"load_in_4bit": True},
        #         "low_cpu_mem_usage": True,
        #     }
        # )

        while True:
            user_input = input("User: ")
            if user_input.lower() in ["exit", "quit"]:
                print("Exiting the chatbot. Goodbye!")
                break
            response, conversation = self.get_response(user_input, conversation)
            print(f"Assistant: {response}")

In [None]:
selected_model = str(model_name.value)
selected_model

In [None]:
llama3 = Llama3(model_name=str(model_name.value),
                model_type="causal",
                train=False)

In [None]:
llama3.chatbot("You are an expert in data engineering and work with General Data Protection Regulation")