In [1]:
import numpy as np
import torch
from torch.utils import data
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.nn.utils.rnn import pack_padded_sequence
import pickle

In [2]:
class Seq2Seq(nn.Module):
    def __init__(self,
                 src_voc_size=9000,
                 trg_voc_size=9000,
                 src_embedding_size=256,
                 trg_embedding_size=256,
                 enc_hidden_size=200,
                 dec_hidden_size=200):
        
        super(Seq2Seq, self).__init__()
        self.trg_embedding_size = trg_embedding_size
        self.dec_hidden_size = dec_hidden_size
        
        self.src_embedder = nn.Embedding(src_voc_size , src_embedding_size)
        self.encoder = nn.LSTM(src_embedding_size ,enc_hidden_size,3, batch_first=True,dropout=0.5)
        
        self.trg_embedder = nn.Embedding(trg_voc_size , trg_embedding_size)
        self.decoder = nn.LSTM(trg_embedding_size ,dec_hidden_size,3, batch_first=True,dropout=0.5)
        self.cls = nn.Linear(dec_hidden_size , trg_voc_size)
    
    def forward(self,source,target,feed_previous=False):
        batch_size = source.size()[0]
        src_em = self.src_embedder(source)
        trg_em = self.trg_embedder(target)
        
        _ , enc_state = self.encoder(src_em)
        
        GO = Variable(torch.zeros(batch_size,1,self.trg_embedding_size))
        
        if feed_previous: #test phase
            logits_ = []
            inputs = GO
            h = enc_state
            for i in range(25):
                output , h = self.decoder(inputs,h)
                logits = self.cls(output.view(-1, self.dec_hidden_size))  # (1, vocab_size)
                logits_.append(logits)
                
                predicted = logits.max(1)[1]
                inputs = self.trg_embedder(predicted)
                    
            return torch.cat(logits_,0)
            
        else: #train phase
            dec_in = torch.cat([GO,trg_em[:,:-1,:]],1)
            outputs , _ = self.decoder(dec_in,enc_state)
            outputs = outputs.contiguous().view(-1,self.dec_hidden_size)
            logits = self.cls(outputs)
        
            return logits

In [3]:
female = Seq2Seq()
male = Seq2Seq()

In [4]:
female.load_state_dict(torch.load('model_female_cpu.pth'))
male.load_state_dict(torch.load('model_male_cpu.pth'))

In [12]:
female_sche = Seq2Seq()
female_sche.load_state_dict(torch.load('model_female_sche_samplling_cpu.pth'))

In [5]:
class Vocab:
    def __init__(self,idx2word,word2idx):
        self.idx2word = idx2word
        self.word2idx = word2idx
        self.max_len = 25
        self.eos_idx = 8002
        self.EN_WHITELIST  = '0123456789abcdefghijklmnopqrstuvwxyz '             
            
    '''
    idx -> word with EOS
    '''        
    def decode_line(self,sentence_idx,remove_pad=True,remove_eos=True):  #sentence_idx: 1d_matrix     
        sentence = []
        for w in sentence_idx:
            if remove_eos and w==self.eos_idx:
                continue
            if remove_pad and w==0 : 
                continue
            sentence.append(self.idx2word[w])
            #if w==self.eos_idx:
            #    break
        sentence = ' '.join(sentence)
        return sentence
    
    def decode(self,sentence_idxs,remove_pad=True,remove_eos=True): #sentence_idxs: 2d_matrix 
        sentences = []
        for s in sentence_idxs: 
            sentences.append(self.decode_line(s,
                                              remove_pad=remove_pad,
                                              remove_eos=remove_eos))
        return sentences
            
    '''
    word -> idx with EOS
    '''
    def encode_line(self,sentence):  #sentence: 1d_matrix
        sentence = sentence.lower()
        s_list = ''.join([ ch for ch in sentence if ch in self.EN_WHITELIST ]).split()
        sentence_idx = []
        for w in s_list:
            sentence_idx.append(self.word2idx[w])
        n = len(sentence_idx)
        if  n > self.max_len:
            sentence_idx = sentence_idx[:self.max_len] 
        elif n < self.max_len:
            sentence_idx = sentence_idx + [self.eos_idx] + [0]*(self.max_len-n-1)  
        return sentence_idx
    
    def encode(self,sentences): #sentences: 2d_matrix   
        sentence_idxs = []
        for s in sentences: 
            sentence_idxs.append(self.encode_line(s))
        return np.array(sentence_idxs)
    
    def print_QA(self, ques , pred_ans, strd_ans):
        n = len(ques)
        for i in range(n):
            idxs = [ ques[i],  pred_ans[i] , strd_ans[i]]
            sents = vocab.decode(idxs,remove_eos=True,remove_pad=True)
            print('\nQ      :'+sents[0])  
            print('A      :'+sents[2])
            print('pred A :'+sents[1]) 
            
    def print_QA_1(self, ques , pred_ans_train, pred_ans_test, strd_ans):
        n = len(ques)
        for i in range(n):
            idxs = [ ques[i],  pred_ans_train[i], pred_ans_test[i] , strd_ans[i]]
            sents = vocab.decode(idxs,remove_eos=True,remove_pad=True)
            print('\nQ      :'+sents[0])  
            print('A      :'+sents[3])
            print('train A:'+sents[1])    
            print('test A :'+sents[2]) 
            
    def print_QA_2(self, ques , ans):
        n = len(ques)
        for i in range(n):
            idxs = [ ques[i], ans[i]]
            sents = vocab.decode(idxs)
            print('\nQ      :'+sents[0])  
            print('A      :'+sents[1])

