#### This repo allows you to train a Parlai chatbot with the Chinese T5 model released by UNER.
#### You can also replace the Chinese T5 model with your own T5 model simply by replacing the tokenizer and model path
#### The code shown in this notebook is basically copied from the the original T5 repo
#### I made some minor changes in order to make it work with the Chinesr T5 model
#### I ommit details about how to prepare the training data because of laziness. 
#### Please go to my previous repo called  "parlai_chinese_chatbot_by_gpt2" for more information about data preparation
#### Any question please send me email : studyouwei@gmail.com

In [None]:
#### Load all the necessary libraries
import os
import re
from typing import Any, Dict, Optional, Tuple

import pandas as pd
import torch
from transformers import BertTokenizer, T5ForConditionalGeneration

from parlai.agents.hugging_face.dict import HuggingFaceDictionaryAgent
from parlai.agents.hugging_face.t5 import (ParlaiT5Decoder, ParlaiT5Encoder,
                                           ParlaiT5Model, T5Agent)
from parlai.core.agents import register_agent
from parlai.core.opt import Opt
from parlai.core.teachers import DialogTeacher, register_teacher
from parlai.core.torch_generator_agent import TorchGeneratorModel
from parlai.scripts.train_model import TrainModel

In [None]:
try:
    from transformers.models.t5.modeling_t5 import T5Stack
except ModuleNotFoundError:
    T5Stack = object
# please disable this env variable in case any unexpected problems occur
os.environ["TOKENIZERS_PARALLELISM"] = "false"


In [None]:
#### define your own 'parlai' teacher which is responsible for feeding your model training samples


data = pd.read_csv('train_data.csv')
@register_teacher("my_teacher")
class MyTeacher(DialogTeacher):
    def __init__(self, opt, shared=None):
       
        opt['datafile'] = opt['datatype'].split(':')[0] + ".txt"
        super().__init__(opt, shared)

    def setup_data(self, datafile):
      
        print(f" ~~ Loading from {datafile} ~~ ")

       
        for _, diag in data.iterrows():
            text = diag['txt']
            labels = diag['label']
            start = diag['start']
            if isinstance(text, str) and isinstance(labels, str):
                # print(text)
                # print(labels)
                yield (text, labels), start

In [None]:
#### define your tokenizer for the Chinese t5 model
class MyDictionaryAgent(HuggingFaceDictionaryAgent):
    def get_tokenizer(self, opt):
        return BertTokenizer.from_pretrained("uer/t5-small-chinese-cluecorpussmall")

    @property
    def add_special_tokens(self) -> bool:
        """
        Whether to add special tokens when tokenizing.
        """
        return True

    @property
    def skip_decode_special_tokens(self) -> bool:
        """
        Whether to add special tokens when tokenizing.
        """
        return True

    def override_special_tokens(self, opt):
        # now override
        self.start_token = self.hf_tokenizer.cls_token
        self.end_token = self.hf_tokenizer.sep_token
        self.null_token = self.hf_tokenizer.pad_token
        self.unk_token = self.hf_tokenizer.unk_token

        self._unk_token_idx = self.hf_tokenizer.unk_token_id
        self.start_idx = self.hf_tokenizer.cls_token_id
        self.end_idx = self.hf_tokenizer.sep_token_id
        self.null_idx = self.hf_tokenizer.pad_token_id



In [None]:
#### Wrap t5 models, including its encoder and decoder, in Parlai format

