# Sequence to Sequence ChatBot Inference

Inference on seq2seq chatbot, Using CPU

In [1]:
import torch
from torch import nn
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from jieba import cut
from p3self.lprint import lprint
from multiprocessing import Pool
from collections import Counter
import pandas as pd
import numpy as np

In [2]:
BS = 512# Batch size

VOCAB_SEQ_IN = 3000
VOCAB_SEQ_OUT = 3000

SOS_TOKEN = 0
EOS_TOKEN = 1

LR = 5e-3
HIDDEN_SIZE = 512
MAX_LEN = 20

VERSION = "0.0.2"
# "0.0.1" chars
# "0.0.2" token

CUDA = False

CN_SEG = True

if CN_SEG:
    DICT_IN = "/data/dict/chat_vocab_in.csv"
    DICT_OUT = "/data/dict/chat_vocab_out.csv"
    SEQ_DIR = "/data/chat/xhj_seq.npy"
else:
    DICT_IN = "/data/dict/chat_vocab_in.csv"
    DICT_OUT = "/data/dict/chat_vocab_out.csv"
    SEQ_DIR = "/data/chat/xhj_seq_char.npy"

In [3]:
class s2s_data(Dataset):
    def __init__(self,load_io, vocab_in, vocab_out, seq_addr, build_seq=False,
                 build_vocab = False,):
        """
        vocab_in,vocab_out are csv file addresses
        """
        self.load_io=load_io
        self.vocab_in = vocab_in
        self.vocab_out = vocab_out
        self.seq_addr = seq_addr
        
        print("[Loading the sequence data]")
        
        if build_seq:
            self.i,self.o = self.load_io()
            np.save(self.seq_addr,[self.i,self.o])
        else:
            [self.i,self.o] = np.load(self.seq_addr).tolist()
        print("[Sequence data loaded]")
            
        assert len(self.i)==len(self.o),"input seq length mush match output seq length"
        
        self.N = len(self.i)
        print("Length of sequence:\t",self.N)
        
        if build_vocab:
            self.vocab_i = self.build_vocab(self.i)
            self.vocab_o = self.build_vocab(self.o)
            
            self.vocab_i.to_csv(self.vocab_in)
            self.vocab_o.to_csv(self.vocab_out)
            
            self.print_vocab_info()
        else:
            self.vocab_i = pd.read_csv(self.vocab_in).fillna("")
            self.vocab_o = pd.read_csv(self.vocab_out).fillna("")
                  
            self.print_vocab_info()
        
        print("building mapping dicts")
        self.i_char2idx,self.i_idx2char = self.get_mapping(self.vocab_i)
        self.o_char2idx,self.o_idx2char = self.get_mapping(self.vocab_o)
        
    def __len__(self):
        return self.N
    
    def __getitem__(self,idx):
        return self.seq2idx(self.i[idx],self.mapfunc_i),self.seq2idx(self.o[idx],self.mapfunc_o)
    
    def get_full_token(self,list_of_tokens):
        """
        From a list of list of tokens, to a long list of tokens, duplicate tokens included
        """
        return (" ".join(list_of_tokens)).split(" ")
    
    def get_mapping(self,vocab_df):
        char2idx=dict(zip(vocab_df["token"],vocab_df["idx"]))
        idx2char=dict(zip(vocab_df["idx"],vocab_df["token"]))
        return char2idx,idx2char
    
    def seq2idx(self,x,mapfunc):
        return np.vectorize(mapfunc)(x.split(" ")).tolist()
    
    def mapfunc_i(self,x):
        try:
            return self.i_char2idx[x]
        except:
            return 2
        
    def mapfunc_o(self,x):
        try:
            return self.o_char2idx[x]
        except:
            return 2
        
    def get_token_count_dict(self,full_token):
        """count the token to a list"""
        return Counter(full_token)
    
    def build_vocab(self,seq_list):
        ct_dict = self.get_token_count_dict(self.get_full_token(seq_list))
        ct_dict["SOS_TOKEN"] = 9e9
        ct_dict["EOS_TOKEN"] = 8e9
        ct_dict[" "] = 7e9
        tk,ct = list(ct_dict.keys()),list(ct_dict.values())
        
        token_df=pd.DataFrame({"token":tk,"count":ct}).sort_values(by="count",ascending=False)
        return token_df.reset_index().drop("index",axis=1).reset_index().rename(columns={"index":"idx"}).fillna("")
    
    def print_vocab_info(self):
        self.vocab_size_i = len(self.vocab_i)
        self.vocab_size_o = len(self.vocab_o)
        
        print("[in seq vocab address]: %s,\t%s total lines"%(self.vocab_in,self.vocab_size_i))
        print("[out seq vocab address]: %s,\t%s total lines"%(self.vocab_out,self.vocab_size_o))
            
        print("Input sequence vocab samples:")
        print(self.vocab_i.sample(5))
        print("Output sequence vocab samples:")
        print(self.vocab_o.sample(5))

