# Sequence to Sequence ChatBot Inference

Inference on seq2seq chatbot, Using CPU

In [1]:
import torch
from torch import nn
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from jieba import cut
from p3self.lprint import lprint
from multiprocessing import Pool
from collections import Counter
import pandas as pd
import numpy as np

In [2]:
from constants import *
CUDA = False

In [3]:
class s2s_data(Dataset):
    def __init__(self,load_io, vocab_in, vocab_out, seq_addr, build_seq=False,
                 build_vocab = False,):
        """
        vocab_in,vocab_out are csv file addresses
        """
        self.load_io=load_io
        self.vocab_in = vocab_in
        self.vocab_out = vocab_out
        self.seq_addr = seq_addr
        
        print("[Loading the sequence data]")
        
        if build_seq:
            self.i,self.o = self.load_io()
            np.save(self.seq_addr,[self.i,self.o])
        else:
            [self.i,self.o] = np.load(self.seq_addr).tolist()
        print("[Sequence data loaded]")
            
        assert len(self.i)==len(self.o),"input seq length mush match output seq length"
        
        self.N = len(self.i)
        print("Length of sequence:\t",self.N)
        
        if build_vocab:
            self.vocab_i = self.build_vocab(self.i)
            self.vocab_o = self.build_vocab(self.o)
            
            self.vocab_i.to_csv(self.vocab_in)
            self.vocab_o.to_csv(self.vocab_out)
            
            self.print_vocab_info()
        else:
            self.vocab_i = pd.read_csv(self.vocab_in).fillna("")
            self.vocab_o = pd.read_csv(self.vocab_out).fillna("")
                  
            self.print_vocab_info()
        
        print("building mapping dicts")
        self.i_char2idx,self.i_idx2char = self.get_mapping(self.vocab_i)
        self.o_char2idx,self.o_idx2char = self.get_mapping(self.vocab_o)
        
    def __len__(self):
        return self.N
    
    def __getitem__(self,idx):
        return self.seq2idx(self.i[idx],self.mapfunc_i),self.seq2idx(self.o[idx],self.mapfunc_o)
    
    def get_full_token(self,list_of_tokens):
        """
        From a list of list of tokens, to a long list of tokens, duplicate tokens included
        """
        return (" ".join(list_of_tokens)).split(" ")
    
    def get_mapping(self,vocab_df):
        char2idx=dict(zip(vocab_df["token"],vocab_df["idx"]))
        idx2char=dict(zip(vocab_df["idx"],vocab_df["token"]))
        return char2idx,idx2char
    
    def seq2idx(self,x,mapfunc):
        return np.vectorize(mapfunc)(x.split(" ")).tolist()
    
    def mapfunc_i(self,x):
        try:
            return self.i_char2idx[x]
        except:
            return 2
        
    def mapfunc_o(self,x):
        try:
            return self.o_char2idx[x]
        except:
            return 2
        
    def get_token_count_dict(self,full_token):
        """count the token to a list"""
        return Counter(full_token)
    
    def build_vocab(self,seq_list):
        ct_dict = self.get_token_count_dict(self.get_full_token(seq_list))
        ct_dict["SOS_TOKEN"] = 9e9
        ct_dict["EOS_TOKEN"] = 8e9
        ct_dict[" "] = 7e9
        tk,ct = list(ct_dict.keys()),list(ct_dict.values())
        
        token_df=pd.DataFrame({"token":tk,"count":ct}).sort_values(by="count",ascending=False)
        return token_df.reset_index().drop("index",axis=1).reset_index().rename(columns={"index":"idx"}).fillna("")
    
    def print_vocab_info(self):
        self.vocab_size_i = len(self.vocab_i)
        self.vocab_size_o = len(self.vocab_o)
        
        print("[in seq vocab address]: %s,\t%s total lines"%(self.vocab_in,self.vocab_size_i))
        print("[out seq vocab address]: %s,\t%s total lines"%(self.vocab_out,self.vocab_size_o))
            
        print("Input sequence vocab samples:")
        print(self.vocab_i.sample(5))
        print("Output sequence vocab samples:")
        print(self.vocab_o.sample(5))

In [4]:
def load_empty():
    return list(range(5)),list(range(5))

