In [None]:
import os, time, re, random, glob, json, jieba, copy
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
from transformers import (
    AutoModel,
    AutoModelForCausalLM,
    AutoModelForSeq2SeqLM,
    AutoModelForMultipleChoice,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    default_data_collator,
    TextGenerationPipeline
)

In [None]:
device="cuda:0" if torch.cuda.is_available() else "cpu"
from sys import platform
if platform == "linux" or platform == "linux2":
    # linux
    root = "/mnt/private-pa002-vol726121-prd/Data"
#         root = "/root/autodl-tmp/Data"
elif platform == "darwin":
    # OS X
    root = "/Users/zeyesun/Documents/Data"
elif platform == "win32":
    # Windows...
    root = "D:\\Data"

In [None]:
CLEAN_TEXT_PATTERN = re.compile(r"[\r\n]")

def clean_text(text):
    return CLEAN_TEXT_PATTERN.sub("", text)

In [None]:
# model_name_or_path = os.path.join(root, "models", "pangu-350M")
# model_name_or_path = os.path.join(root, "models", "pangu-2.6B")
# model_name_or_path = os.path.join(root, "models", "pangu-13B")
model_name_or_path = os.path.join(root, "models", "glm-335M-chinese")
# model_name_or_path = os.path.join(root, "models", "glm-10B-chinese")
# model_name_or_path = os.path.join(root, "models", "chatglm-6B")
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_cache=False, trust_remote_code=True)

# import sentencepiece
# model_file = os.path.join(root, "models", "pangu-350M", "vocab.model")
# sp = sentencepiece.SentencePieceProcessor()
# sp.Load(model_file=model_file)
# for i in range(10):
#     print(sp.id_to_piece(i))

# SFT Prediction

In [None]:
# prefix = "模型回答："
prefix = "答:"

In [None]:
if "glm" in model_name_or_path:
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, trust_remote_code=True)
else:
    model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, use_cache=False)
    model.resize_token_embeddings(tokenizer.vocab_size)
model.to(device)
print(model.device)

In [None]:
checkpoint_files = os.path.join(root, "chatgpt", "output", "sft", "pangu-350M", "external_checkpoint-12000", "pytorch_model*.bin")
# checkpoint_files = os.path.join(root, "chatgpt", "output", "sft", "pangu-2.6B", "external_checkpoint-9000", "pytorch_model*.bin")
checkpoints = glob.glob(checkpoint_files)
st = dict()
for checkpoint in checkpoints:
    st.update(torch.load(checkpoint, map_location="cpu"))
model.load_state_dict(st)

In [None]:
# text_generator = TextGenerationPipeline(model, tokenizer, device=device)

In [None]:
max_length = 512
max_length_generation = 50
num_return_sequences = 1
top_p = 0.8
temperature = 0.8
prompt = "写一篇歌颂祖国的文章"
# prompt_processed = prompt + tokenizer.sep_token + prefix
# prompt = """阅读文章：《战国无双3》（）是由光荣和ω-force开发的战国无双系列的正统第三续作。本作以三大故事为主轴，分别是以武田信玄等人为主的《关东三国志》，织田信长等人为主的《战国三杰》，石田三成等人为主的《关原的年轻武者》，丰富游戏内的剧情。此部份专门介绍角色，欲知武器情报、奥义字或擅长攻击类型等，请至战国无双系列1.由于乡里大辅先生因故去世，不得不寻找其他声优接手。从猛将传 and Z开始。2.战国无双 编年史的原创男女主角亦有专属声优。此模式是任天堂游戏谜之村雨城改编的新增模式。本作中共有20张战场地图（不含村雨城），后来发行的猛将传再新增3张战场地图。但游戏内战役数量繁多，部分地图会有兼用的状况，战役虚实则是以光荣发行的2本「战国无双3 人物真书」内容为主，以下是相关介绍。（注：前方加☆者为猛将传新增关卡及地图。）合并本篇和猛将传的内容，村雨城模式剔除，战国史模式可直接游玩。主打两大模式「战史演武」&「争霸演武」。系列作品外传作品\n问：《战国无双3》是由哪两个公司合作开发的？\n答："""
while True:
    if "glm" in model_name_or_path:
#         prompt += tokenizer.mask_token
        prompt += "[gMASK]"
        inputs = tokenizer(prompt, return_tensors="pt")
        inputs = tokenizer.build_inputs_for_generation(inputs, max_gen_length=max_length + max_length_generation)
        inputs = inputs.to(device)
        outputs = model.generate(**inputs,
                                 max_new_tokens=max_length_generation,
                                 eos_token_id=tokenizer.eop_token_id,
                                 pad_token_id=tokenizer.pad_token_id,
                                 do_sample=False,
                                 num_return_sequences=num_return_sequences,
                                 top_p=top_p,
                                 temperature=temperature)
        
    else:
        inputs = tokenizer(prompt, add_special_tokens=False, return_token_type_ids=False, return_tensors="pt")
        inputs = inputs.to(device)
        outputs = model.generate(**inputs,
                                 max_new_tokens=max_length_generation,
                                 pad_token_id=tokenizer.pad_token_id,
                                 do_sample=False,
                                 num_return_sequences=num_return_sequences,
                                 top_p=top_p,
                                 temperature=temperature)
        # outputs = text_generator(prompt, max_length=args.max_length_generation,
        #                          do_sample=True, num_return_sequences=args.num_return_sequences,
        #                          top_p=args.top_p, temperature=args.temperature)
#         results = [output['generated_text'].split(prefix, maxsplit=1)[1].replace(tokenizer.eos_token, "").replace(tokenizer.pad_token, "") for output in outputs]
    results = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    results = [result.split(prefix, maxsplit=1)[1] for result in results]
    break

In [None]:
# f = os.path.join(root, "raw", "baike_qa_train.json")
# [baike_qa.jsonl, chinese_classical.jsonl, chinese_poetry.jsonl, couplets.jsonl, weibo_summary_comments.jsonl, zhidao.jsonl]
f = os.path.join(root, "chatgpt", "processed", "baike_qa.jsonl")
i = 0
prompts = []
prompts_processed = []
labels = []
with open(f, "r", encoding="utf-8") as r:
    while True:
        line = r.readline()
        if not line:
            break
        item = json.loads(line.strip("\n"))
        # prompt = clean_text(item['title'] if len(item['title']) > len(item['desc']) else item['desc'])
        # prompt_processed = prompt + tokenizer.sep_token + prefix
        # label = clean_text(item['answer'])
        prompt = item['prompt']
        prompt_processed = prompt
        label = item['answers'][0]['answer']
        prompts.append(prompt)
        prompts_processed.append(prompt_processed)
        labels.append(label)
        i += 1
        # if i > 1000:
        #     break
random.shuffle(prompts_processed)
print(len(prompts_processed))

In [None]:
i = 79
num_return_sequences = 2
max_length = 512
max_length_generation = 100
top_k = 50
top_p = 0.8
temperature = 1.0
t1 = time.time()
prompt = prompts_processed[i]
inputs = tokenizer(prompt, add_special_tokens=False, return_token_type_ids=False, return_tensors="pt")
inputs = inputs.to(device)
outputs = model.generate(**inputs,
                         max_new_tokens=max_length_generation,
                         pad_token_id=tokenizer.pad_token_id,
                         do_sample=True,
                         num_return_sequences=num_return_sequences,
                         # top_p=top_p,
                         top_k=top_k,
                         temperature=temperature)
results = tokenizer.batch_decode(outputs, skip_special_tokens=True)
for result in results:
    # result.split(prefix, maxsplit=1)[1]
    print(result)


# results = text_generator(prompts_processed[i:j], max_length=200, num_return_sequences=num_return_sequences,
#                          do_sample=True, top_k=50, temperature=10.0)
# print(f"Finished prediction, time taken: {time.time()-t1}")

