In [2]:
import os, time, re, random, glob, json, jieba, copy
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    default_data_collator,
    TextGenerationPipeline
)

In [4]:
device="cuda:0" if torch.cuda.is_available() else "cpu"
from sys import platform
if platform == "linux" or platform == "linux2":
    # linux
    root = "/mnt/private-pa002-vol141056-prd/Data"
elif platform == "darwin":
    # OS X
    root = "/Users/zeyesun/Documents/Data"
elif platform == "win32":
    # Windows...
    root = "D:\\Data"

In [5]:
CLEAN_TEXT_PATTERN = re.compile(r"[\r\n]")

def clean_text(text):
    return CLEAN_TEXT_PATTERN.sub("", text)

In [6]:
# import sentencepiece
# model_file = os.path.join(root, "models", "pangu-350M", "vocab.model")
# sp = sentencepiece.SentencePieceProcessor()
# sp.Load(model_file=model_file)
# for i in range(10):
#     print(sp.id_to_piece(i))

In [8]:
model_name_or_path = os.path.join(root, "models", "pangu-350M")
# model_name_or_path = os.path.join(root, "Data", "models", "pangu-2.6B")
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
tokenizer.add_special_tokens({
    'unk_token': "<unk>", 
    'eos_token': "<eot>", 
    'pad_token': "<pad>", 
    "sep_token": "<sep>"
})

Explicitly passing a `revision` is encouraged when loading a configuration with custom code to ensure no malicious code has been contributed in a newer revision.
Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.


0

# Data Processing

### weibo_summary_comments_json

In [92]:
t = time.time()
fi = os.path.join(root, "raw", "weibo_summary_comments_json.json")
fo = os.path.join(root, "chatgpt", "processed", "weibo_summary_comments.jsonl")
ct = 0
with open(fo, "w", encoding="utf-8") as w:
    with open(fi, "r", encoding="utf-8") as r:
        while True:
            line = r.readline()
            if not line:
                break
            
            item = json.loads(line.strip("\n"))
            article = item['article'].replace(" ", "")
            abstract = item['abstract'].replace(" ", "")
            prompt = f"新闻内容：{article}{tokenizer.sep_token}摘要：{abstract}{tokenizer.sep_token}评论："
            answers = [
                {
                    "answer": k.replace(" ", ""), 
                    "score": int(v)
                } for (k, v) in sorted(item['comments'], key=lambda x: (int(x[1]), len(x[0])), reverse=True)
            ]
            w.write(json.dumps({"prompt": prompt, "answers": answers}, ensure_ascii=False)+'\n')
            ct += 1
print(f"length: {ct}, time taken: {time.time()-t} s")

length: 894732, time taken: 68.99229574203491 s


### couplets

In [None]:
t1 = time.time()
fi = os.path.join(root, "raw", "couplets.txt")
fo = os.path.join(root, "chatgpt", "processed", "couplets.jsonl")
l2 = []
nexts = dict()
with open(fi, "r", encoding="utf-8") as r:
    while True:
        line = r.readline()
        if not line:
            break
        line = line.strip("\n")
        idx = len(line) // 2
        prompt = line[:idx]
        answer = line[idx+1:]
        answers = [{"answer": answer, "score": 1}]
        l2.append({"prompt": f"上联：{prompt}{tokenizer.sep_token}下联：", "answers": answers})
        length = len(answer)
        if length not in nexts:
            nexts[length] = list()
        nexts[length].append(answer)
t2 = time.time()
print(f"length: {len(l2)}, # different lengths: {len(nexts)}, time taken: {t2-t1} s")
with open(fo, "w", encoding="utf-8") as w:
    for i, l in tqdm(enumerate(l2), desc="Processing Couplets"):
        answer = l['answers'][0]
        length = len(answer['answer'])
        # 上下联长度一样
        nexts_tmp = set(nexts[length])
        nexts_tmp.remove(answer['answer'])
        nexts_tmp = set(nexts[length]).difference(set([answer['answer']]))
