# Sequence to Sequence ChatBot Inference

Inference on seq2seq chatbot, Using CPU

In [1]:
import torch
from torch import nn
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from jieba import cut
from p3self.lprint import lprint
from multiprocessing import Pool
from collections import Counter
import pandas as pd
import numpy as np

In [16]:
BS = 256# Batch size

VOCAB_SEQ_IN = 3000
VOCAB_SEQ_OUT = 3000

SOS_TOKEN = 0
EOS_TOKEN = 1

LR = 5e-3
HIDDEN_SIZE = 256
MAX_LEN = 20

VERSION = "0.0.1"

CUDA = False

In [3]:
class s2s_data(Dataset):
    def __init__(self,load_io, vocab_in, vocab_out, seq_addr, build_seq=False,
                 build_vocab = False,):
        """
        vocab_in,vocab_out are csv file addresses
        """
        self.load_io=load_io
        self.vocab_in = vocab_in
        self.vocab_out = vocab_out
        self.seq_addr = seq_addr
        
        print("[Loading the sequence data]")
        
        if build_seq:
            self.i,self.o = self.load_io()
            np.save(self.seq_addr,[self.i,self.o])
        else:
            [self.i,self.o] = np.load(self.seq_addr).tolist()
        print("[Sequence data loaded]")
            
        assert len(self.i)==len(self.o),"input seq length mush match output seq length"
        
        self.N = len(self.i)
        print("Length of sequence:\t",self.N)
        
        if build_vocab:
            self.vocab_i = self.build_vocab(self.i)
            self.vocab_o = self.build_vocab(self.o)
            
            self.vocab_i.to_csv(self.vocab_in)
            self.vocab_o.to_csv(self.vocab_out)
            
            self.print_vocab_info()
        else:
            self.vocab_i = pd.read_csv(self.vocab_in).fillna("")
            self.vocab_o = pd.read_csv(self.vocab_out).fillna("")
                  
            self.print_vocab_info()
        
        print("building mapping dicts")
        self.i_char2idx,self.i_idx2char = self.get_mapping(self.vocab_i)
        self.o_char2idx,self.o_idx2char = self.get_mapping(self.vocab_o)
        
    def __len__(self):
        return self.N
    
    def __getitem__(self,idx):
        return self.seq2idx(self.i[idx],self.i_char2idx),self.seq2idx(self.o[idx],self.o_char2idx)
    
    def get_full_token(self,list_of_tokens):
        """
        From a list of list of tokens, to a long list of tokens, duplicate tokens included
        """
        return (" ".join(list_of_tokens)).split(" ")
    
    def get_mapping(self,vocab_df):
        char2idx=dict(zip(vocab_df["token"],vocab_df["idx"]))
        idx2char=dict(zip(vocab_df["idx"],vocab_df["token"]))
        return char2idx,idx2char
    
    def seq2idx(self,x,mapdict):
        return np.vectorize(lambda i:mapdict[i])(x.split(" ")).tolist()
    
    def get_token_count_dict(self,full_token):
        """count the token to a list"""
        return Counter(full_token)
    
    def build_vocab(self,seq_list):
        ct_dict = self.get_token_count_dict(self.get_full_token(seq_list))
        ct_dict["SOS_TOKEN"] = 9e9
        ct_dict["EOS_TOKEN"] = 8e9
        tk,ct = list(ct_dict.keys()),list(ct_dict.values())
        
        token_df=pd.DataFrame({"token":tk,"count":ct}).sort_values(by="count",ascending=False)
        return token_df.reset_index().drop("index",axis=1).reset_index().rename(columns={"index":"idx"}).fillna("")
    
    def print_vocab_info(self):
        self.vocab_size_i = len(self.vocab_i)
        self.vocab_size_o = len(self.vocab_o)
        
        print("[in seq vocab address]: %s,\t%s total lines"%(self.vocab_in,self.vocab_size_i))
        print("[out seq vocab address]: %s,\t%s total lines"%(self.vocab_out,self.vocab_size_o))
            
        print("Input sequence vocab samples:")
        print(self.vocab_i.sample(5))
        print("Output sequence vocab samples:")
        print(self.vocab_o.sample(5))

In [4]:
def load_empty():
    return list(range(5)),list(range(5))