# for prompt, res, label in zip(prompts[i:j], results[:(j-i)], labels[i:j]):
#     print(f"prompt: {prompt}\nlabel: {label}")
#     for k in range(num_return_sequences):
#         model_answer = res[k]['generated_text'].split(prefix)[1].replace("<eot>", "").replace("<pad>", "")
#         print(f"model answer-{k}: {model_answer}")
#     print("\n\n")

# Split torch checkpoint into multiple checkpoints

In [None]:
checkpoint = os.path.join(root, "chatgpt", "output", "sft", "pangu-2.6B", "checkpoint-42782")
if "glm" in model_name_or_path:
    model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, trust_remote_code=True)
else:
    model = AutoModelForCausalLM.from_pretrained(checkpoint, trust_remote_code=True, use_cache=False)
st = model.state_dict()

In [None]:
import sys
print(sys.getsizeof(st))
keys = list(st.keys())
n = 10
m = {"metadata": {"total_size":sys.getsizeof(st)}, "weight_map":dict()}
span = len(keys) // n
for i in range(n):
    fn = f"pytorch_model-{i+1}-of-{n}.bin"
    f = os.path.join(checkpoint, fn)
    stt = dict()
    for key in keys[i*span:(i+1)*span]:
        stt[key] = st[key]
        m["weight_map"][key] = fn
    torch.save(stt, f)
f = os.path.join(checkpoint, "pytorch_model.bin.index.json")
json.dump(m, open(f, "w", encoding="utf-8"), ensure_ascii=False)

# Data Processing

### weibo_summary_comments_json

In [None]:
t = time.time()
fi = os.path.join(root, "raw", "weibo_summary_comments_json.json")
fo = os.path.join(root, "chatgpt", "processed", "weibo_summary_comments.jsonl")
ct = 0
with open(fo, "w", encoding="utf-8") as w:
    with open(fi, "r", encoding="utf-8") as r:
        while True:
            line = r.readline()
            if not line
                break
            
            item = json.loads(line.strip("\n"))
            article = item['article'].replace(" ", "")
            abstract = item['abstract'].replace(" ", "")
            prompt = f"新闻内容：{article}{tokenizer.sep_token}摘要：{abstract}{tokenizer.sep_token}评论："
            answers = [
                {
                    "answer": k.replace(" ", ""), 
                    "score": int(v)
                } for (k, v) in sorted(item['comments'], key=lambda x: (int(x[1]), len(x[0])), reverse=True)
            ]
            w.write(json.dumps({"prompt": prompt, "answers": answers}, ensure_ascii=False)+'\n')
            ct += 1
print(f"length: {ct}, time taken: {time.time()-t} s")

### couplets

In [None]:
t1 = time.time()
fi = os.path.join(root, "raw", "couplets.txt")
fo = os.path.join(root, "chatgpt", "processed", "couplets.jsonl")
l2 = []
nexts = dict()
with open(fi, "r", encoding="utf-8") as r:
    while True:
        line = r.readline()
        if not line:
            break
        line = line.strip("\n")
        idx = len(line) // 2
        prompt = line[:idx]
        answer = line[idx+1:]
        answers = [{"answer": answer, "score": 1}]
        l2.append({"prompt": f"上联：{prompt}{tokenizer.sep_token}下联：", "answers": answers})
        length = len(answer)
        if length not in nexts:
            nexts[length] = list()
        nexts[length].append(answer)
t2 = time.time()
print(f"length: {len(l2)}, # different lengths: {len(nexts)}, time taken: {t2-t1} s")
with open(fo, "w", encoding="utf-8") as w:
    for i, l in tqdm(enumerate(l2), desc="Processing Couplets"):
        answer = l['answers'][0]
        length = len(answer['answer'])
        # 上下联长度一样
        nexts_tmp = set(nexts[length])
        nexts_tmp.remove(answer['answer'])
        nexts_tmp = set(nexts[length]).difference(set([answer['answer']]))