#         nexts_tmp.remove(answer['answer'])
        false_answers_1 = [{"answer": fa, "score": 0} for fa in random.sample(nexts_tmp, 2)]
        # 上下联长度不一样
        keys = set(nexts.keys())
        keys.remove(length)
        false_answers_2 = [{"answer": random.choice(nexts[key]), "score": -1} for key in random.sample(keys, 2)]
        answers = [answer] + false_answers_1 + false_answers_2
#         answers = sorted(answers, key=lambda x: x['score'], reverse=True)
        w.write(json.dumps({"prompt": l['prompt'], "answers": answers}, ensure_ascii=False)+'\n')
#         if i % 1000 == 0:
#             print(f"{i} samples processed, time taken: {time.time()-t2} s")
print(f"length: {len(l2)}, time taken: {time.time()-t2} s")

length: 774491, # different lengths: 32, time taken: 3.667919635772705 s


since Python 3.9 and will be removed in a subsequent version.
  false_answers_1 = [{"answer": fa, "score": 0} for fa in random.sample(nexts_tmp, 2)]
since Python 3.9 and will be removed in a subsequent version.
  false_answers_2 = [{"answer": random.choice(nexts[key]), "score": -1} for key in random.sample(keys, 2)]
Processing Couplets: 986it [00:57, 12.84it/s]

### zhidao

In [101]:
t = time.time()
fp = os.path.join(root, "raw", "zhidao", "*.csv")
fo = os.path.join(root, "chatgpt", "processed", "zhidao.jsonl")
ct = 0
with open(fo, "w", encoding="utf-8") as w:
    for fi in glob.glob(fp):
        ct = 0
        df = pd.read_csv(fi).sort_values(by=["title", "is_best"], ascending=False)
        prev_title = None
        prev_prompt = None
        for _, val in df.iterrows():
            if isinstance(val['question'], str) and val['question'] != val['title']:
                prompt = f"问题：{val['title']}{tokenizer.sep_token}内容：{val['question']}{tokenizer.sep_token}回答："
            else:
                prompt = f"问题：{val['title']}{tokenizer.sep_token}回答："
            if prev_title is not None and prev_title == val['title']:
                answers.append({"answer": val['reply'], "score": val['is_best']})
            else:
                if prev_title is not None:
#                     l3.append({"prompt": prev_prompt, "answers": copy.deepcopy(answers)})
                    w.write(json.dumps({"prompt": prev_prompt, "answers": answers}, ensure_ascii=False)+'\n')
                answers = [{"answer": val['reply'], "score": val['is_best']}]
            prev_prompt = prompt
            prev_title = val['title']
            ct += 1
#         l3.append({"prompt": prev_prompt, "answers": copy.deepcopy(answers)})
        w.write(json.dumps({"prompt": prev_prompt, "answers": answers}, ensure_ascii=False)+'\n')
        print(f"finished processing {os.path.basename(fi)}")
print(f"length: {ct}, time taken: {time.time()-t} s")

finished processing financezhidao_filter.csv
finished processing liantongzhidao_filter.csv
finished processing touzizhidao_filter.csv
finished processing nonghangzhidao_filter.csv
finished processing baoxianzhidao_filter.csv
finished processing anhuidianxinzhidao_filter.csv
finished processing lawzhidao_filter.csv
length: 36368, time taken: 127.75226378440857 s


### JDData

In [101]:
from html.parser import HTMLParser
class MyHTMLParser(HTMLParser):
    def __init__(self):
        super().__init__()
       #Initializing lists
        self.start_tags = list()
        self.end_tags = list()
        self.start_end_tags = list()
        self.data_list = list()
    #HTML Parser Methods
    def handle_starttag(self, startTag, attrs):
        self.start_tags.append(startTag)
    def handle_endtag(self, endTag):
        self.end_tags.append(endTag)
    def handle_startendtag(self,startendTag, attrs):
        self.start_end_tags.append(startendTag)
    def handle_data(self, data):
        self.data_list.append(data)
        
