In [None]:

import os
import re
import pandas as pd

from parlai.scripts.train_model import TrainModel
from parlai.agents.hugging_face.gpt2 import Gpt2Agent
from transformers import BertTokenizer, GPT2LMHeadModel
from parlai.agents.transformer.modules import TransformerGeneratorModel
from parlai.core.agents import register_agent
from parlai.agents.transformer.generator import GeneratorAgent
from parlai.agents.hugging_face.dict import HuggingFaceDictionaryAgent, Gpt2DictionaryAgent
from transformers import AutoModel
from parlai.core.teachers import register_teacher, DialogTeacher
from parlai.scripts.interactive_web import InteractiveWeb


In [None]:
# 假设有两个人（假设分别是A和B）的对话聊天记录
# 你可以把这些记录放在一个csv文件(excel也行)。每一行放一问一答
# 文件3个字段:text,  label,  start
# text: A说的话, 文本数据
# label: B对A的回答，文本数据
# start: 是否是一次聊天的第一句， True 或者 False 
# 例子：
# text                    label    start
# 你好                    你好      True
# 你叫什么？               张三      False 
# 太巧了，我也叫张三        不会吧    False  

data = pd.read_csv('your_data.csv')

# 设置下parlai怎么读取你的数据
@register_teacher("my_teacher")
class MyTeacher(DialogTeacher):
    def __init__(self, opt, shared=None):
    
        opt['datafile'] = opt['datatype'].split(':')[0] + ".txt"
        super().__init__(opt, shared)

    def setup_data(self, datafile):
     
        print(f" ~~ Loading from {datafile} ~~ ")

        for _,diag in data.iterrows():
            text = diag['txt']
            labels = diag['label']
            start = diag['start']
            if isinstance(text,str) and isinstance(labels,str):
                yield (text, labels), start
                


In [None]:

# 因为使用GPT2-Chinese, 设置下这个模型要使用什么Tokenizer
class MyDictionaryAgent(Gpt2DictionaryAgent):

    def get_tokenizer(self, opt):
        """ 
        Instantiate and return the HF tokenizer (e.g. via .from_pretrained())
        """
        return BertTokenizer.from_pretrained("uer/gpt2-chinese-cluecorpussmall")

# 告诉Parlai我们使用Gpt-Chinese
@register_agent('chinese_gpt2')
class ChineseGPT(Gpt2Agent):
    ...
    @staticmethod
    def dictionary_class():
        """
        Return the dictionary class that this agent expects to use.
        Can be overridden if a more complex dictionary is required.
        """
        return MyDictionaryAgent

# 模型训练啦。要有GPU哈。如果你的数据只有几万条，大概要跑3-6个小时
# 如果数据上千万，可能要跑1周的时间
TrainModel.main(
    model='chinese_gpt2',
    model_name='uer/gpt2-chinese-cluecorpussmall',
    model_file='./model',
    task='my_teacher',
    lr=1e-5,
    optimizer= 'adam',
    warmup_updates=100,
    batchsize=8,
    fp16=True,
    num_epochs =3 , 
    fp16_impl='mem_efficient'
)


In [None]:
### 模型训练好后，部署在一个web服务器

InteractiveWeb.main(
    model='chinese_gpt2',
    model_file='./model',
    inference = 'beam',
    beam_size=3,
    # beam_min_length =4,
    # temperature =0.5,
    beam_block_ngram=4,
    beam_context_block_ngram=4,
    host='0.0.0.0')