#         nexts_tmp.remove(answer['answer'])
        answers.extend([{"answer": fa, "score": 0} for fa in random.sample(nexts_tmp, 2)])
        # 上下联长度不一样
        keys = set(nexts.keys())
        keys.remove(length)
        answers.extend([{"answer": random.choice(nexts[key]), "score": -1} for key in random.sample(keys, 2)])
#         answers = sorted(answers, key=lambda x: x['score'], reverse=True)
        w.write(json.dumps({"prompt": l['prompt'], "answers": answers}, ensure_ascii=False)+'\n')
#         if i % 1000 == 0:
#             print(f"{i} samples processed, time taken: {time.time()-t2} s")
print(f"length: {len(l2)}, time taken: {time.time()-t2} s")

### zhidao

In [None]:
t = time.time()
fp = os.path.join(root, "raw", "zhidao", "*.csv")
fo = os.path.join(root, "chatgpt", "processed", "zhidao.jsonl")
ct = 0
with open(fo, "w", encoding="utf-8") as w:
    for fi in glob.glob(fp):
        ct = 0
        df = pd.read_csv(fi).sort_values(by=["title", "is_best"], ascending=False)
        prev_title = None
        prev_prompt = None
        for _, val in df.iterrows():
            if isinstance(val['question'], str) and val['question'] != val['title']:
                prompt = f"问题：{val['title']}{tokenizer.sep_token}内容：{val['question']}{tokenizer.sep_token}回答："
            else:
                prompt = f"问题：{val['title']}{tokenizer.sep_token}回答："
            if prev_title is not None and prev_title == val['title']:
                answers.append({"answer": val['reply'], "score": val['is_best']})
            else:
                if prev_title is not None:
#                     l3.append({"prompt": prev_prompt, "answers": copy.deepcopy(answers)})
                    w.write(json.dumps({"prompt": prev_prompt, "answers": answers}, ensure_ascii=False)+'\n')
                answers = [{"answer": val['reply'], "score": val['is_best']}]
            prev_prompt = prompt
            prev_title = val['title']
            ct += 1
#         l3.append({"prompt": prev_prompt, "answers": copy.deepcopy(answers)})
        w.write(json.dumps({"prompt": prev_prompt, "answers": answers}, ensure_ascii=False)+'\n')
        print(f"finished processing {os.path.basename(fi)}")
print(f"length: {ct}, time taken: {time.time()-t} s")

### JDData

In [None]:
from html.parser import HTMLParser
class MyHTMLParser(HTMLParser):
    def __init__(self):
        super().__init__()
       #Initializing lists
        self.start_tags = list()
        self.end_tags = list()
        self.start_end_tags = list()
        self.data_list = list()
    #HTML Parser Methods
    def handle_starttag(self, startTag, attrs):
        self.start_tags.append(startTag)
    def handle_endtag(self, endTag):
        self.end_tags.append(endTag)
    def handle_startendtag(self,startendTag, attrs):
        self.start_end_tags.append(startendTag)
    def handle_data(self, data):
        self.data_list.append(data)
        
t = time.time()
fi = os.path.join(root, "raw", "JDData", "*.data*")
# fo = os.path.join(root, "chatgpt", "processed", "zhidao.jsonl")
ct = 0
with open(fo, "w", encoding="utf-8") as w:
    for fi in glob.glob(fp):
        ct = 0
        with open(fi, "r", encoding="gbk") as r:
            line = r.readline()
            items = line.strip("\n").split("\t")
            parser = MyHTMLParser()
            parser.feed(items[1])
            for t, d in zip(parser.start_tags, parser.data_list):
                print(f"{t}: {d}")
#                 prompt = f"问题：{val['title']}{tokenizer.sep_token}内容：{val['question']}{tokenizer.sep_token}回答："
#                 answers.append({"answer": val['reply'], "score": val['is_best']})
            ct += 1
