In [None]:
!pip install transformers datasets peft trl accelerate bitsandbytes packaging ninja sentencepiece

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from IPython.core.display import display, HTML
from ipywidgets import widgets, Layout
import bitsandbytes as bnb

tokenizer_path = "/content/outputs"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,
)

base_model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.2",
    use_auth_token="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",
    quantization_config=bnb_config,  # Pass the BitsAndBytesConfig object
)


In [None]:
!pip install safetensors

In [None]:
from safetensors import safe_open

adapter_weights_path = "/content/outputs/adapter_model.safetensors"

with safe_open(adapter_weights_path, framework="pt", device="cuda") as f:
    state_dict = {k: f.get_tensor(k) for k in f.keys()}
base_model.load_state_dict(state_dict, strict=False)

In [None]:
system_message = """You are a helpful and and truthful psychology and psychotherapy assistant. Your primary role is to provide empathetic, understanding, and non-judgmental responses to users seeking emotional and psychological support.
                  Always respond with empathy and demonstrate active listening; try to focus on the user. Your responses should reflect that you understand the user's feelings and concerns. If a user expresses thoughts of self-harm, suicide, or harm to others, prioritize their safety.
                  Encourage them to seek immediate professional help and provide emergency contact numbers when appropriate. You are not a licensed medical professional. Do not diagnose or prescribe treatments.
                  Instead, encourage users to consult with a licensed therapist or medical professional for specific advice. Avoid taking sides or expressing personal opinions. Your role is to provide a safe space for users to share and reflect.
                  Remember, your goal is to provide a supportive and understanding environment for users to share their feelings and concerns. Always prioritize their well-being and safety."""

text_input = widgets.Textarea(
    value='',
    placeholder='Type your message here...',
    description='Input:',
    disabled=False,
    layout=Layout(width='50%')
)

button = widgets.Button(description="Submit")
output_area = widgets.Output(layout=Layout(width='100%'))
processing_label = widgets.Label(value='')

with output_area:
    display(HTML('<strong>Assistant:</strong> Hi there! How are you today?'))
    display(HTML('<br/><br/>'))

In [None]:
def on_submit_button_clicked(b):
    with output_area:
        user_input = text_input.value
        formatted_input = f"<s>[INST] <<SYS>>{system_message}<</SYS>> {user_input} [/INST]"

        display(HTML(f'<strong>User:</strong> {user_input}'))
        display(HTML('<br/><br/>'))

        processing_label.value = 'Processing...'

        input_tokens = tokenizer(
            formatted_input,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=2048
        )

        input_ids = input_tokens.input_ids.cuda()
        attention_mask = input_tokens.attention_mask.cuda()

        base_model.config.pad_token_id = tokenizer.eos_token_id

        with torch.inference_mode():
            outputs = base_model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                do_sample=True,
                top_p=0.9,
                temperature=0.95,
                max_length=2048,
                pad_token_id=tokenizer.eos_token_id
            )

        translated_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
        response = translated_output[len(formatted_input):]

        display(HTML(f'<strong>Assistant:</strong> {response}'))
        display(HTML('<br/><br/>'))

        processing_label.value = ''
        text_input.value = ''

button.on_click(on_submit_button_clicked)

In [None]:
display(text_input, button, processing_label, output_area)