t = time.time()
fi = os.path.join(root, "raw", "JDData", "*.data*")
# fo = os.path.join(root, "chatgpt", "processed", "zhidao.jsonl")
ct = 0
with open(fo, "w", encoding="utf-8") as w:
    for fi in glob.glob(fp):
        ct = 0
        with open(fi, "r", encoding="gbk") as r:
            line = r.readline()
            items = line.strip("\n").split("\t")
            parser = MyHTMLParser()
            parser.feed(items[1])
            for t, d in zip(parser.start_tags, parser.data_list):
                print(f"{t}: {d}")
#                 prompt = f"问题：{val['title']}{tokenizer.sep_token}内容：{val['question']}{tokenizer.sep_token}回答："
#                 answers.append({"answer": val['reply'], "score": val['is_best']})
            ct += 1
#         l3.append({"prompt": prev_prompt, "answers": copy.deepcopy(answers)})
#         w.write(json.dumps({"prompt": prev_prompt, "answers": answers}, ensure_ascii=False)+'\n')
        print(f"finished processing {os.path.basename(fi)}")
print(f"length: {ct}, time taken: {time.time()-t} s")

finished processing financezhidao_filter.csv
finished processing liantongzhidao_filter.csv
finished processing touzizhidao_filter.csv
finished processing nonghangzhidao_filter.csv
finished processing baoxianzhidao_filter.csv
finished processing anhuidianxinzhidao_filter.csv
finished processing lawzhidao_filter.csv
length: 36368, time taken: 127.75226378440857 s


### yf_amazon

In [153]:
t = time.time()
fi = os.path.join(root, "raw", "yf_amazon", "products.csv")
dfp = pd.read_csv(fi)
fi = os.path.join(root, "raw", "yf_amazon", "ratings.csv")
dfr = pd.read_csv(fi)
fi = os.path.join(root, "raw", "yf_amazon", "categories.csv")
dfc = pd.read_csv(fi)

In [168]:
dfp.columns
# dfp['name'].unique().tolist()

Index(['productId', 'name', 'catIds', 'cate_id_1'], dtype='object')

In [167]:
dfp['cate_id_1'] = dfp['catIds'].apply(lambda x: x.split(",")[0])
for cid1 in dfp['cate_id_1'].unique():
    print(dfc[dfc['catId']==int(cid1)]['category'])

832    图书音像
Name: category, dtype: object
911    母婴/玩具
Name: category, dtype: object
1111    运动户外
Name: category, dtype: object
486    钟表/首饰/眼镜/礼品
Name: category, dtype: object
1057    电脑/办公
Name: category, dtype: object
571    家具/家装/建材
Name: category, dtype: object
933    家居生活
Name: category, dtype: object
67    食品/保健
Name: category, dtype: object
57    其他
Name: category, dtype: object
1128    手机/数码
Name: category, dtype: object
916    美妆个护
Name: category, dtype: object
222    家用电器
Name: category, dtype: object
802    服饰服装
Name: category, dtype: object
518    鞋类箱包
Name: category, dtype: object
539    机票/充值/票务/虚拟
Name: category, dtype: object


### dmsc

In [169]:
t = time.time()
fi = os.path.join(root, "raw", "dmsc", "movies.csv")
dfm = pd.read_csv(fi)
print(dfm.shape)
fi = os.path.join(root, "raw", "dmsc", "ratings.csv")
dfr = pd.read_csv(fi)
print(dfr.shape)

(28, 3)
(2125056, 6)


In [172]:
dfr.groupby("movieId", 'rating').count()['comment']

