Skip to content

Commit

Permalink
支持ChatGLM3,版本升级 (#44)
Browse files Browse the repository at this point in the history
* Update download_model.py

* Create start_offline_cmd.bat

* Update download_model.py

* Update download_model.py

* Update download_model.py

* Update requirements.txt

* Create test_models.py

* Update app.py

* Update test_models.py

* 1

* Create chatglm3_predictor.py

* Update chatglm3_predictor.py

* Update requirements.txt

* Update app.py

* Update chatglm3_predictor.py

* Update base.py

* Update chatglm3_predictor.py

* Update chatglm3_predictor.py

* Update download_model.py

* Update chatglm3_predictor.py

* Update base.py

* Update base.py

* Update base.py

* Update app.py

* Update chatglm3_predictor.py

* Update app.py
  • Loading branch information
ypwhs committed Dec 18, 2023
1 parent eefba1b commit 02fe5f2
Show file tree
Hide file tree
Showing 11 changed files with 2,175 additions and 17 deletions.
22 changes: 14 additions & 8 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
print('Done'.center(64, '-'))

# 加载模型
model_name = 'THUDM/chatglm2-6b'
model_name = 'THUDM/chatglm3-6b'

if 'chatglm2' in model_name.lower():
if 'chatglm3' in model_name.lower():
from predictors.chatglm3_predictor import ChatGLM3
predictor = ChatGLM3(model_name)
elif 'chatglm2' in model_name.lower():
from predictors.chatglm2_predictor import ChatGLM2
predictor = ChatGLM2(model_name)
elif 'chatglm' in model_name.lower():
Expand All @@ -31,7 +34,10 @@


def revise(history, latest_message):
history[-1] = (history[-1][0], latest_message)
if isinstance(history[-1], tuple):
history[-1] = (history[-1][0], latest_message)
elif isinstance(history[-1], dict):
history[-1]['content'] = latest_message
return history, ''


Expand Down Expand Up @@ -76,21 +82,21 @@ def regenerate(last_state, max_length, top_p, temperature, allow_generate):
""")
with gr.Row():
with gr.Column(scale=4):
chatbot = gr.Chatbot(elem_id="chat-box", show_label=False).style(height=850)
chatbot = gr.Chatbot(elem_id="chat-box", show_label=False, height=850)
with gr.Column(scale=1):
with gr.Row():
max_length = gr.Slider(32, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
top_p = gr.Slider(0.01, 1, value=0.7, step=0.01, label="Top P", interactive=True)
temperature = gr.Slider(0.01, 5, value=0.95, step=0.01, label="Temperature", interactive=True)
with gr.Row():
query = gr.Textbox(show_label=False, placeholder="Prompts", lines=4).style(container=False)
query = gr.Textbox(show_label=False, placeholder="Prompts", lines=4)
generate_button = gr.Button("生成")
with gr.Row():
continue_message = gr.Textbox(
show_label=False, placeholder="Continue message", lines=2).style(container=False)
show_label=False, placeholder="Continue message", lines=2)
continue_btn = gr.Button("续写")
revise_message = gr.Textbox(
show_label=False, placeholder="Revise message", lines=2).style(container=False)
show_label=False, placeholder="Revise message", lines=2)
revise_btn = gr.Button("修订")
revoke_btn = gr.Button("撤回")
regenerate_btn = gr.Button("重新生成")
Expand All @@ -114,5 +120,5 @@ def regenerate(last_state, max_length, top_p, temperature, allow_generate):
outputs=[chatbot, query, continue_message])
interrupt_btn.click(interrupt, inputs=[allow_generate])

demo.queue(concurrency_count=4).launch(server_name='0.0.0.0', server_port=7860, share=False, inbrowser=False)
demo.queue().launch(server_name='0.0.0.0', server_port=7860, share=False, inbrowser=False)
demo.close()
61 changes: 61 additions & 0 deletions chatglm3/configuration_chatglm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from transformers import PretrainedConfig


class ChatGLMConfig(PretrainedConfig):
model_type = "chatglm"
def __init__(
self,
num_layers=28,
padded_vocab_size=65024,
hidden_size=4096,
ffn_hidden_size=13696,
kv_channels=128,
num_attention_heads=32,
seq_length=2048,
hidden_dropout=0.0,
classifier_dropout=None,
attention_dropout=0.0,
layernorm_epsilon=1e-5,
rmsnorm=True,
apply_residual_connection_post_layernorm=False,
post_layer_norm=True,
add_bias_linear=False,
add_qkv_bias=False,
bias_dropout_fusion=True,
multi_query_attention=False,
multi_query_group_num=1,
apply_query_key_layer_scaling=True,
attention_softmax_in_fp32=True,
fp32_residual_connection=False,
quantization_bit=0,
pre_seq_len=None,
prefix_projection=False,
**kwargs
):
self.num_layers = num_layers
self.vocab_size = padded_vocab_size
self.padded_vocab_size = padded_vocab_size
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.kv_channels = kv_channels
self.num_attention_heads = num_attention_heads
self.seq_length = seq_length
self.hidden_dropout = hidden_dropout
self.classifier_dropout = classifier_dropout
self.attention_dropout = attention_dropout
self.layernorm_epsilon = layernorm_epsilon
self.rmsnorm = rmsnorm
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
self.post_layer_norm = post_layer_norm
self.add_bias_linear = add_bias_linear
self.add_qkv_bias = add_qkv_bias
self.bias_dropout_fusion = bias_dropout_fusion
self.multi_query_attention = multi_query_attention
self.multi_query_group_num = multi_query_group_num
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
self.fp32_residual_connection = fp32_residual_connection
self.quantization_bit = quantization_bit
self.pre_seq_len = pre_seq_len
self.prefix_projection = prefix_projection
super().__init__(**kwargs)
Loading

0 comments on commit 02fe5f2

Please sign in to comment.