## 加载模型

In [1]:
import json

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

from fastchat.model import load_model, get_conversation_template, add_model_args

In [2]:
model_path = "/workspace/huggingface_hub/vicuna_demo/"
num_of_gpu = 3
memory_of_gpu = "80GiB"
cpu_offlaoding = False
load_8bit = False
model, tokenizer = load_model(
        model_path,
        "cuda",
        num_of_gpu,
        memory_of_gpu,
        load_8bit,
        cpu_offlaoding,
    )

The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

## 加载对话模版

In [39]:
conv = get_conversation_template(model_path)

## 开始对话

In [40]:
msg = "It is a good day today!"
conv.append_message(conv.roles[0], msg)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
print(prompt)

A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: It is a good day today! ASSISTANT:


In [41]:
input_ids = tokenizer([prompt]).input_ids

temperature = 0.7
repetition_penalty = 1.0
max_new_tokens = 512

output_ids = model.generate(
    torch.as_tensor(input_ids).cuda(),
    do_sample=True,
    temperature=temperature,
    repetition_penalty=repetition_penalty,
    max_new_tokens=max_new_tokens,
)

output_ids = output_ids[0][len(input_ids[0]) :]
outputs = tokenizer.decode(
   output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
)

print(f"{conv.roles[0]}: {msg}")
print(f"{conv.roles[1]}: {outputs}")

USER: It is a good day today!
ASSISTANT: That's great to hear! Is there anything specific you'd like help with?


In [42]:
prompt = conv.get_prompt()
print(prompt)

A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: It is a good day today! ASSISTANT:


## 支持多轮对话

In [43]:
conv.update_last_message(outputs)
prompt = conv.get_prompt()
print(prompt)

A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: It is a good day today! ASSISTANT: That's great to hear! Is there anything specific you'd like help with?</s>


In [44]:
def talk(conv, message, temperature = 0.7, repetition_penalty = 1.0, max_new_tokens = 512):
    conv.append_message(conv.roles[0], message)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    input_ids = tokenizer([prompt]).input_ids

    
    output_ids = model.generate(
        torch.as_tensor(input_ids).cuda(),
        do_sample=True,
        temperature=temperature,
        repetition_penalty=repetition_penalty,
        max_new_tokens=max_new_tokens,
    )

    output_ids = output_ids[0][len(input_ids[0]) :]
    outputs = tokenizer.decode(
        output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
    )

    print(f"{conv.roles[0]}: {message}")
    print(f"{conv.roles[1]}: {outputs}")
    conv.update_last_message(outputs)

In [45]:
talk(conv, "Would you like to have a walk in the park?")

USER: Would you like to have a walk in the park?
ASSISTANT: Absolutely! I enjoy spending time outdoors when the weather is pleasant. Where would you like to go?


In [46]:
print(conv.get_prompt())

A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: It is a good day today! ASSISTANT: That's great to hear! Is there anything specific you'd like help with?</s>USER: Would you like to have a walk in the park? ASSISTANT: Absolutely! I enjoy spending time outdoors when the weather is pleasant. Where would you like to go?</s>


In [47]:
talk(conv, "你会说中文吗？")

USER: 你会说中文吗？
ASSISTANT: Yes, I know some Chinese. Would you like me to speak Chinese with you?


In [48]:
print(conv.get_prompt())

A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: It is a good day today! ASSISTANT: That's great to hear! Is there anything specific you'd like help with?</s>USER: Would you like to have a walk in the park? ASSISTANT: Absolutely! I enjoy spending time outdoors when the weather is pleasant. Where would you like to go?</s>USER: 你会说中文吗？ ASSISTANT: Yes, I know some Chinese. Would you like me to speak Chinese with you?</s>


In [49]:
message = """
Please extract person name from the following sentence and output it in json format:

John Doe lives in New York and works at Google.
"""
talk(conv, message)

USER: 
Please extract person name from the following sentence and output it in json format:

John Doe lives in New York and works at Google.

ASSISTANT: Sure! Here's the extracted person name in JSON format: 
{
 "name": "John Doe"



In [50]:
talk(conv, "Please extract person name from the following sentence: \n\n John Doe lives in New York and works at Google.")

USER: Please extract person name from the following sentence: 

 John Doe lives in New York and works at Google.
ASSISTANT: The extracted person name is "John Doe".


## 通过Gradio进行封装

In [52]:
import gradio as gr

class ConversationWrapper:
    def __init__(self):
        self.conv = get_conversation_template(model_path)

    def talk(self, message):
        self.conv.append_message(conv.roles[0], message)
        self.conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()
        input_ids = tokenizer([prompt]).input_ids
    
        output_ids = model.generate(
            torch.as_tensor(input_ids).cuda(),
            do_sample=True,
            temperature=temperature,
            repetition_penalty=repetition_penalty,
            max_new_tokens=max_new_tokens,
        )

        output_ids = output_ids[0][len(input_ids[0]) :]
        outputs = tokenizer.decode(
            output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
        )
        conv.update_last_message(outputs)
        return outputs

conv = ConversationWrapper()

def predict(input, history=[]):
    history.append(input)
    response = conv.talk(input)
    history.append(response)
    responses = [(u,b) for u,b in zip(history[::2], history[1::2])]
    return responses, history

with gr.Blocks(css="#chatbot{height:350px} .overflow-y-auto{height:500px}") as demo:
    chatbot = gr.Chatbot(elem_id="chatbot")
    state = gr.State([])

    with gr.Row():
        txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter").style(container=False)

    txt.submit(predict, [txt, state], [chatbot, state])

demo.launch()



Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