movieId
0      54153
1      83692
2      64410
3      46233
4      44366
5     133393
6      30475
7     109162
8      23739
9      79962
10     91452
11     96620
12     85677
13     26797
14     35093
15     68359
16     78281
17    120200
18    113687
19     83173
20     39802
21     73882
22     88903
23     41152
24    102876
25     58746
26    113260
27    137511
Name: comment, dtype: int64

### Chinese Classical-Modern

In [None]:
t1 = time.time()
fp = os.path.join(root, "raw", "Classical-Modern", "bitext", "*")
fo = os.path.join(root, "chatgpt", "processed", "chinese_classical.jsonl")
l3 = []
dicts = dict()
for fi in glob.glob(fp):
    name = os.path.basename(fi)
    dicts[name] = {"古文": [], "现代文": []}
    with open(fi, "r", encoding="utf-8") as r:
        for i, line in enumerate(r):
            line = line.strip("\n")
            if line.startswith("古文"):
                p1 = line[3:]
                dicts[name]['古文'].append(p1)
            elif line.startswith("现代文"):
                p2 = line[4:]
                dicts[name]['现代文'].append(p2)
            elif p1 is not None and p2 is not None:
                pair = [("古文", p1), ("现代文", p2)]
                random.shuffle(pair)
                prompt = f"{pair[0][0]}：{pair[0][1]}{tokenizer.sep_token}{pair[1][0]}："
                answers = [{"answer": pair[1][1], "score": 1}]
                l3.append({"prompt": prompt, "answers": answers, "name": name})
                p1 = None
                p2 = None
t2 = time.time()
print(f"length: {len(l3)}, # different names: {len(dicts)}, time taken: {t2-t1} s")
with open(fo, "w", encoding="utf-8") as w:
    for i, l in tqdm(enumerate(l3), desc="Processing Chinese Classical-Modern"):
        name = l['name']
        prompt = l['prompt']
        answer = l['answers'][0]['answer']
        if prompt.startswith("古文"):
            answer_type = '现代文'
        else:
            answer_type = '古文'
        samples_tmp = set(dicts[name][answer_type])
        samples_tmp.remove(answer)
        false_answers_1 = [{"answer": fa, "score": 0} for fa in random.sample(samples_tmp, 2)]
        keys = set(dicts.keys())
        keys.remove(name)
        false_answers_2 = [{"answer": random.choice(dicts[key][answer_type]), "score": -1} for key in random.sample(keys, 2)]
        answers = [answer] + false_answers_1 + false_answers_2
        w.write(json.dumps({"prompt": prompt, "answers": answers}, ensure_ascii=False)+'\n')
#         if i % 100 == 0:
#             print(f"{i} samples processed, time taken: {time.time()-t2} s")
print(f"length: {i}, time taken: {time.time()-t} s")

length: 967255, # different lengths: 27, time taken: 6.60289192199707 s


since Python 3.9 and will be removed in a subsequent version.
  false_answers_2 = [{"answer": random.choice(dicts[key][answer_type]), "score": -1} for key in random.sample(keys, 2)]
Processing Chinese Classical-Modern: 967255it [19:13:59, 13.97it/s] 

length: 967254, time taken: 69959.00833964348 s





### Chinese Poetry

In [34]:
import opencc
converter = opencc.OpenCC('t2s.json')
t1 = time.time()
fp = [
    # 四书五经
    os.path.join(root, "raw", "chinese-poetry", "lunyu", "lunyu.json"),
#     os.path.join(root, "raw", "chinese-poetry", "mengxue", "*.json"),
    os.path.join(root, "raw", "chinese-poetry", "sishuwujing", "*.json"),
    # 古体诗
    os.path.join(root, "raw", "chinese-poetry", "caocaoshiji", "caocao.json"),
    os.path.join(root, "raw", "chinese-poetry", "shijing", "shijing.json"),
    # 楚辞
    os.path.join(root, "raw", "chinese-poetry", "chuci", "chuci.json"),
    # 诗
    os.path.join(root, "raw", "chinese-poetry", "shi", "poet*.json"),
    # 词
    os.path.join(root, "raw", "chinese-poetry", "ci", "ci*.json"),
    os.path.join(root, "raw", "chinese-poetry", "nalanxingde", "*.json"),
    os.path.join(root, "raw", "chinese-poetry", "wudai", "huajianji", "*juan.json"),
    os.path.join(root, "raw", "chinese-poetry", "wudai", "nantang", "poetrys.json"),
    # 曲
    os.path.join(root, "raw", "chinese-poetry", "yuanqu", "yuanqu.json"),
]
fs = [each for f in fp for each in glob.glob(f)]