#         l3.append({"prompt": prev_prompt, "answers": copy.deepcopy(answers)})
#         w.write(json.dumps({"prompt": prev_prompt, "answers": answers}, ensure_ascii=False)+'\n')
        print(f"finished processing {os.path.basename(fi)}")
print(f"length: {ct}, time taken: {time.time()-t} s")

### yf_amazon

In [None]:
t = time.time()
fi = os.path.join(root, "raw", "yf_amazon", "products.csv")
dfp = pd.read_csv(fi)
fi = os.path.join(root, "raw", "yf_amazon", "ratings.csv")
dfr = pd.read_csv(fi)
fi = os.path.join(root, "raw", "yf_amazon", "categories.csv")
dfc = pd.read_csv(fi)

In [None]:
dfp.columns
# dfp['name'].unique().tolist()

In [None]:
dfp['cate_id_1'] = dfp['catIds'].apply(lambda x: x.split(",")[0])
for cid1 in dfp['cate_id_1'].unique():
    print(dfc[dfc['catId']==int(cid1)]['category'])

### dmsc

In [None]:
t = time.time()
fi = os.path.join(root, "raw", "dmsc", "movies.csv")
dfm = pd.read_csv(fi)
print(dfm.shape)
fi = os.path.join(root, "raw", "dmsc", "ratings.csv")
dfr = pd.read_csv(fi)
print(dfr.shape)

In [None]:
dfr.groupby("movieId", 'rating').count()['comment']

### Chinese Classical-Modern

In [None]:
t1 = time.time()
fp = os.path.join(root, "raw", "Classical-Modern", "bitext", "*")
fo = os.path.join(root, "chatgpt", "processed", "chinese_classical.jsonl")
l3 = []
dicts = dict()
for fi in glob.glob(fp):
    name = os.path.basename(fi)
    dicts[name] = {"古文": [], "现代文": []}
    with open(fi, "r", encoding="utf-8") as r:
        for i, line in enumerate(r):
            line = line.strip("\n")
            if line.startswith("古文"):
                p1 = line[3:]
                dicts[name]['古文'].append(p1)
            elif line.startswith("现代文"):
                p2 = line[4:]
                dicts[name]['现代文'].append(p2)
            elif p1 is not None and p2 is not None:
                pair = [("古文", p1), ("现代文", p2)]
                random.shuffle(pair)
                prompt = f"{pair[0][0]}：{pair[0][1]}{tokenizer.sep_token}{pair[1][0]}："
                answers = [{"answer": pair[1][1], "score": 1}]
                l3.append({"prompt": prompt, "answers": answers, "name": name})
                p1 = None
                p2 = None
t2 = time.time()
print(f"length: {len(l3)}, # different names: {len(dicts)}, time taken: {t2-t1} s")
with open(fo, "w", encoding="utf-8") as w:
    for i, l in tqdm(enumerate(l3), desc="Processing Chinese Classical-Modern"):
        name = l['name']
        prompt = l['prompt']
        answer = l['answers'][0]['answer']
        if prompt.startswith("古文"):
            answer_type = '现代文'
        else:
            answer_type = '古文'
        samples_tmp = set(dicts[name][answer_type])
        samples_tmp.remove(answer)
        answers.extend([{"answer": fa, "score": 0} for fa in random.sample(samples_tmp, 2)])
        keys = set(dicts.keys())
        keys.remove(name)
        answers.extend([{"answer": random.choice(dicts[key][answer_type]), "score": -1} for key in random.sample(keys, 2)])
        w.write(json.dumps({"prompt": prompt, "answers": answers}, ensure_ascii=False)+'\n')
#         if i % 100 == 0:
#             print(f"{i} samples processed, time taken: {time.time()-t2} s")
print(f"length: {i}, time taken: {time.time()-t} s")

### Chinese Poetry