class MyParlaiT5Model(TorchGeneratorModel):
    """
    Wrap T5 in ParlAI.
    """

    def __init__(self, opt, dictionary):
        self.pad_idx = dictionary[dictionary.null_token]
        self.start_idx = self.pad_idx
        self.end_idx = dictionary[dictionary.end_token]
        super().__init__(self.pad_idx, self.start_idx, self.end_idx)
        self.t5 = T5ForConditionalGeneration.from_pretrained(
            "uer/t5-small-chinese-cluecorpussmall")

        self.encoder = ParlaiT5Encoder(
            opt, self.t5.get_encoder(), self.pad_idx)
        self.decoder = ParlaiT5Decoder(
            opt, self.t5.get_decoder(), self.pad_idx)
        self.paralleled = not opt['t5_model_parallel']

    def _get_initial_forced_decoder_input(self, bsz: int, inputs: torch.LongTensor):
        """
        Return initial input to the decoder.
        :param bsz:
            batchsize
        :param inputs:
            inputs to decode
        :return initial_input:
            initial input for the decoder.
        """
        inputs = torch.cat([self.START.detach().expand(bsz, 1), inputs], 1)
        return inputs

    def reorder_encoder_states(self, encoder_states, indices):
        """
        Reorder the encoder states.
        See ``TorchGeneratorModel.reorder_encoder_states`` for a description.
        """
        enc, mask = encoder_states
        if not torch.is_tensor(indices):
            indices = torch.LongTensor(indices).to(enc.device)
        enc = torch.index_select(enc, 0, indices)
        mask = torch.index_select(mask, 0, indices)
        return enc, mask

    def reorder_decoder_incremental_state(
        self, incremental_state: Dict[int, dict], inds: torch.Tensor
    ) -> Dict[int, dict]:
        """
        Not *quite* sure how to reconcile this with HF.
        """
        return {}

    def output(self, tensor):
        """
        Compute output logits.
        """
        tensor = tensor * (self.t5.model_dim**-0.5)
        lm_logits = self.t5.lm_head(tensor)
        return lm_logits


class ParlaiT5Encoder(torch.nn.Module):
    def __init__(self, opt: Opt, encoder: T5Stack, padding_idx: Optional[int] = None):
        super().__init__()
        self.stack = encoder
        self.padding_idx = padding_idx
        self.paralleled = not opt[
            't5_model_parallel'
        ]  # need to parallel in forward; bug in HF

    def forward(
        self,
        input: torch.LongTensor,
        positions: Optional[torch.LongTensor] = None,
        segments: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.Tensor, torch.BoolTensor]:
        """
        Forward pass.
        :param LongTensor[batch,seqlen] input:
            The input IDs
        :param LongTensor[batch,seqlen] positions:
            Positions for input IDs
        :param LongTensor[batch,seqlen] segments:
            If provided, additionally adds ``segments`` as extra embedding features.
        """
        if not self.paralleled:
            self.stack.parallelize()
        mask = input != self.padding_idx
        # print(input)
        # print(self.stack.embed_tokens)
        outputs = self.stack(input, attention_mask=mask,
                             output_hidden_states=False)
        for k in outputs:
            if torch.is_tensor(outputs[k]):
                outputs[k] = outputs[k].to(input.device)
        return outputs[0], mask


class ParlaiT5Decoder(torch.nn.Module):
    def __init__(self, opt: Opt, decoder: T5Stack, padding_idx: Optional[int] = None):
        super().__init__()
        self.stack = decoder
        self.padding_idx = padding_idx
        self.paralleled = not opt[
            't5_model_parallel'
        ]  # need to parallel in forward; bug in HF

    def forward(
        self, input: torch.LongTensor, encoder_state: Tuple[Any], incr_state=None
    ):
        """
        Forward pass.
        :param LongTensor[batch,seqlen] input:
            The decoder inputs (partial or full decoded token IDs).
        :param encoder_state:
            Output from the encoder module forward pass.
        :param incr_state:
            The incremental state: a dictionary whose keys index the layers and whose
            values contain the incremental state for each layer.
        """
        if not self.paralleled:
            self.stack.parallelize()
        encoder_output, encoder_mask = encoder_state

        mask = input != self.padding_idx
        mask[:, 0] = True  # first token is pad
        # print(input)

        outputs = self.stack(
            input_ids=input,
            attention_mask=mask,
            encoder_hidden_states=encoder_output.to(input.device),
            encoder_attention_mask=encoder_mask.to(input.device),
        )
        return outputs[0].to(input.device), incr_state






In [None]:
#### define your t5 agent 
@register_agent('chinese_t5')
class ChineseT5Agent(T5Agent):

    def build_model(self):

        model = MyParlaiT5Model(self.opt, self.dict)

        return model

    def build_dictionary(self):
        """
        Overrides TorchAgent.build_dictionary to use t5 dict.
        """
        return MyDictionaryAgent(self.opt)

In [None]:
#### Now you are all set, just train the model !
####  Please comment out the param no_cuda=True if you do have cuda available ! 
TrainModel.main(
    model='chinese_t5',
    model_file='./model',
    task='my_teacher',
    lr=1e-5,
    optimizer='adam',
    warmup_updates=100,
    # t5_model_parallel = True,
    text_truncate=512,
    t5_model_arch='t5-small',
    batchsize=8,
    fp16=True,
    num_epochs=3,
    no_cuda=True
    # fp16_impl='mem_efficient'
)