In [2]:
import torch
import os
from transformers import (
    AutoConfig,
    AutoModel,
    AutoTokenizer,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    HfArgumentParser,
    Seq2SeqTrainingArguments,
    set_seed,
)
from arguments import ModelArguments
import argparse


In [6]:
model_args = ModelArguments(model_name_or_path="pretrain_model",
                                ptuning_checkpoint="./ckpt/ptuningv2",
                                pre_seq_len=128,
                                quantization_bit=8,)


In [7]:
config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
config.pre_seq_len = model_args.pre_seq_len
config.prefix_projection = model_args.prefix_projection

tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)

Explicitly passing a `revision` is encouraged when loading a configuration with custom code to ensure no malicious code has been contributed in a newer revision.
Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.


In [8]:
model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
    if k.startswith("transformer.prefix_encoder."):
        new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)

Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.
Loading checkpoint shards: 100%|██████████| 8/8 [00:10<00:00,  1.32s/it]
Some weights of ChatGLMForConditionalGeneration were not initialized from the model checkpoint at pretrain_model and are newly initialized: ['transformer.prefix_encoder.embedding.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


<All keys matched successfully>

In [9]:
print(f"Quantized to {model_args.quantization_bit} bit")
model = model.quantize(model_args.quantization_bit)

Quantized to 8 bit


In [10]:
model = model.half()
model.transformer.prefix_encoder.float()
model = model.cuda()
model = model.eval()

In [17]:
question = "咽痛，咳嗽，多痰"
response, history = model.chat(tokenizer, question, history=[],max_length=2048,
                                        eos_token_id=config.eos_token_id,
                                        do_sample=True, top_p=0.7, temperature=1,
                                        )
print(response)


你好，这种情况应该是咽喉炎症引起的，可以口服阿奇霉素，利咽颗粒试试，保持室内空气流通，平时饮食中加点水果蔬菜，另外可以口服蓝芩口服液调理一下。慢慢会改善的，祝健康