In [None]:
import opencc
converter = opencc.OpenCC('t2s.json')
t1 = time.time()
fp = [
    # 四书五经
    os.path.join(root, "raw", "chinese-poetry", "lunyu", "lunyu.json"),
#     os.path.join(root, "raw", "chinese-poetry", "mengxue", "*.json"),
    os.path.join(root, "raw", "chinese-poetry", "sishuwujing", "*.json"),
    # 古体诗
    os.path.join(root, "raw", "chinese-poetry", "caocaoshiji", "caocao.json"),
    os.path.join(root, "raw", "chinese-poetry", "shijing", "shijing.json"),
    # 楚辞
    os.path.join(root, "raw", "chinese-poetry", "chuci", "chuci.json"),
    # 诗
    os.path.join(root, "raw", "chinese-poetry", "shi", "poet*.json"),
    # 词
    os.path.join(root, "raw", "chinese-poetry", "ci", "ci*.json"),
    os.path.join(root, "raw", "chinese-poetry", "nalanxingde", "*.json"),
    os.path.join(root, "raw", "chinese-poetry", "wudai", "huajianji", "*juan.json"),
    os.path.join(root, "raw", "chinese-poetry", "wudai", "nantang", "poetrys.json"),
    # 曲
    os.path.join(root, "raw", "chinese-poetry", "yuanqu", "yuanqu.json"),
]
fs = [each for f in fp for each in glob.glob(f)]

l5 = []
dicts = dict()
for fi in fs:
    lines = json.load(open(fi, "r", encoding="utf-8"))
    if isinstance(lines, dict):
        lines = [lines]
    for i, line in enumerate(lines):
        if "lunyu" in fi:
            author = "孔子"
            genre = "经书"
            title = line['chapter']
            contents = "".join(line['paragraphs'])
        elif "daxue" in fi:
            author = "曾子"
            genre = "经书"
            title = "大学"
            contents = converter.convert("".join(line['paragraphs'])).replace("「", "“").replace("」", "”")
        elif "mengzi" in fi:
            author = "孟子"
            genre = "经书"
            title = converter.convert(line['chapter'])
            contents = converter.convert("".join(line['paragraphs'])).replace("「", "“").replace("」", "”")
        elif "zhongyong" in fi:
            author = "孔伋"
            genre = "经书"
            title = "中庸"
            contents = converter.convert("".join(line['paragraphs'])).replace("「", "“").replace("」", "”")
        elif "caocao" in fi:
            author = "曹操"
            genre = "古体诗"
            title = line['title']
            contents = "".join(line['paragraphs'])
        elif "shijing" in fi:
            author = "诗经"
            genre = "古体诗"
            title = line['chapter'] + "-" + line['section'] + "-" + line['title']
            contents = "".join(line['content'])
        elif "chuci" in fi:
            author = line['author']
            genre = "楚辞"
            title = line['section'] + "-" + line['title']
            contents = "".join(line['content'])
        elif "nalanxingde" in fi:
            author = line['author']
            genre = "词"
            title = line['title']
            contents = "".join(line['para'])
        elif "huajianci" in fi:
            author = line['author']
            genre = "词"
            title = line['title']
            contents = "".join(line['paragraphs'])
        elif "nantang" in fi:
            author = line['author']
            genre = "词"
            title = line['title']
            contents = "".join(line['paragraphs'])
        elif "yuanqu" in fi:
            author = line['author']
            genre = "曲"
            title = line['title']
            contents = "".join(line['paragraphs'])
        elif "shi" in fi:
            if len(line['paragraphs']) <= 0:
                continue
            author = converter.convert(line['author'])
            genre = "五言诗" if len(line['paragraphs'][0]) == 12 else "七言诗"
            title = converter.convert(line['title'])
            contents = converter.convert("".join(line['paragraphs']))
        elif "ci" in fi:
            author = line['author']
            genre = "词"
            title = line['rhythmic']
            contents = "".join(line['paragraphs'])
        if genre not in dicts:
            dicts[genre] = dict()
        if author not in dicts[genre]:
            dicts[genre][author] = dict()
        quantifier = "篇" if genre in ["经书", "楚辞"] else "首"
        prompt = f"以{author}的风格，写一{quantifier}{genre}，题为{title}{tokenizer.sep_token}"
        answers = [{"answer": contents, "score": 1}]
        l5.append({"prompt": prompt, "answers": answers, "genre": genre, "title": title, "author": author})
        dicts[genre][author][title] = contents
        