l5 = []
dicts = dict()
for fi in fs:
    lines = json.load(open(fi, "r", encoding="utf-8"))
    if isinstance(lines, dict):
        lines = [lines]
    for i, line in enumerate(lines):
        if "lunyu" in fi:
            author = "孔子"
            genre = "经书"
            title = line['chapter']
            contents = "".join(line['paragraphs'])
        elif "daxue" in fi:
            author = "曾子"
            genre = "经书"
            title = "大学"
            contents = converter.convert("".join(line['paragraphs'])).replace("「", "“").replace("」", "”")
        elif "mengzi" in fi:
            author = "孟子"
            genre = "经书"
            title = converter.convert(line['chapter'])
            contents = converter.convert("".join(line['paragraphs'])).replace("「", "“").replace("」", "”")
        elif "zhongyong" in fi:
            author = "孔伋"
            genre = "经书"
            title = "中庸"
            contents = converter.convert("".join(line['paragraphs'])).replace("「", "“").replace("」", "”")
        elif "caocao" in fi:
            author = "曹操"
            genre = "古体诗"
            title = line['title']
            contents = "".join(line['paragraphs'])
        elif "shijing" in fi:
            author = "诗经"
            genre = "古体诗"
            title = line['chapter'] + "-" + line['section'] + "-" + line['title']
            contents = "".join(line['content'])
        elif "chuci" in fi:
            author = line['author']
            genre = "楚辞"
            title = line['section'] + "-" + line['title']
            contents = "".join(line['content'])
        elif "nalanxingde" in fi:
            author = line['author']
            genre = "词"
            title = line['title']
            contents = "".join(line['para'])
        elif "huajianci" in fi:
            author = line['author']
            genre = "词"
            title = line['title']
            contents = "".join(line['paragraphs'])
        elif "nantang" in fi:
            author = line['author']
            genre = "词"
            title = line['title']
            contents = "".join(line['paragraphs'])
        elif "yuanqu" in fi:
            author = line['author']
            genre = "曲"
            title = line['title']
            contents = "".join(line['paragraphs'])
        elif "shi" in fi:
            if len(line['paragraphs']) <= 0:
                continue
            author = converter.convert(line['author'])
            genre = "五言诗" if len(line['paragraphs'][0]) == 12 else "七言诗"
            title = converter.convert(line['title'])
            contents = converter.convert("".join(line['paragraphs']))
        elif "ci" in fi:
            author = line['author']
            genre = "词"
            title = line['rhythmic']
            contents = "".join(line['paragraphs'])
        if genre not in dicts:
            dicts[genre] = dict()
        if author not in dicts[genre]:
            dicts[genre][author] = dict()
        quantifier = "篇" if genre in ["经书", "楚辞"] else "首"
        prompt = f"以{author}的风格，写一{quantifier}{genre}，题为{title}{tokenizer.sep_token}"
        answers = [{"answer": contents, "score": 1}]
        l5.append({"prompt": prompt, "answers": answers, "genre": genre, "title": title, "author": author})
        dicts[genre][author][title] = contents
        
