In [1]:
import json
import os
import warnings

import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
)
from peft import PeftConfig, PeftModel

from utils import save_conversation, load_conversation

from fastchat.utils import get_gpu_memory, is_partial_stop, is_sentence_complete, get_context_length
from fastchat.conversation import get_conv_template, register_conv_template, Conversation, SeparatorStyle
from fastchat.serve.inference import generate_stream

warnings.filterwarnings('ignore')

[2023-10-30 04:38:54,055] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [2]:
try:
    register_conv_template(
        Conversation(
            name="ZenAI",
            system_message="Your name is ZenAI and you're a therapist. Please have a conversation with your patient and provide them with a helpful response to their concerns.",
            roles=("USER", "ASSISTANT"),
            sep_style=SeparatorStyle.ADD_COLON_TWO,
            sep=" ",
            sep2="</s>",
        )
    )
except AssertionError:
    pass

In [3]:
def load_model(model_path, num_gpus, max_gpu_memory=None):
    
    kwargs = {"torch_dtype": torch.float16}
    if num_gpus != 1:
        kwargs["device_map"] = "auto"
        if max_gpu_memory is None:
            kwargs[
                "device_map"
            ] = "sequential"  # This is important for not the same VRAM sizes
            available_gpu_memory = get_gpu_memory(num_gpus)
            kwargs["max_memory"] = {
                i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
                for i in range(num_gpus)
            }
        else:
            kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)}
        
        config = PeftConfig.from_pretrained(model_path)
        base_model_path = config.base_model_name_or_path
        tokenizer = AutoTokenizer.from_pretrained(
            base_model_path, use_fast=False
        )
        base_model = AutoModelForCausalLM.from_pretrained(
            base_model_path,
            low_cpu_mem_usage=True,
            **kwargs,
        )
        model = PeftModel.from_pretrained(base_model, model_path)
        
        return model, tokenizer, base_model

In [4]:
use_vicuna = False
num_gpus = 4
max_gpu_memory = "12GiB"
model_path = "/home/jupyter/therapy-bot/models/zenai_sample/"
if use_vicuna:
    _, tokenizer, model = load_model(model_path, num_gpus, max_gpu_memory)
else:
    model, tokenizer, _ = load_model(model_path, num_gpus, max_gpu_memory)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



In [32]:
def chat_loop(
    model, tokenizer, username,
    device, num_gpus, max_gpu_memory,
    conv_template, system_msg,
    temperature, repetition_penalty, max_new_tokens,
    chatio,
    dtype=torch.float16,
    judge_sent_end=True
):
    conv_path = f"saved_conversations/{username}.json"
    context_len = get_context_length(model.config)

    def new_chat():
        conv = get_conv_template(conv_template)
        if system_msg is not None:
            conv.set_system_message(system_msg)
        return conv

    def reload_conv(conv):
        for message in conv.messages[conv.offset:]:
            chatio.prompt_for_output(message[0])
            chatio.print_output(message[1])

    if os.path.exists(conv_path):
        conv = load_conversation(conv_path)
        reload_conv(conv)
    else:
        conv = new_chat()

    while True:

        inp = chatio.prompt_for_input(conv.roles[0])
        conv.append_message(conv.roles[0], inp)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        gen_params = {
            "prompt": prompt,
            "temperature": temperature,
            "repetition_penalty": repetition_penalty,
            "max_new_tokens": max_new_tokens,
            "stop": conv.stop_str,
            "stop_token_ids": conv.stop_token_ids,
            "echo": False,
        }

        chatio.prompt_for_output(conv.roles[1])
        output_stream = generate_stream(
            model,
            tokenizer,
            gen_params,
            device,
            context_len=context_len,
            judge_sent_end=judge_sent_end,
        )
        outputs = chatio.stream_output(output_stream)
        conv.update_last_message(outputs.strip())
        save_conversation(conv, conv_path)

In [33]:
class SimpleChatIO:

    def prompt_for_input(self, role) -> str:
        return input(f"{role}: ")

    def prompt_for_output(self, role: str):
        print(f"{role}: ", end="")

    def stream_output(self, output_stream):
        pre = 0
        for outputs in output_stream:
            output_text = outputs["text"]
            output_text = output_text.strip().split(" ")
            now = len(output_text) - 1
            if now > pre:
                print(" ".join(output_text[pre:now]), end=" ")
                pre = now
        print(" ".join(output_text[pre:]))
        return " ".join(output_text)

    def print_output(self, text: str):
        print(text)

In [None]:
chatio = SimpleChatIO()

chat_loop(
    model, tokenizer, "manish",
    device="cuda",
    num_gpus=4,
    max_gpu_memory="12GiB",
    conv_template="ZenAI",
    system_msg="Your name is ZenAI and you're a therapist. Please have a conversation with your patient and provide them with a helpful response to their concerns.",
    temperature=0.7,
    repetition_penalty=1.0,
    max_new_tokens=512,
    chatio=chatio
)

USER: Hello!
ASSISTANT: Hello! How can I help you today?
USER: I'm Manish. Who are you?
ASSISTANT: My name is ZenAI, and I'm a therapist developed by researchers from the ZenAI team..
USER: Hello!
ASSISTANT: Hello! How can I help you today?
USER: Who am I?
ASSISTANT: You are Manish, a new patient who recently joined the ZenAI team.
USER: Hola
ASSISTANT: Hola! ¿En qué puedo ayudarte hoy?
USER: I don't know Spanish
ASSISTANT: That's okay, I can communicate with you in English. How can I assist you today?
USER: helo
ASSISTANT: Hello! Is there anything I can help you with?


USER:  Hello


ASSISTANT: Hello! Is there anything on your mind that you would like to discuss?