In [5]:
class inf_s2s(s2s_data):
    def __init__(self,vocab_in, vocab_out):
        super(inf_s2s,self).__init__(load_empty, vocab_in, vocab_out, seq_addr="/data/chat/empty.npy", build_seq=True,
                 build_vocab = False,)
        
    def feed_encoder(self,x):
        arr = np.array(self.seq2idx(" ".join(list(str(x))),self.i_char2idx))
        return torch.LongTensor(arr).unsqueeze(0)
        

In [6]:
inf=inf_s2s(vocab_in = "/data/dict/chat_vocab_char_in.csv",
         vocab_out = "/data/dict/chat_vocab_char_out.csv",)

[Loading the sequence data]
[Sequence data loaded]
Length of sequence:	 5
[in seq vocab address]: /data/dict/chat_vocab_char_in.csv,	5747 total lines
[out seq vocab address]: /data/dict/chat_vocab_char_out.csv,	5634 total lines
Input sequence vocab samples:
      Unnamed: 0   idx  count token
955          955   955  278.0     短
5473        5473  5473    1.0     鑼
5175        5175  5175    1.0     淸
2480        2480  2480   26.0     い
1153        1153  1153  200.0     嗨
Output sequence vocab samples:
      Unnamed: 0   idx   count token
340          340   340  2099.0     害
3987        3987  3987     6.0     淄
4682        4682  4682     2.0     囔
3409        3409  3409    13.0     棠
5180        5180  5180     1.0     檢
building mapping dicts


In [7]:
inf.feed_encoder("很高兴认识你")

tensor([[ 104,  214,  610,  154,  192,    2]])

In [8]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, n_layers=1):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, num_layers=n_layers)
        
    def forward(self, input_, hidden):
        output, hidden = self.gru(self.embedding(input_), hidden)
        return output, hidden

    # TODO: other inits
    def initHidden(self, batch_size):
        en_hidden = torch.zeros(1, batch_size, self.hidden_size)
        if CUDA:
            en_hidden = en_hidden.cuda()
        return en_hidden

In [46]:
class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, n_layers=1):
        super(DecoderRNN, self).__init__()
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, num_layers=n_layers)
        # TODO use transpose of embedding
        self.out = nn.Linear(hidden_size, output_size)
        self.sm = nn.LogSoftmax(dim=-1)
        
    def forward(self, input_, hidden):
        emb = self.embedding(input_).unsqueeze(1)
        # NB: Removed relu
        res, hidden = self.gru(emb, hidden)
        output = self.sm(self.out(res[:,0]))
        return output, hidden
    
    def initInput(self,batch_size):
        decoder_input = torch.LongTensor([SOS_TOKEN]*batch_size)
        if CUDA:
            decoder_input = decoder_input.cuda()
        return decoder_input

In [47]:
encoder = EncoderRNN(inf.vocab_size_i,HIDDEN_SIZE)
decoder = DecoderRNN(HIDDEN_SIZE,inf.vocab_size_o)

In [48]:
import os

def load_s2s(version):
    encoder.load_state_dict(torch.load("/data/weights/enc_%s.pkl"%(version)))
    decoder.load_state_dict(torch.load("/data/weights/dec_%s.pkl"%(version)))

In [94]:
load_s2s(VERSION)

In [99]:
def answer(question):
    encoder_hidden = encoder.initHidden(1)
    decoder_input = decoder.initInput(1)
    print(encoder_hidden.size(),decoder_input.size())
    encoder_output,encoder_hidden = encoder(question,encoder_hidden)

    decoder_hidden = encoder_hidden
    
    i = 0
    output=list()
    last_idx = 9e9
    while (i < MAX_LEN and last_idx != EOS_TOKEN):
        last_idx,decoder_hidden = decoder(decoder_input,decoder_hidden)
        last_idx = torch.max(decoder_input,dim=-1)[1].item()
        output.append(last_idx)
        i += 1
    print(output)
    output_char = "".join(np.vectorize(lambda x:inf.o_idx2char[x])(output).tolist())
    print(output_char)

In [100]:
answer(inf.feed_encoder("你叫什么名字"))

torch.Size([1, 1, 256]) torch.Size([1])
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
SOS_TOKENSOS_TOKENSOS_TOKENSOS_TOKENSOS_TOKENSOS_TOKENSOS_TOKENSOS_TOKENSOS_TOKENSOS_TOKENSOS_TOKENSOS_TOKENSOS_TOKENSOS_TOKENSOS_TOKENSOS_TOKENSOS_TOKENSOS_TOKENSOS_TOKENSOS_TOKEN