In [4]:
def load_empty():
    return list(range(5)),list(range(5))

In [5]:
class inf_s2s(s2s_data):
    def __init__(self,vocab_in, vocab_out):
        super(inf_s2s,self).__init__(load_empty, vocab_in, vocab_out, seq_addr="/data/chat/empty.npy", build_seq=True,
                 build_vocab = False,)
        
    def feed_encoder(self,x):
        if CN_SEG:
            x_list = list(cut(x))
        else:
            x_list = list(str(x))
        arr = np.array(self.seq2idx(" ".join(x_list),self.mapfunc_o))
        return torch.LongTensor(arr).unsqueeze(0)
        

In [6]:
inf=inf_s2s(vocab_in = DICT_IN,
         vocab_out = DICT_OUT,)
# inf=inf_s2s(vocab_in = DICT_IN,
#          vocab_out = "/data/dict/chat_vocab_char_out.csv",)

[Loading the sequence data]
[Sequence data loaded]
Length of sequence:	 5
[in seq vocab address]: /data/dict/chat_vocab_in.csv,	62596 total lines
[out seq vocab address]: /data/dict/chat_vocab_out.csv,	55508 total lines
Input sequence vocab samples:
       Unnamed: 0    idx  count token
18424       18424  18424    3.0   张无忌
32355       32355  32355    2.0    奶有
61283       61283  61283    1.0   时志诚
21932       21932  21932    3.0    拍个
55102       55102  55102    1.0   人生观
Output sequence vocab samples:
       Unnamed: 0    idx  count token
31898       31898  31898    2.0   父皇母
17166       17166  17166    6.0   凑合着
42525       42525  42525    1.0   十五号
35713       35713  35713    2.0  静下心来
4319         4319   4319   43.0    洋洋
building mapping dicts


In [7]:
inf.feed_encoder("很高兴认识你")

Building prefix dict from the default dictionary ...
Dumping model to file cache /tmp/jieba.cache
Loading model cost 0.891 seconds.
Prefix dict has been built succesfully.


tensor([[  52,  684,  211,    7]])

In [8]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, n_layers=1):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, num_layers=n_layers)
        
    def forward(self, input_, hidden):
        output, hidden = self.gru(self.embedding(input_), hidden)
        return output, hidden

    # TODO: other inits
    def initHidden(self, batch_size):
        en_hidden = torch.zeros(1, batch_size, self.hidden_size)
        if CUDA:
            en_hidden = en_hidden.cuda()
        return en_hidden

In [9]:
class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, n_layers=1):
        super(DecoderRNN, self).__init__()
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, num_layers=n_layers)
        # TODO use transpose of embedding
        self.out = nn.Linear(hidden_size, output_size)
        self.sm = nn.LogSoftmax(dim=-1)
        
    def forward(self, input_, hidden):
        emb = self.embedding(input_).unsqueeze(1)
        # NB: Removed relu
        res, hidden = self.gru(emb, hidden)
        output = self.sm(self.out(res[:,0]))
        return output, hidden
    
    def initInput(self,batch_size):
        decoder_input = torch.LongTensor([SOS_TOKEN]*batch_size)
        if CUDA:
            decoder_input = decoder_input.cuda()
        return decoder_input

In [10]:
encoder = EncoderRNN(inf.vocab_size_i,HIDDEN_SIZE)
decoder = DecoderRNN(HIDDEN_SIZE,inf.vocab_size_o)

In [11]:
def load_s2s(version):
    encoder.load_state_dict(torch.load("/data/weights/enc_%s.pkl"%(version)))
    decoder.load_state_dict(torch.load("/data/weights/dec_%s.pkl"%(version)))

If the following encounter error (it's because the trainning process is saving), try run it again.

In [13]:
load_s2s(VERSION)

In [14]:
def answer(question):
    encoder_hidden = encoder.initHidden(1)
    last_idx= decoder.initInput(1)
    print(encoder_hidden.size(),last_idx.size())
    encoder_output,encoder_hidden = encoder(question,encoder_hidden)

    decoder_hidden = encoder_hidden
    
    i = 0
    output=list()
    while i < MAX_LEN:
        if (i>0 and last_idx.item() == SOS_TOKEN):
            break
        decoder_input,decoder_hidden = decoder(last_idx,decoder_hidden)
        last_idx = torch.max(decoder_input,dim=-1)[1]
        output.append(last_idx.item())
        i += 1
    output_char = " ".join(np.vectorize(lambda x:inf.o_idx2char[x])(output).tolist())
    print(output_char)
    print("length:\t",len(output))

In [18]:
answer(inf.feed_encoder("很高兴认识你，哈哈哈哈"))

torch.Size([1, 1, 512]) torch.Size([1])
SOS_TOKEN
length:	 1