t2 = time.time()
print(f"length: {len(l5)}, # different lengths: {len(dicts)}, time taken: {t2-t1} s")
fo = os.path.join(root, "chatgpt", "processed", "chinese_poetry.jsonl")
with open(fo, "w", encoding="utf-8") as w:
    for i, l in tqdm(enumerate(l5), desc="Processing Chinese Poetry"):
        genre = l['genre']
        author = l['author']
        title = l['title']
        prompt = l['prompt']
        answers = l['answers']
        # 同作者其他作品-2
        titles_tmp = set(dicts[genre][author].keys())
        titles_tmp.remove(title)
        if len(titles_tmp) > 0:
            t = random.choice(list(titles_tmp))
            answers.append({"answer": dicts[genre][author][t], "score": 0})
        # 同体裁其他作者其他作品-1
        authors_tmp = set(dicts[genre].keys())
        authors_tmp.remove(author)
        a = random.choice(list(authors_tmp))
        t = random.choice(list(dicts[genre][a].keys()))
        answers.append({"answer": dicts[genre][a][t], "score": -1})
        # 不同体裁作品-0
        genres_tmp = set(dicts.keys())
        genres_tmp.remove(genre)
        g = random.choice(list(genres_tmp))
        a = random.choice(list(dicts[g].keys()))
        t = random.choice(list(dicts[g][a].keys()))
        answers.append({"answer": dicts[g][a][t], "score": -2})
        w.write(json.dumps({"prompt": prompt, "answers": answers}, ensure_ascii=False)+'\n')
print(f"length: {i}, time taken: {time.time()-t2} s")

### baike_qa_2019

In [None]:
fs = glob.glob(os.path.join(root, "raw", "baike_qa2019", "baike_qa_*.json"))
fo = os.path.join(root, "chatgpt", "processed", "baike_qa.jsonl")
ct = 0
# items = []
# lens_prompt = []
# lens_label = []
with open(fo, "w", encoding="utf-8") as w:
    for f in fs:
        with open(f, "r", encoding="utf-8") as r:
            while True:
                line = r.readline()
                if not line:
                    break
                item = json.loads(line.strip("\n"))
                question = clean_text(item['title'] if len(item['title']) > len(item['desc']) else item['desc'])
                prompt = f"{question}{tokenizer.sep_token}回答："
                answer = clean_text(item['answer'])
                answers = [{"answer": answer, "score": 1}]
#                 items.append(item)
#                 lens_prompt.append(len(prompt))
#                 lens_label.append(len(label))
                w.write(json.dumps({"prompt": prompt, "answers": answers}, ensure_ascii=False)+'\n')
                ct += 1
print(ct)
# print(len(items))
# print(np.percentile(lens_prompt, np.arange(90, 101)))
# print(np.percentile(lens_label, np.arange(90, 101)))

### rm-static

In [None]:
fi = os.path.join(root, "raw", "rm-static", "data", "test-00000-of-00001-bf4c733542e35fcb.parquet")
df = pd.read_parquet(fi)
print(df.shape)
df.head(2)

# ChatGLM

In [None]:
model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True).half().cuda()

In [None]:
texts = [
    "你好",
    "晚上睡不着应该怎么办"
]
history = []
for text in texts:
    response, history = model.chat(tokenizer, text, history=history)
    print(f"问: {text}\n答:{response}\n")