In [3]:
import torch
import re
import numpy as np
import pandas as pd
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling, BitsAndBytesConfig, pipeline, StoppingCriteria, StoppingCriteriaList
from sentence_transformers import SentenceTransformer
import faiss
from peft import LoraConfig, get_peft_model, TaskType
import warnings
warnings.filterwarnings("ignore")

In [4]:
torch.__version__

'2.8.0+cu126'

In [5]:
torch.cuda.is_available()

False

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [7]:
dataset = load_dataset("Amod/mental_health_counseling_conversations", split="train")
dataset

Dataset({
    features: ['Context', 'Response'],
    num_rows: 3512
})

In [8]:
dataset[0]

{'Context': "I'm going through some things with my feelings and myself. I barely sleep and I do nothing but think about how I'm worthless and how I shouldn't be here.\n   I've never tried or contemplated suicide. I've always wanted to fix my issues, but I never get around to it.\n   How can I change my feeling of being worthless to everyone?",
 'Response': "If everyone thinks you're worthless, then maybe you need to find new people to hang out with.Seriously, the social context in which a person lives is a big influence in self-esteem.Otherwise, you can go round and round trying to understand why you're not worthless, then go back to the same crowd and be knocked down again.There are many inspirational messages you can find in social media. \xa0Maybe read some of the ones which state that no person is worthless, and that everyone has a good purpose to their life.Also, since our culture is so saturated with the belief that if someone doesn't feel good about themselves that this is someh

In [9]:
def chunking(text, max_length=512):
  # split by sent endings and paragraph markers
  chunks = re.split(r'(?<=[.!?])\s+|\n\s*\n', text)

  # clean and filter chunks
  chunks = [chunk.strip() for chunk in chunks if chunk.strip()]

  # Merge small chunks
  merged_chunks = []
  current_chunk = ""

  for chunk in chunks:
      if len(current_chunk) + len(chunk) < max_length:
          if current_chunk:
              current_chunk += " " + chunk
          else:
              current_chunk = chunk
      else:
          if current_chunk:
              merged_chunks.append(current_chunk)
          current_chunk = chunk

  if current_chunk:
      merged_chunks.append(current_chunk)

  return merged_chunks

In [10]:
print(chunking(dataset[0]['Context']))

["I'm going through some things with my feelings and myself. I barely sleep and I do nothing but think about how I'm worthless and how I shouldn't be here. I've never tried or contemplated suicide. I've always wanted to fix my issues, but I never get around to it. How can I change my feeling of being worthless to everyone?"]


In [11]:
# process dataset
all_chunks = []
chunk_metadata = []
chunk_groups = []

for idx, conversation in enumerate(dataset):
  context = conversation['Context']
  response = conversation.get('Response', '')

  chunks = chunking(context)

  for chunk_idx, chunk in enumerate(chunks):
      all_chunks.append({
          "text": chunk,
          "response": response,
          "conversation_index": idx,
          "chunk_index": chunk_idx,
          "full_context": context
      })

  chunk_groups.append(chunks)

In [12]:
print(f"Created {len(all_chunks)} chunks from {len(dataset)} conversations")

Created 3934 chunks from 3512 conversations


In [13]:
# embedding
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
chunk_texts = [chunk["text"] for chunk in all_chunks]

In [14]:
chunk_texts[0]

"I'm going through some things with my feelings and myself. I barely sleep and I do nothing but think about how I'm worthless and how I shouldn't be here. I've never tried or contemplated suicide. I've always wanted to fix my issues, but I never get around to it. How can I change my feeling of being worthless to everyone?"

In [15]:
batch_size = 32
embeddings = []
for i in range(0, len(chunk_texts) + 1, batch_size):
  batch_texts = chunk_texts[i:i+batch_size]
  batch_embeddings = embedding_model.encode(batch_texts, show_progress_bar=False)
  embeddings.append(batch_embeddings)

embeddings

[array([[ 0.07201868, -0.01708577, -0.01348482, ...,  0.0093195 ,
         -0.05317944, -0.00170032],
        [ 0.07201868, -0.01708577, -0.01348482, ...,  0.0093195 ,
         -0.05317944, -0.00170032],
        [ 0.07201868, -0.01708577, -0.01348482, ...,  0.0093195 ,
         -0.05317944, -0.00170032],
        ...,
        [ 0.06321125,  0.01027208, -0.04621496, ..., -0.02750318,
         -0.01479691, -0.0275609 ],
        [ 0.06321125,  0.01027208, -0.04621496, ..., -0.02750318,
         -0.01479691, -0.0275609 ],
        [ 0.06321125,  0.01027208, -0.04621496, ..., -0.02750318,
         -0.01479691, -0.0275609 ]], dtype=float32),
 array([[ 0.06321125,  0.01027208, -0.04621496, ..., -0.02750318,
         -0.01479691, -0.0275609 ],
        [ 0.06321125,  0.01027208, -0.04621496, ..., -0.02750318,
         -0.01479691, -0.0275609 ],
        [ 0.06321125,  0.01027208, -0.04621496, ..., -0.02750318,
         -0.01479691, -0.0275609 ],
        ...,
        [ 0.06321125,  0.01027208, -0.0

In [16]:
embeddings = np.vstack(embeddings)
embeddings

array([[ 0.07201868, -0.01708577, -0.01348482, ...,  0.0093195 ,
        -0.05317944, -0.00170032],
       [ 0.07201868, -0.01708577, -0.01348482, ...,  0.0093195 ,
        -0.05317944, -0.00170032],
       [ 0.07201868, -0.01708577, -0.01348482, ...,  0.0093195 ,
        -0.05317944, -0.00170032],
       ...,
       [-0.02839057,  0.08433163, -0.00603622, ..., -0.00874444,
         0.09284265,  0.03221234],
       [ 0.09772041,  0.07784512, -0.00295214, ...,  0.01763512,
        -0.02381536, -0.01553846],
       [ 0.12001848, -0.02394911,  0.00200294, ...,  0.00614607,
        -0.04754531,  0.02329832]], dtype=float32)

In [17]:
embeddings.shape

(3934, 384)

In [18]:
# faiss index
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings.astype('float32'))

In [19]:
def retreive_relevant_chunks(query, k=5):
  query_embedding = embedding_model.encode([query])
  distances, indices = index.search(query_embedding.astype('float32'), k)

  retrieved_chunks = []
  for idx in indices[0]:
    if idx < len(all_chunks):
      retrieved_chunks.append(all_chunks[idx])
  return retrieved_chunks

In [20]:
# retrieval test
test_query = "I'm feeling anxious about my future"
retrieved = retreive_relevant_chunks(test_query, k=10)
for i, chunk in enumerate(retrieved):
    print(f"{i+1}. {chunk['text'][:100]}...")

1. I get so much anxiety, and I don’t know why. I feel like I can’t do anything by myself because I’m s...
2. I get so much anxiety, and I don’t know why. I feel like I can’t do anything by myself because I’m s...
3. I get so much anxiety, and I don’t know why. I feel like I can’t do anything by myself because I’m s...
4. I get so much anxiety, and I don’t know why. I feel like I can’t do anything by myself because I’m s...
5. I get so much anxiety, and I don’t know why. I feel like I can’t do anything by myself because I’m s...
6. I get so much anxiety, and I don’t know why. I feel like I can’t do anything by myself because I’m s...
7. I get so much anxiety, and I don’t know why. I feel like I can’t do anything by myself because I’m s...
8. I get so much anxiety, and I don’t know why. I feel like I can’t do anything by myself because I’m s...
9. I'm scared to because I don't want it to be taken away from me again. I feel like ever lesson I lear...
10. I'm scared to because I don't wan

### Base model

In [21]:
# prepare data (base model for now)
model_name = "microsoft/DialoGPT-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

In [22]:
def create_training_examples():
  examples = []
  for chunk in all_chunks:
    if chunk["response"]: # only use chunks with responses
      input_text = f"Context: {chunk['full_context']}\nChunk: {chunk['text']}\nResponse:"
      target_text = chunk["response"]
      examples.append({"input": input_text, "target": target_text})
  return examples

In [23]:
training_examples = create_training_examples()
training_examples[0]

{'input': "Context: I'm going through some things with my feelings and myself. I barely sleep and I do nothing but think about how I'm worthless and how I shouldn't be here.\n   I've never tried or contemplated suicide. I've always wanted to fix my issues, but I never get around to it.\n   How can I change my feeling of being worthless to everyone?\nChunk: I'm going through some things with my feelings and myself. I barely sleep and I do nothing but think about how I'm worthless and how I shouldn't be here. I've never tried or contemplated suicide. I've always wanted to fix my issues, but I never get around to it. How can I change my feeling of being worthless to everyone?\nResponse:",
 'target': "If everyone thinks you're worthless, then maybe you need to find new people to hang out with.Seriously, the social context in which a person lives is a big influence in self-esteem.Otherwise, you can go round and round trying to understand why you're not worthless, then go back to the same cr

In [24]:
len(training_examples)

3930

In [25]:
# tokenise data
def tokenize_func(examples):
  model_inputs = tokenizer(examples["input"], padding="max_length", truncation=True, max_length=512)
  labels = tokenizer(examples["target"], padding="max_length", truncation=True, max_length=512)
  model_inputs["labels"] = labels["input_ids"]
  return model_inputs

In [26]:
train_dataset = Dataset.from_list(training_examples)
tokenized_dataset = train_dataset.map(tokenize_func, batched=True)

Map:   0%|          | 0/3930 [00:00<?, ? examples/s]

In [27]:
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
lora_config = LoraConfig(r=8, lora_alpha=16, target_modules=["c_attn", "c_proj"], lora_dropout=0.05, bias="none", task_type=TaskType.CAUSAL_LM)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto", quantization_config=quantization_config)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

`torch_dtype` is deprecated! Use `dtype` instead!


trainable params: 811,008 || all params: 125,250,816 || trainable%: 0.6475


In [28]:
training_args = TrainingArguments(
    output_dir="./therapy-rag-model",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    warmup_steps=100,
    learning_rate=2e-4,
    fp16=True,
    logging_steps=10,
    save_strategy="epoch",
    eval_strategy="no",
    push_to_hub=False
)

In [29]:
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

In [30]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator,
)

In [31]:
# # fine tune the model
# trainer.train()
# print("Fine tuning complete")

In [32]:
# # save the model
# model.save_pretrained("./therapy-rag-model-final")
# tokenizer.save_pretrained("./therapy-rag-model-final")
# print("Model saved successfully")

### Speculative RAG

In [33]:
draft_model_name = "microsoft/DialoGPT-small"
draft_tokenizer = AutoTokenizer.from_pretrained(draft_model_name)
draft_tokenizer.pad_token = draft_tokenizer.eos_token
draft_model = AutoModelForCausalLM.from_pretrained(draft_model_name, torch_dtype=torch.float16, device_map="auto", quantization_config=quantization_config)

In [34]:
target_model_name = "microsoft/DialoGPT-medium"
target_tokenizer = AutoTokenizer.from_pretrained(target_model_name)
target_tokenizer.pad_token = target_tokenizer.eos_token
target_model = AutoModelForCausalLM.from_pretrained(target_model_name, torch_dtype=torch.float16, device_map="auto", quantization_config=quantization_config)

tokenizer_config.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/642 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/863M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/863M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [35]:
class TherapySpeculativeRAG:
  def __init__(self, draft_model, draft_tokenizer, target_model, target_tokenizer, retrieval_func, chunk_data):
    self.draft_model = draft_model
    self.draft_tokenizer = draft_tokenizer
    self.target_model = target_model
    self.target_tokenizer = target_tokenizer
    self.retrieve_relevant_chunks = retrieval_func
    self.chunk_data = chunk_data

    self.draft_model.eval()
    self.target_model.eval()

  def generate_response(self, query, num_drafts=3, max_new_tokens=80):
    try:
      # retrieve relevant context (more focused)
      retrieved_chunks = self.retrieve_relevant_chunks(query, k=5)  # Reduced from 10 to 5

      # use both the context and response for better understanding
      context_items = []
      for chunk in retrieved_chunks:
        if chunk.get('text') and chunk.get('response'):
          context_items.append(f"Client: {chunk['text']}\nTherapist: {chunk['response']}")

      context = "\n\n".join(context_items) if context_items else "No specific context available."

      # build a cleaner prompt with context
      prompt = f"""Based on these therapy examples:

      {context}

      Now respond to the client's current concern:

      Client: {query}
      Therapist:"""

      inputs = self.draft_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(self.draft_model.device)

      # drafter generates multiple drafts with better parameters
      with torch.no_grad():
        draft_outputs = self.draft_model.generate(
          inputs.input_ids,
          max_new_tokens=max_new_tokens,
          num_return_sequences=num_drafts,
          do_sample=True,
          temperature=0.8,
          top_p=0.9,
          repetition_penalty=1.1,
          pad_token_id=self.draft_tokenizer.eos_token_id,
        )

      draft_responses = []
      for out in draft_outputs:
        response = self.draft_tokenizer.decode(
          out[inputs.input_ids.shape[1]:],
          skip_special_tokens=True
        ).strip()

        # clean up responses
        response = self._clean_response(response)
        if response and len(response.split()) > 2:  # only keep meaningful responses
          draft_responses.append(response)

      # if no good drafts, create a fallback
      if not draft_responses:
        draft_responses = [
          "I hear that you're struggling. Could you tell me more about what you're experiencing?",
          "Thank you for sharing that with me. It sounds like you're going through a difficult time."
        ]

      # verifier selects/refines the best response with a simpler prompt
      verifier_prompt = f"""As a therapist, choose the best response for this client:

      Client: {query}

      Possible responses:
      """ + "\n".join([f"{i+1}. {resp}" for i, resp in enumerate(draft_responses)]) + """

      Best therapist response:"""

      target_inputs = self.target_tokenizer(
        verifier_prompt,
        return_tensors="pt",
        truncation=True,
        max_length=1024
      ).to(self.target_model.device)

      with torch.no_grad():
        target_outputs = self.target_model.generate(
          target_inputs.input_ids,
          max_new_tokens=max_new_tokens,
          do_sample=True,
          temperature=0.7,
          top_p=0.9,
          repetition_penalty=1.1,
          pad_token_id=self.target_tokenizer.eos_token_id,
        )

      verified_response = self.target_tokenizer.decode(
        target_outputs[0][target_inputs.input_ids.shape[1]:],
        skip_special_tokens=True
      ).strip()

      verified_response = self._clean_response(verified_response)

      # ensure we have a valid response
      if not verified_response or len(verified_response.split()) < 3:
        verified_response = draft_responses[0] if draft_responses else "I appreciate you sharing that. How has this been affecting you?"

      return {
        "query": query,
        "context": context,
        "drafts": draft_responses,
        "verified_response": verified_response
      }

    except Exception as e:
      print(f"Error in generate_response: {e}")
      # fallback response
      return {
        "query": query,
        "context": "Error in retrieval",
        "drafts": [],
        "verified_response": "I'm here to listen. Could you tell me more about what you're experiencing?"
      }

  def _clean_response(self, response):
    """Clean up therapy responses by removing unwanted patterns"""
    if not response:
      return response

    # remove any stop sequences
    stop_sequences = ["Client:", "Patient:", "Therapist:", "\n\n", "###", "<|endoftext|>"]
    for seq in stop_sequences:
      if seq in response:
        response = response.split(seq)[0].strip()

    # ensure proper sentence endings
    if response and response[-1] not in ['.', '!', '?']:
      # try to find a reasonable cutoff point
      for punct in ['.', '!', '?']:
        if punct in response:
          response = response.rsplit(punct, 1)[0] + punct
          break

    return response.strip()

In [36]:
rag_system = TherapySpeculativeRAG(
    draft_model, draft_tokenizer,
    target_model, target_tokenizer,
    retreive_relevant_chunks, all_chunks
)

In [37]:
test_queries = [
  "I've been feeling really anxious lately",
  "I'm struggling with my relationships",
  "I can't sleep at night because of my thoughts",
  "I feel lonely even when I'm with people"
]

for query in test_queries:
  print(f"\nClient: {query}")
  result = rag_system.generate_response(query)
  if len(result['drafts']) > 0:
    print(f"Drafts: {result['drafts']}")
  print(f"Verified Response: {result['verified_response']}")

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.



Client: I've been feeling really anxious lately
Drafts: ["I hear that you're struggling. Could you tell me more about what you're experiencing?", "Thank you for sharing that with me. It sounds like you're going through a difficult time."]
Verified Response: I hear that you're struggling. Could you tell me more about what you're experiencing?

Client: I'm struggling with my relationships
Drafts: ["I hear that you're struggling. Could you tell me more about what you're experiencing?", "Thank you for sharing that with me. It sounds like you're going through a difficult time."]
Verified Response: I hear that you're struggling. Could you tell me more about what you're experiencing?

Client: I can't sleep at night because of my thoughts
Drafts: ["I hear that you're struggling. Could you tell me more about what you're experiencing?", "Thank you for sharing that with me. It sounds like you're going through a difficult time."]
Verified Response: I hear that you're struggling. Could you tell me

In [44]:
def evaluate_rag_system(rag_system, eval_queries, metrics=True, verbose=True):
  results = []

  for q in eval_queries:
    out = rag_system.generate_response(q)
    results.append(out)

    if verbose:
      print("\nClient:", out["query"])
      print("Context:", out["context"][:200], "..." if len(out["context"]) > 200 else "")
      print("Drafts:")
      for i, d in enumerate(out["drafts"]):
        print(f"\t{i+1}. {d}")
      print("Verified:", out["verified_response"])

  scores = {}
  if metrics:
    draft_lengths = []
    for res in results:
      if res["drafts"]:  # only process if there are drafts
        all_draft_text = " ".join(res["drafts"])
        draft_lengths.append(len(all_draft_text.split()))

    verified_lengths = [len(res["verified_response"].split()) for res in results]

    scores["avg_draft_length"] = np.mean(draft_lengths) if draft_lengths else 0
    scores["avg_verified_length"] = np.mean(verified_lengths) if verified_lengths else 0

    diversities = []
    for res in results:
      drafts = res["drafts"]
      if drafts:  # Only calculate diversity if there are drafts
        unique_count = len(set(drafts))
        diversities.append(unique_count / len(drafts))
      else:
        diversities.append(0)  # no drafts -> zero diversity

    scores["avg_draft_diversity"] = np.mean(diversities) if diversities else 0

    overlaps = []
    for res in results:
      if res["drafts"] and res["verified_response"]:  # only calculate if both exist
        vr_tokens = set(res["verified_response"].split())
        all_draft_tokens = set()
        for draft in res["drafts"]:
          all_draft_tokens.update(draft.split())

          if vr_tokens:  # avoid division by zero
            overlap_count = len(vr_tokens.intersection(all_draft_tokens))
            overlaps.append(overlap_count / len(vr_tokens))

    scores["avg_verified_overlap_with_drafts"] = np.mean(overlaps) if overlaps else 0

  return results, scores

In [45]:
results, scores = evaluate_rag_system(rag_system, test_queries)

print("\nEvaluation Metrics:")
for k, v in scores.items():
  print(f"{k}: {v:.3f}")


Client: I've been feeling really anxious lately
Context: Client: I started having anxiety three months ago. I'm new to having anxiety, and it's making me depressed.
Therapist: One of the first steps is to manage anxiety and depression symptoms are to establ ...
Drafts:
	1. I hear that you're struggling. Could you tell me more about what you're experiencing?
	2. Thank you for sharing that with me. It sounds like you're going through a difficult time.
Verified: I hear that you're struggling. Could you tell me more about what you're experiencing?

Client: I'm struggling with my relationships
Context: Client: My last relationships have ended horribly. They just up and abandoned me. One of them I have never gotten closure with over it, leaving me emotionally wrecked. I know something's wrong with me ...
Drafts:
	1. I hear that you're struggling. Could you tell me more about what you're experiencing?
	2. Thank you for sharing that with me. It sounds like you're going through a difficult tim

**Note:** The answers are repeated because of the dataset. The datset used here includes a similar response for most of the context. In process: Using a better dataset with LoRA for fine tuning