<a href="https://colab.research.google.com/github/zyren123/LLM/blob/main/baichuan_13b.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

安装依赖：

In [None]:
!pip install accelerate
!pip install bitsandbytes
!pip install colorama
!pip install cpm_kernels
!pip install sentencepiece
!pip install streamlit
!pip install transformers_stream_generator
!pip install gradio
!pip install mdtex2html

Collecting accelerate
  Downloading accelerate-0.25.0-py3-none-any.whl (265 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m265.7/265.7 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: accelerate
Successfully installed accelerate-0.25.0
Collecting bitsandbytes
  Downloading bitsandbytes-0.41.3.post2-py3-none-any.whl (92.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m92.6/92.6 MB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
[?25h

聊天：

In [None]:
import gradio as gr
import mdtex2html
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers.generation.utils import GenerationConfig

# Configuration
MODEL_PATH = 'sharpbai/Baichuan-13B-Chat'

MAX_LENGTH = 2048
TOP_P = 0.85
TEMPERATURE = 0.05
STREAM = True


nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    trust_remote_code=True,
    quantization_config=nf4_config,
    device_map="auto",
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False, trust_remote_code=True)

model.generation_config = GenerationConfig.from_pretrained(MODEL_PATH)
model.generation_config.temperature = TEMPERATURE
model.generation_config.top_p = TOP_P
model.generation_config.max_new_tokens = MAX_LENGTH


def postprocess(self, y):
    if y is None:
        return []
    for i, (message, response) in enumerate(y):
        y[i] = (
            None if message is None else mdtex2html.convert((message)),
            None if response is None else mdtex2html.convert(response),
        )
    return y


gr.Chatbot.postprocess = postprocess


def parse_text(text):
    lines = text.split("\n")
    lines = [line for line in lines if line != ""]
    count = 0
    for i, line in enumerate(lines):
        if "```" in line:
            count += 1
            items = line.split('`')
            if count % 2 == 1:
                lines[i] = f'<pre><code class="language-{items[-1]}">'
            else:
                lines[i] = f'<br></code></pre>'
        else:
            if i > 0:
                if count % 2 == 1:
                    line = line.replace("`", "\`")
                    line = line.replace("<", "&lt;")
                    line = line.replace(">", "&gt;")
                    line = line.replace(" ", "&nbsp;")
                    line = line.replace("*", "&ast;")
                    line = line.replace("_", "&lowbar;")
                    line = line.replace("-", "&#45;")
                    line = line.replace(".", "&#46;")
                    line = line.replace("!", "&#33;")
                    line = line.replace("(", "&#40;")
                    line = line.replace(")", "&#41;")
                    line = line.replace("$", "&#36;")
                lines[i] = "<br>" + line
    text = "".join(lines)
    return text


def predict(input, chatbot, history):
    chatbot.append((parse_text(input), ""))
    # 只保留最后 6 条对话记录
    history = history[-6:]
    history.append({"role": "user", "content": parse_text(input)})
    if STREAM:
        for response in model.chat(tokenizer, history, stream=True):
            chatbot[-1] = (parse_text(input), parse_text(response))
            yield chatbot, history
        history.append({"role": "assistant", "content": response})
    else:
        response = model.chat(tokenizer, history)
        chatbot[-1] = (parse_text(input), parse_text(response))
    yield chatbot, history


def reset_user_input():
    return gr.update(value='')


def reset_state():
    return [], []


with gr.Blocks() as demo:
    gr.HTML("""<h1 align="center">Baichuan 13B Chat</h1>""")
    chatbot = gr.Chatbot()
    with gr.Row():
        with gr.Column(scale=4):
            user_input = gr.Textbox(show_label=False, placeholder="在此输入消息", lines=4).style(container=False)
        with gr.Column(scale=1):
            submitBtn = gr.Button("Submit", variant="primary")
            emptyBtn = gr.Button("重置会话")
    history = gr.State([])
    submitBtn.click(predict, [user_input, chatbot, history], [chatbot, history], show_progress=True)
    submitBtn.click(reset_user_input, [], [user_input])
    emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)

demo.queue().launch(share=True, inbrowser=True, server_name="0.0.0.0", server_port=9876)