t2 = time.time()
print(f"length: {len(l5)}, # different lengths: {len(dicts)}, time taken: {t2-t1} s")
fo = os.path.join(root, "chatgpt", "processed", "chinese_poetry.jsonl")
with open(fo, "w", encoding="utf-8") as w:
    for i, l in tqdm(enumerate(l5), desc="Processing Chinese Poetry"):
        genre = l['genre']
        author = l['author']
        title = l['title']
        prompt = l['prompt']
        answers = l['answers']
        # 同作者其他作品-2
        titles_tmp = set(dicts[genre][author].keys())
        titles_tmp.remove(title)
        if len(titles_tmp) > 0:
            t = random.choice(list(titles_tmp))
            answers.append({"answer": dicts[genre][author][t], "score": 0})
        # 同体裁其他作者其他作品-1
        authors_tmp = set(dicts[genre].keys())
        authors_tmp.remove(author)
        a = random.choice(list(authors_tmp))
        t = random.choice(list(dicts[genre][a].keys()))
        answers.append({"answer": dicts[genre][a][t], "score": -1})
        # 不同体裁作品-0
        genres_tmp = set(dicts.keys())
        genres_tmp.remove(genre)
        g = random.choice(list(genres_tmp))
        a = random.choice(list(dicts[g].keys()))
        t = random.choice(list(dicts[g][a].keys()))
        answers.append({"answer": dicts[g][a][t], "score": -2})
        w.write(json.dumps({"prompt": prompt, "answers": answers}, ensure_ascii=False)+'\n')
print(f"length: {i}, time taken: {time.time()-t2} s")

length: 345170, # different lengths: 7, time taken: 21.88776707649231 s


### baike_qa_2019

In [6]:
f = os.path.join(root, "raw", "baike_qa_train.json")
items = []
lens_prompt = []
lens_label = []
with open(f, "r", encoding="utf-8") as r:
    while True:
        line = r.readline()
        if not line:
            break
        item = json.loads(line.strip("\n"))
        prompt = clean_text(item['title'] if len(item['title']) > len(item['desc']) else item['desc'])
        label = clean_text(item['answer'])
        items.append(item)
        lens_prompt.append(len(prompt))
        lens_label.append(len(label))
print(len(items))
print(np.percentile(lens_prompt, np.arange(90, 101)))
print(np.percentile(lens_label, np.arange(90, 101)))

1425170


### rm-static

In [44]:
fi = os.path.join(root, "raw", "rm-static", "data", "test-00000-of-00001-bf4c733542e35fcb.parquet")
df = pd.read_parquet(fi)
print(df.shape)
df.head(2)

(5103, 4)


Unnamed: 0,prompt,response,chosen,rejected
0,\n\nHuman: I am trying to write a fairy tale. ...,This sounds like a really interesting modern ...,This sounds like a really interesting modern ...,And the prince and the princess both decide t...
1,\n\nHuman: What flowers should I grow to attra...,"Great, there are a lot of different kinds of ...","Great, there are a lot of different kinds of ...","In particular, it’s important to have a wide ..."


# SFT Prediction

In [15]:
prefix = "模型回答："

In [5]:
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, use_cache=False)
model.resize_token_embeddings(len(tokenizer.sp))
model.config.end_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
# model.config.max_length_prompt = 200
model.to(device)
# print(model.device)

Explicitly passing a `revision` is encouraged when loading a configuration with custom code to ensure no malicious code has been contributed in a newer revision.
Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.