In [5]:
class inf_s2s(s2s_data):
    def __init__(self,vocab_in, vocab_out):
        super(inf_s2s,self).__init__(load_empty, vocab_in, vocab_out, seq_addr="/data/chat/empty.npy", build_seq=True,
                 build_vocab = False,)
        
    def feed_encoder(self,x):
        if CN_SEG:
            x_list = list(cut(x))
        else:
            x_list = list(str(x))
        arr = np.array(self.seq2idx(" ".join(x_list),self.mapfunc_o))
        return torch.LongTensor(arr).unsqueeze(0)
        

In [6]:
inf=inf_s2s(vocab_in = DICT_IN,
         vocab_out = DICT_OUT,)
# inf=inf_s2s(vocab_in = DICT_IN,
#          vocab_out = "/data/dict/chat_vocab_char_out.csv",)

[Loading the sequence data]
[Sequence data loaded]
Length of sequence:	 5
[in seq vocab address]: /data/dict/chat_vocab_in.csv,	5748 total lines
[out seq vocab address]: /data/dict/chat_vocab_out.csv,	5635 total lines
Input sequence vocab samples:
      Unnamed: 0   idx  count token
3232        3232  3232    9.0     捧
2156        2156  2156   42.0     胞
947          947   947  283.0     顺
3022        3022  3022   12.0     仨
3776        3776  3776    5.0     挚
Output sequence vocab samples:
      Unnamed: 0   idx   count token
673          673   673   880.0     嘴
2273        2273  2273    68.0     噎
3273        3273  3273    15.0     讶
569          569   569  1085.0     7
5258        5258  5258     1.0     怆
building mapping dicts


In [7]:
inf.feed_encoder("很高兴认识你")

tensor([[  67,  185,  795,  280,  332,    6]])

In [8]:
from models import EncoderRNN_GRU as EncoderRNN
from models import DecoderRNN_GRU as DecoderRNN

encoder = EncoderRNN(inf.vocab_size_i,HIDDEN_SIZE,n_layers = NB_LAYER)
decoder = DecoderRNN(HIDDEN_SIZE,inf.vocab_size_o,n_layers = NB_LAYER)

encoder.cuda_ = False
decoder.cuda_ = False

def load_s2s(version):
    encoder.load_state_dict(torch.load("/data/weights/enc_%s.pkl"%(version)))
    decoder.load_state_dict(torch.load("/data/weights/dec_%s.pkl"%(version)))

If the following encounter error (it's because the trainning process is saving), try run it again.

In [75]:
load_s2s(VERSION)

In [76]:
def answer(question):
    encoder_hidden = encoder.initHidden(1)
    last_idx= decoder.initInput(1)
    encoder_output,encoder_hidden = encoder(question,encoder_hidden)

    decoder_hidden = encoder_hidden
    
    i = 0
    output=list()
    while i < MAX_LEN:
        if (i>0 and last_idx.item() == SOS_TOKEN):
            break
        decoder_input,decoder_hidden = decoder(last_idx,decoder_hidden)
        last_idx = torch.max(decoder_input,dim=-1)[1]
        output.append(last_idx.item())
        i += 1
    output_char = " ".join(np.vectorize(lambda x:inf.o_idx2char[x])(output).tolist())
    print(output_char)
    print("length:\t",len(output))

In [77]:
answer(inf.feed_encoder("为什么事情会这样"))

那 有 呢 呢 到 律 法 规 了 了 ， 还 不 然 了 不 到 的 着 打
length:	 20


In [78]:
answer(inf.feed_encoder("你叫什么名字"))

哎 呦 呦 哎 … … … ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~
length:	 20


In [79]:
answer(inf.feed_encoder("你喜欢不喜欢我"))

=   = SOS_TOKEN
length:	 5


In [80]:
answer(inf.feed_encoder("你把我灌醉"))

有 个 人 醒 呢 ， 主 人 醒 爱 你 的 那 个 快 ~ ~ ~ ~ ~
length:	 20


In [81]:
answer(inf.feed_encoder("等你等到我心碎"))

n 你 么 么 SOS_TOKEN
length:	 5


In [82]:
answer(inf.feed_encoder("唱首歌给我听"))

=   = SOS_TOKEN
length:	 5


In [83]:
answer(inf.feed_encoder("你是男的还是女的"))

你 有 呢 呢 呢 d 呢 纸 为 SOS_TOKEN
length:	 10


In [84]:
answer(inf.feed_encoder("去你大爷的"))

哎 呦 不 要 哎 哎 啊 切 糕 糕 糕 糕 糕 糕 糕 厉 厉 厉 厉 ！
length:	 20
