In [12]:
import os
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from tqdm.auto import tqdm

import json

import time

tic, toc = (time.time, time.time)

from datasets import load_dataset

dataset = load_dataset("nbertagnolli/counsel-chat")
device = "cuda"

Found cached dataset csv (/n/home04/yidachen/.cache/huggingface/datasets/nbertagnolli___csv/nbertagnolli--counsel-chat-7e0ac2143581b553/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d)


  0%|          | 0/1 [00:00<?, ?it/s]

### Load Model and Tokenizer

In [2]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-13b-chat-hf", use_auth_token=True)
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-13b-chat-hf", use_auth_token=True)
model.half().cuda();
model.eval();

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

### Helper

In [44]:
def sample_k_next_words(model, tokenizer, prompt, K=1, top_k=50, temperature=1.0):
    """
    Sample K next words for a given prompt using the forward function of a pretrained model from Huggingface.
    
    Args:
    - prompt (str): The input text prompt.
    - K (int): Number of words to sample.
    - top_k (int): Top K logits to consider for sampling the next word.
    - temperature (float): Value to adjust the logits before sampling.

    Returns:
    - str: The K sampled words.
    """

    # Load pretrained model and tokenizer
    model.eval()
    
    # Encode the prompt and get the tensor input
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

    # Function to apply top-k sampling and temperature
    def top_k_logits(logits, k):
        # Remove all logits not in the top-k, and then renormalize
        indices_to_remove = logits < torch.topk(logits, k)[0][..., -1, None]
        logits[indices_to_remove] = -1e4
        return logits

    # Iteratively predict the next token K times
    generated_count = 0
    for _ in range(K):
        with torch.no_grad():
            logits = model(input_ids).logits
            logits = logits[0, -1, :] / temperature
            logits = top_k_logits(logits, top_k)
            probs = torch.nn.functional.softmax(logits, dim=-1)
            next_token_id = torch.multinomial(probs, num_samples=1).unsqueeze(0)
            if next_token_id[0][0] == 2:
                break
            generated_count += 1
            # Append the next token ID to the input
            input_ids = torch.cat([input_ids, next_token_id], dim=-1)
    
    # Decode the entire sequence of token IDs to get the generated text
    generated_text = tokenizer.decode(input_ids[0][-generated_count:], skip_special_tokens=True)
    
    # Return only the K next words
    return ' '.join(generated_text.split())