GPTPanguForCausalLM(
  (transformer): GPTPanguModel(
    (wte): Embedding(40000, 2560)
    (wpe): Embedding(1024, 2560)
    (wqe): Embedding(1024, 2560)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPTPanguBlock(
        (ln_1): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
        (attn): GPTPanguAttention(
          (k_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (v_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (q_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (c_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
        (mlp): GPTPanguMLP(
          (c_fc): Linear(in_features=2560, out_features=10240, bias=True)
          (c_proj): Linear(in_features=10240, out_featur

In [6]:
checkpoint_files = os.path.join(root, "Data", "chatgpt", "output", "sft", "pangu-350M", "checkpoint-12000", "pytorch_model*.bin")
# checkpoint_files = os.path.join(root, "Data", "output", "sft", "pangu-2.6B", "checkpoint-9000", "pytorch_model*.bin")
checkpoints = glob.glob(checkpoint_files)
st = dict()
for checkpoint in checkpoints:
    st.update(torch.load(checkpoint, map_location="cpu"))
model.load_state_dict(st)

<All keys matched successfully>

In [20]:
text_generator = TextGenerationPipeline(model, tokenizer, device=device)

The model 'GPTPanguForCausalLM' is not supported for . Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'CodeGenForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'GPT2LMHeadModel', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegatronBertForCausalLM', 'MvpForCausalLM', 'OpenAIGPTLMHeadModel', 'OPTForCausalLM', 'PegasusForCausalLM', 'PLBartForCausalLM', 'ProphetNetForCausalLM', 'QDQBertLMHeadModel', 'ReformerModelWithLMHead', 'RemBertForCausalLM', 'RobertaForCausalLM', 'RoFormerForCausalLM', 'Speech2Text2ForCausalLM', 'TransfoXLLMHeadModel', 'TrOCRForCausalLM', 'XGLMForCausalLM', 'XLMWithLMHeadModel', 'XLMProphetNetForCausalLM', 'XLMRobertaForCausalLM', 'XLMRobertaX

In [8]:
f = os.path.join(root, "Data", "raw", "baike_qa_train.json")
i = 0
prompts = []
prompts_processed = []
labels = []
with open(f, "r", encoding="utf-8") as r:
    while True:
        line = r.readline()
        if not line:
            break
        item = json.loads(line.strip("\n"))
        prompt = clean_text(item['title'] if len(item['title']) > len(item['desc']) else item['desc'])
        label = clean_text(item['answer'])
        prompt_processed = prompt + tokenizer.sep_token + prefix
        prompts.append(prompt)
        prompts_processed.append(prompt_processed)
        labels.append(label)
        i += 1
        if i > 1000:
            break

In [14]:
num_return_sequences = 5
i = 10
j = 20
t1 = time.time()
results = text_generator(prompts_processed[i:j], max_length=200, num_return_sequences=num_return_sequences,
                         do_sample=True, top_k=50, temperature=10.0)
print(f"Finished prediction, time taken: {time.time()-t1}")

for prompt, res, label in zip(prompts[i:j], results[:(j-i)], labels[i:j]):
    print(f"prompt: {prompt}\nlabel: {label}")
    for k in range(num_return_sequences):
        model_answer = res[k]['generated_text'].split(prefix)[1].replace("<eot>", "").replace("<pad>", "")
        print(f"model answer-{k}: {model_answer}")
    print("\n\n")

prompt: 有6区无尽之海众神之子工会的吗?你们会怎么了?看不见人了,解散了 
label: 我听原众神成员说 众神的副会长 卷了工会财产跑路了(几千W呢)  所以导致工会解散 哎 真是人心陷恶   
model answer-0: 我也是无尽之海的,现在工会还没解散,是不可能解散了,不过我想你应该有很多个工会吧,应该有很多人在等着你,如果你有足够的耐心的话,可以去工会看看有多少人在等着你,我有个朋友就在等着我,我说的是在无尽之海,你知道的,要是没人的话,你就去工会看看,反正现在工会也没解散,你可以去公会看看,我也是在无尽之海的,在无尽之海也有很多人在等着你,要是你能坚持的话,我想你会更好的!结舌
model answer-1: 现在工会也解散了,是吧··タタタ瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄瞄
model answer-2: 你可以去看看万盛万盛万盛万盛万盛万盛万盛万盛万盛万盛万盛万盛万盛万盛万盛万盛万盛NdF万盛NdF
model answer-3: 不解散了,我是6区无尽之海的,我想知道,如果解散的话,工会的人和工会的人一起走,会怎么样?方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒方寒
model answer-4: 有,因为工会的人太多了,服务器都快封了。タ榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫榫



prompt: 老师好，

In [21]:
prompt = "北大和清华哪个更好？"
prompt_processed = prompt + tokenizer.sep_token + prefix
num_return_sequences=5
res = text_generator(prompt_processed, max_length=200, num_return_sequences=num_return_sequences,
                         do_sample=True, top_k=50, temperature=5.0)
print(f"prompt: {prompt}")
for i in range(num_return_sequences):
    model_answer = res[i]['generated_text'].split(prefix)[1].replace("<eot>", "").replace("<pad>", "")
    # print(res)
    print(f"model answer-{i}: {model_answer}")

Building prefix dict from the default dictionary ...
I0224 13:57:47.086069 140704479725184 __init__.py:113] Building prefix dict from the default dictionary ...
Loading model from cache /var/folders/wf/12stcn_s6zq56j9h3fnkv5lm0000gn/T/jieba.cache
I0224 13:57:47.104202 140704479725184 __init__.py:133] Loading model from cache /var/folders/wf/12stcn_s6zq56j9h3fnkv5lm0000gn/T/jieba.cache
Loading model cost 0.565 seconds.
I0224 13:57:47.664287 140704479725184 __init__.py:165] Loading model cost 0.565 seconds.
Prefix dict has been built successfully.
I0224 13:57:47.666184 140704479725184 __init__.py:166] Prefix dict has been built successfully.


prompt: 内有吊车，牛腿高度怎么定？根据什么来定？生产的工艺要求？还是要考虑到其他因素？
model answer-0: 工艺性强了定在1.2的高。生产量要能到2。6车就行不生产也有1.6高度牛只等。一般要求1.0就能很明确可以计算:工艺高到多少比较科学(不能超过3的要考虑3吨以下汽车要超过1.0等要求)
model answer-1: 这是可以确定重量(1米60一米40多啊牛腿不是固定啊呵呵)当然这有很大差别呢主要工艺性决定,如要求速度很高就要设较佳速度来加工...呵呵了。
model answer-2: 牛吊(头向下弯曲至脚尖方向的过程均指同一吊轮形式而分开表示用轮、吊、牛三个量块计量不同头辐。下轮吊运时常由1对,8台1式起重机按一用一吊或两种方式排列配置完成为准
model answer-3: 先定一根长100/1*30厘米钢管和高70公厘米无缝钢管和厚4一13之间2组钢管做“架子钢骨柱基础架板基础或柱的下部混凝土(每只立柱混凝土20到1250加500就足够≥1)柱筋2和圈4根
model answer-4: 楼上有人是傻不隆科型吧.这个很不好设定啊如果把车架(如W5F和OYFL6ZXHB车架有不小)的强度用公式A2F1×hP定好车臂高/m在1和100之最小就足够大这样


# Reward Model

In [9]:
model_name_or_path = "D:\\Data\\models\\pangu-350M"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_cache=False, trust_remote_code=True)
tokenizer.add_special_tokens({'unk_token': "<unk>",
                                  'bos_token': "<s>",
                                  'eos_token': "<eot>",
                                  'pad_token': "<pad>",
                                  "sep_token": "<sep>"})

Explicitly passing a `revision` is encouraged when loading a configuration with custom code to ensure no malicious code has been contributed in a newer revision.
Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.


0

In [10]:
max_length = 1024
text = "你好，你是谁"
# text = "<|startoftext|>" + text + "<|endoftext|>"
res = tokenizer(text, max_length=max_length, truncation="longest_first", 
          return_tensors="pt", add_special_tokens=False)

Building prefix dict from the default dictionary ...
Loading model from cache C:\Users\SUNZEY~1\AppData\Local\Temp\jieba.cache
Loading model cost 0.501 seconds.
Prefix dict has been built successfully.


In [19]:
res.keys()
# torch.cat((res['input_ids'], res['input_ids']), axis=1)

dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])