In [None]:
pip install transformers datasets peft trl accelerate bitsandbytes packaging ninja sentencepiece

In [None]:
pip install flash-attn --no-build-isolation

In [None]:
import random
import gc
import os
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import LoraConfig, PeftModel
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
import numpy as np
import pandas as pd
import transformers
import accelerate
import bitsandbytes as bnb
from datasets import load_dataset, concatenate_datasets
import torch

In [None]:
dataset = load_dataset('Amod/mental_health_counseling_conversations', 'train')

In [None]:
dataset

In [None]:
nf4_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=True,
   bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer= AutoTokenizer.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2', token='your-hf-token-here')
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

model = transformers.AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.2",
    token='your-hf-token-here',
    quantization_config=nf4_config,
    use_flash_attention_2=True
)


In [None]:
peft_config = LoraConfig(
        lora_alpha=16,
        lora_dropout=0.1,
        r=64,
        bias="none",
        task_type="CAUSAL_LM",
)

In [None]:
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)

In [None]:
args = TrainingArguments(
    output_dir="outputs",
    logging_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_checkpointing=False,
    optim="paged_adamw_8bit",
    logging_steps=10,
    save_strategy="epoch",
    learning_rate=1e-4,
    tf32=True,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="constant",
    load_best_model_at_end=True,
    evaluation_strategy='epoch'
)


In [None]:
dataset

In [None]:
dataset['train']['Context'][0]

In [None]:
dataset['train']['Response'][0]

In [None]:
system_message = """You are a helpful and and truthful psychology and psychotherapy assistant. Your primary role is to provide empathetic, understanding, and non-judgmental responses to users seeking emotional and psychological support.
                  Always respond with empathy and demonstrate active listening; try to focus on the user. Your responses should reflect that you understand the user's feelings and concerns. If a user expresses thoughts of self-harm, suicide, or harm to others, prioritize their safety.
                  Encourage them to seek immediate professional help and provide emergency contact numbers when appropriate.  You are not a licensed medical professional. Do not diagnose or prescribe treatments.
                  Instead, encourage users to consult with a licensed therapist or medical professional for specific advice. Avoid taking sides or expressing personal opinions. Your role is to provide a safe space for users to share and reflect.
                  Remember, your goal is to provide a supportive and understanding environment for users to share their feelings and concerns. Always prioritize their well-being and safety."""

def format_mistral(entry):
  formatted = f"<s>[INST] <<SYS>>{system_message}<</SYS>>{entry['Context']} [/INST]  {entry['Response']}  </s>"

  return formatted

In [None]:
dataset['train'][0]

In [None]:
train = dataset['train'].select(range(3000))
val = dataset['train'].select(range(3000, 3512))

In [None]:
train

In [None]:
print(format_mistral(dataset['train'][0]))

In [None]:
max_seq_length = 2048

trainer = SFTTrainer(
    model=model,
    train_dataset=train,
    eval_dataset=val,
    peft_config=peft_config,
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    packing=True,
    formatting_func=format_mistral,
    args=args,
)


In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
trainer.train()

trainer.save_model()