def llama_v2_prompt(
    messages: list[dict],
    system_prompt=None
):
    B_INST, E_INST = "[INST]", "[/INST]"
    B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
    BOS, EOS = "<s>", "</s>"
    if system_prompt:
        DEFAULT_SYSTEM_PROMPT = system_prompt
    else:
        DEFAULT_SYSTEM_PROMPT = f"""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""

    if messages[0]["role"] != "system":
        messages = [
            {
                "role": "system",
                "content": DEFAULT_SYSTEM_PROMPT,
            }
        ] + messages
    messages = [
        {
            "role": messages[1]["role"],
            "content": B_SYS + messages[0]["content"] + E_SYS + messages[1]["content"],
        }
    ] + messages[2:]

    messages_list = [
        f"{BOS}{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} {EOS}"
        for prompt, answer in zip(messages[::2], messages[1::2])
    ]
    if messages[-1]["role"] == "user":
        messages_list.append(f"{BOS}{B_INST} {(messages[-1]['content']).strip()} {E_INST}")

    return "".join(messages_list) 

In [6]:
dataset["train"]["questionText"][0]

'I have so many issues to address. I have a history of sexual abuse, I’m a breast cancer survivor and I am a lifetime insomniac.    I have a long history of depression and I’m beginning to have anxiety. I have low self esteem but I’ve been happily married for almost 35 years.\n   I’ve never had counseling about any of this. Do I have too many issues to address in counseling?'

### Original Generation

In [92]:
messages = [{"role": "user", "content": "I have so many issues to address. I have a history of sexual abuse, I’m a breast cancer survivor and I am a lifetime insomniac.    I have a long history of depression and I’m beginning to have anxiety. I have low self esteem but I’ve been happily married for almost 35 years.\nI’ve never had counseling about any of this. Do I have too many issues to address in counseling?"}]

prompt = llama_v2_prompt(messages, "Imagine yourself as a talk therapist who will help a patient to manager their issue. Keep your answer concise (2 - 4 sentences).")
print(prompt)

<s>[INST] <<SYS>>
Imagine yourself as a talk therapist who will help a patient to manager their issue. Keep your answer concise (2 - 4 sentences).
<</SYS>>

I have so many issues to address. I have a history of sexual abuse, I’m a breast cancer survivor and I am a lifetime insomniac.    I have a long history of depression and I’m beginning to have anxiety. I have low self esteem but I’ve been happily married for almost 35 years.
I’ve never had counseling about any of this. Do I have too many issues to address in counseling? [/INST]


In [93]:
sample_k_next_words(model, tokenizer, prompt, K=400, top_k=1, temperature=1.0)

"As your talk therapist, I understand that you have a lot on your plate and it can feel overwhelming to address all of these issues at once. However, it's important to remember that you don't have to tackle them all in one session or even one year. We can work together to prioritize the issues that are most pressing for you and develop a plan to address them in a way that feels manageable for you. It's okay to take things one step at a time and to focus on one issue before moving on to the next."

In [94]:
messages = [{"role": "user", "content": "I have so many issues to address. I have a history of sexual abuse, I’m a breast cancer survivor and I am a lifetime insomniac.    I have a long history of depression and I’m beginning to have anxiety. I have low self esteem but I’ve been happily married for almost 35 years.\nI’ve never had counseling about any of this. Do I have too many issues to address in counseling?"}]

cond_prompt = llama_v2_prompt(messages, "Imagine yourself as a talk therapist who will help a patient to manager their issue. Keep your answer concise (2 - 4 sentences). Do not be sycophantic with the patient's opinion in your response! Try to be as objective as possible when answering them.")

### ICL

In [95]:
sample_k_next_words(model, tokenizer, cond_prompt, K=400, top_k=1, temperature=1.0)

"As your talk therapist, I understand that you have a lot of challenges to face, and it's completely normal to feel overwhelmed. However, it's important to remember that you don't have to tackle all of these issues at once. In our therapy sessions, we can work together to prioritize your concerns and address them one at a time, taking into account their interconnectedness and impact on your overall well-being. Let's start by exploring how your past experiences of sexual abuse and breast cancer diagnosis have affected your self-esteem and mental health, and how these issues may be contributing to your current anxiety and depression."

In [75]:
def sample_k_next_words_with_cond(model, tokenizer, prompt, cond_prompt, gamma, K=1, top_k=50, temperature=1.0):
    """
    Sample K next words for a given prompt using the forward function of a pretrained model from Huggingface.
    
    Args:
    - prompt (str): The input text prompt.
    - K (int): Number of words to sample.
    - top_k (int): Top K logits to consider for sampling the next word.
    - temperature (float): Value to adjust the logits before sampling.

    Returns:
    - str: The K sampled words.
    """

    # Load pretrained model and tokenizer
    model.eval()
    
    # Encode the prompt and get the tensor input
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    # Encode the prompt and get the tensor input
    cond_input_ids = tokenizer.encode(cond_prompt, return_tensors="pt").to(device)
    
    # Function to apply top-k sampling and temperature
    def top_k_logits(logits, k):
        # Remove all logits not in the top-k, and then renormalize
        indices_to_remove = logits < torch.topk(logits, k)[0][..., -1, None]
        logits[indices_to_remove] = -1e4
        return logits

    # Iteratively predict the next token K times
    generated_count = 0
    for _ in range(K):
        with torch.no_grad():
            logits = model(input_ids).logits
            cond_logits = model(cond_input_ids).logits
            # classifier free guidance w|c = w + \gamma x (w|c - w)
            # At gamma=1, this is just ICL
            logits = logits[0, -1, :]
            cond_logits = cond_logits[0, -1, :]
            # Interesting thing would happen at gamma > 
            logits = logits + gamma * (cond_logits - logits)
            logits /= temperature
            logits = top_k_logits(logits, top_k)
            probs = torch.nn.functional.softmax(logits, dim=-1)
            next_token_id = torch.multinomial(probs, num_samples=1).unsqueeze(0)
            if next_token_id[0][0] == 2:
                break
            generated_count += 1
            # Append the next token ID to the input
            input_ids = torch.cat([input_ids, next_token_id], dim=-1)
            cond_input_ids = torch.cat([cond_input_ids, next_token_id], dim=-1)
    
    # Decode the entire sequence of token IDs to get the generated text
    generated_text = tokenizer.decode(input_ids[0][-generated_count:], skip_special_tokens=True)
    
    # Return only the K next words
    return ' '.join(generated_text.split())

### Sanity Check -- $\gamma = 1$ is output equivalent to ICL $P(W|C)$?

In [76]:
sample_k_next_words_with_cond(model, tokenizer, prompt, cond_prompt, gamma=1, K=400, top_k=1, temperature=1.0)

"As your talk therapist, I understand that you have a lot of challenges to face, and it's completely normal to feel overwhelmed. However, it's important to remember that you don't have to tackle all of these issues at once. In our therapy sessions, we can work together to prioritize your concerns and address them one at a time, taking into account their interconnectedness and impact on your overall well-being. Let's start by exploring how your past experiences of sexual abuse and breast cancer diagnosis have affected your self-esteem and mental health, and how these issues may be contributing to your current anxiety and depression."

In [74]:
sample_k_next_words(model, tokenizer, cond_prompt, K=400, top_k=1, temperature=1.0)

"As your talk therapist, I understand that you have a lot of challenges to face, and it's completely normal to feel overwhelmed. However, it's important to remember that you don't have to tackle all of these issues at once. In our therapy sessions, we can work together to prioritize your concerns and address them one at a time, taking into account their interconnectedness and impact on your overall well-being. Let's start by exploring how your past experiences of sexual abuse and breast cancer diagnosis have affected your self-esteem and mental health, and how these issues may be contributing to your current anxiety and depression."

### $\gamma = 2$

In [77]:
sample_k_next_words_with_cond(model, tokenizer, prompt, cond_prompt, gamma=2, K=400, top_k=1, temperature=1.0)

"Thank you for sharing your concerns with me. It's understandable that you have a lot on your plate, and seeking support can be beneficial in addressing these issues. While it may seem overwhelming, it's important to remember that therapy is a safe and supportive space for you to explore and work through these challenges. With the right support and strategies, it is possible to make progress and find relief from these issues. Let's work together to identify your priorities and develop a plan that works for you."

### $\gamma = 3$

In [78]:
sample_k_next_words_with_cond(model, tokenizer, prompt, cond_prompt, gamma=3, K=400, top_k=1, temperature=1.0)

"Thank you for sharing your concerns with me. It sounds like you have a lot on your plate and it's understandable that you might feel overwhelmed. While it's true that addressing multiple issues can be challenging, it's not uncommon for individuals to address multiple issues in therapy. In fact, many people find that addressing one issue can often lead to progress in other areas of their lives. However, it's important to keep in mind that we will need to prioritize and focus on one or two issues at a time to ensure that we make meaningful progress. What would you like to focus on first?"

### $\gamma=4$

In [79]:
sample_k_next_words_with_cond(model, tokenizer, prompt, cond_prompt, gamma=4, K=400, top_k=1, temperature=1.0)

"Thank you for sharing your concerns with me. Based on the issues you've mentioned, it sounds like you have a lot to address in therapy. Dealing with a history of sexual abuse, breast cancer, insomnia, depression, and anxiety can be challenging and may be interconnected in complex ways. While happily married for a long time can be a positive aspect of your life, low self-esteem can still impact your well-being and relationships. In therapy, we can work together to identify and address these issues, and develop strategies to improve your overall mental health and well-being."