In [6]:
with open('./metadata_1.pkl', 'rb') as f:
        metadata = pickle.load(f)
        
vocab = Vocab(metadata['idx2w'] , metadata['w2idx'])

In [7]:
lines = []
lines.append( 'you can do it'  )
lines.append( 'how are you'    )
lines.append( 'fuck you'  )
lines.append( 'jesus christ you scared the shit out of me'  )
lines.append( 'youre terrible'  )
lines.append( 'is something wrong' )
lines.append( 'nobodys gonna get inside' )
lines.append( 'im sorry'  )
lines.append( 'shut up'  )
N = len(lines)
lines = vocab.encode(lines)
q_o = Variable(torch.from_numpy(lines).long())
#vocab.decode(vocab.encode(lines))

In [9]:
female.eval()
o = female(q_o,q_o,feed_previous=True)
_,predict_o = o.max(1)
#vocab.decode(predict_o.cpu().view(-1,3).data.numpy().T)
pred_ans_o = predict_o.view(-1,N).data.numpy().T #predicted answer 
vocab.print_QA_2(lines, pred_ans_o)


Q      :you can do it
A      :no i dont need maybe money i need to see you and i dont know why youre talking about the body

Q      :how are you
A      :im okay

Q      :fuck you
A      :fuck me fuck you

Q      :jesus christ you scared the shit out of me
A      :what things what do you mean

Q      :youre terrible
A      :i was just bunch about a word

Q      :is something wrong
A      :i want to be alone

Q      :nobodys gonna get inside
A      :thats beautiful

Q      :im sorry
A      :i thought you were saying you cant be a good mommy

Q      :shut up
A      :get out


In [10]:
male.eval()
o = male(q_o,q_o,feed_previous=True)
_,predict_o = o.max(1)
#vocab.decode(predict_o.cpu().view(-1,3).data.numpy().T)
pred_ans_o = predict_o.view(-1,N).data.numpy().T #predicted answer 
vocab.print_QA_2(lines, pred_ans_o)


Q      :you can do it
A      :its a good idea

Q      :how are you
A      :fine again you

Q      :fuck you
A      :i love you

Q      :jesus christ you scared the shit out of me
A      :you know what place of suit are graff

Q      :youre terrible
A      :im not control about the unk

Q      :is something wrong
A      :what is it

Q      :nobodys gonna get inside
A      :sure sir

Q      :im sorry
A      :what do you mean

Q      :shut up
A      :whos the dallas


In [13]:
female_sche.eval()
o = female_sche(q_o,q_o,feed_previous=True)
_,predict_o = o.max(1)
#vocab.decode(predict_o.cpu().view(-1,3).data.numpy().T)
pred_ans_o = predict_o.view(-1,N).data.numpy().T #predicted answer 
vocab.print_QA_2(lines, pred_ans_o)


Q      :you can do it
A      :yeah i know

Q      :how are you
A      :im fine

Q      :fuck you
A      :i cant help it michigan is a here

Q      :jesus christ you scared the shit out of me
A      :well its a nice kids thats a

Q      :youre terrible
A      :i know but i thought to you it i i you you you you you

Q      :is something wrong
A      :i know

Q      :nobodys gonna get inside
A      :i was unk to unk unk tries and meet her dumb

Q      :im sorry
A      :i know

Q      :shut up
A      :i have to find the and i didnt want to to be unk


In [40]:
def chat():
    while 1:
        #input_box
        r = raw_input('Question:To whom{m,f,_}')
        if ':' in r:
            x,m= r.split(':')
        else:
            x,m = r,''
        x,m = x.strip(), m.strip()
        if x=='' :
            break
        
        #decide model
        if m == 'm':
            model = male
        elif m== 'f':
            model = female
        else:
            model = female_sche
        
        #print answer
        lines = []
        lines.append( x )
        N = len(lines)
        lines = vocab.encode(lines)
        q_o = Variable(torch.from_numpy(lines).long()) 
        o = model(q_o,q_o,feed_previous=True)
        _,predict_o = o.max(1)
        pred_ans_o = predict_o.view(-1,N).data.numpy().T #predicted answer 
        vocab.print_QA_2(lines, pred_ans_o)
        print('')
        
    print('...end of conversation')

In [57]:
def two_gender_chatting(characters=['m','f']):

    
    #input_box
    r = raw_input('Srart conversation: ')
    x = r
    print('')

    for i in range(3):   
        for ch in characters: 
            #decide model
            if ch == 'm':
                model = male
                ch = 'm     '
            elif ch == 'f':
                model = female
                ch = 'f     '
            else:
                model = female_sche
                ch = 'f_sche'
                
            lines = vocab.encode([x])
            q_o = Variable(torch.from_numpy(lines).long()) 
            o = model(q_o,q_o,feed_previous=True)
            _,predict_o = o.max(1)
            pred_ans_o = predict_o.view(-1,1).data.numpy().T #predicted answer 
            x = vocab.decode(pred_ans_o)[0]
            print(ch+' :'+x)   
    
    print('\n...end of conversation')

In [58]:
chat()

Question:To whom{m,f,_}hi : m

Q      :hi
A      :hey hey minor you look great

Question:To whom{m,f,_}hi : f

Q      :hi
A      :thats good

Question:To whom{m,f,_}hi :

Q      :hi
A      :hi hi

Question:To whom{m,f,_}
...end of conversation


In [59]:
two_gender_chatting(['f_s','f'])

Srart conversation: hi

f_sche :hi hi
f      :thats time
f_sche :oh yeah
f      :he didnt have anyone
f_sche :no but he said hes unk
f      :oh my god

...end of conversation
