In [1]:
import numpy as np
import torch
from torch.utils import data
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.nn.utils.rnn import pack_padded_sequence
import pickle

#  Data Handling

load data

In [2]:
ans = np.load("./datasets/m_idx_a_1.npy")
ques = np.load("./datasets/m_idx_q_1.npy")

with open('./datasets/metadata_1.pkl', 'rb') as f:
        metadata = pickle.load(f)

用torch的utils.data建立一來自numpy的dataset

In [3]:
class MaleDataset(data.Dataset): 
    def __init__(self,ques,ans):
        self.ques = ques
        self.ans = ans
    
    def __getitem__(self, index):
        ques_tensor = torch.from_numpy(self.ques[index]).long()
        ans_tensor = torch.from_numpy(self.ans[index]).long()
        
        return ques_tensor , ans_tensor
    
    def __len__(self):
        return 78119

將idx解析成文字

# Define Model

In [4]:
batch_size = 32

male_dataset = MaleDataset(ques,ans)
train_loader = torch.utils.data.DataLoader(dataset=male_dataset,
                                           batch_size=batch_size, 
                                           shuffle=True)

In [5]:
class Seq2Seq(nn.Module):
    def __init__(self,
                 src_voc_size=9000,
                 trg_voc_size=9000,
                 src_embedding_size=256,
                 trg_embedding_size=256,
                 enc_hidden_size=200,
                 dec_hidden_size=200):
        
        super(Seq2Seq, self).__init__()
        self.trg_embedding_size = trg_embedding_size
        self.dec_hidden_size = dec_hidden_size
        
        self.src_embedder = nn.Embedding(src_voc_size , src_embedding_size)
        self.encoder = nn.LSTM(src_embedding_size ,enc_hidden_size,3, batch_first=True,dropout=0.5)
        
        self.trg_embedder = nn.Embedding(trg_voc_size , trg_embedding_size)
        self.decoder = nn.LSTM(trg_embedding_size ,dec_hidden_size,3, batch_first=True,dropout=0.5)
        self.cls = nn.Linear(dec_hidden_size , trg_voc_size)
    
    def forward(self,source,target,feed_previous=False):
        batch_size = source.size()[0]
        src_em = self.src_embedder(source)
        trg_em = self.trg_embedder(target)
        
        _ , enc_state = self.encoder(src_em)
        
        GO = Variable(torch.zeros(batch_size,1,self.trg_embedding_size)).cuda()
        
        if feed_previous: #test phase
            logits_ = []
            inputs = GO
            h = enc_state
            for i in range(25):
                output , h = self.decoder(inputs,h)
                logits = self.cls(output.view(-1, self.dec_hidden_size))  # (1, vocab_size)
                logits_.append(logits)
                
                predicted = logits.max(1)[1]
                inputs = self.trg_embedder(predicted)
                    
            return torch.cat(logits_,0)
            
        else: #train phase
            dec_in = torch.cat([GO,trg_em[:,:-1,:]],1)
            outputs , _ = self.decoder(dec_in,enc_state)
            outputs = outputs.contiguous().view(-1,self.dec_hidden_size)
            logits = self.cls(outputs)
        
            return logits

In [6]:
model = Seq2Seq().cuda()

In [7]:
model.train()

Seq2Seq (
  (src_embedder): Embedding(9000, 256)
  (encoder): LSTM(256, 200, num_layers=3, batch_first=True, dropout=0.5)
  (trg_embedder): Embedding(9000, 256)
  (decoder): LSTM(256, 200, num_layers=3, batch_first=True, dropout=0.5)
  (cls): Linear (200 -> 9000)
)

# Training

In [8]:
train_op = optim.Adam(model.parameters() ,lr=3e-4)

In [9]:
epochs = 200
loss_hist = []
loss_ = 3
model.train()
for epoch in range(epochs):
    epoch_mean_loss = []

    for i , (q,a) in enumerate(train_loader):
        q = Variable(q).cuda()
        a = Variable(a).cuda()
   
        logits = model(q,a,feed_previous=False)
        _,predict = logits.max(1)
        
        loss = F.cross_entropy(logits ,a.view(-1))
        train_op.zero_grad()
        loss.backward()
        train_op.step()
        
        epoch_mean_loss.append(loss.data[0])
    
    loss_ = np.mean(epoch_mean_loss)
    loss_hist.append(loss_)
    if epoch % 10 == 0  or epoch == epochs-1:
        print "epoch:%s , loss:%s" % (epoch , loss_ )
    if epoch % 50 == 0 or epoch == epochs-1:
        torch.save(model.state_dict() , 'pth/model_male_epo%s.pth'%epoch) #save model
        
np.save('loss_male_epo%s.npy'%epochs,loss_hist)

epoch:0 , loss:2.1570013063
epoch:10 , loss:1.65180309672
epoch:20 , loss:1.54403085176
epoch:30 , loss:1.46688195653
epoch:40 , loss:1.40535816604
epoch:50 , loss:1.35423740828
epoch:60 , loss:1.31112265833
epoch:70 , loss:1.27307238745
epoch:80 , loss:1.2397340821
epoch:90 , loss:1.2094951676
epoch:100 , loss:1.18277289263
epoch:110 , loss:1.15926431267
epoch:120 , loss:1.13700090819
epoch:130 , loss:1.11663240521
epoch:140 , loss:1.09736820416
epoch:150 , loss:1.08017611609
epoch:160 , loss:1.06460089849
epoch:170 , loss:1.04964912494
epoch:180 , loss:1.03527451408
epoch:190 , loss:1.02238349497
epoch:199 , loss:1.01203030824


In [10]:
_,predict = logits.max(1)

# Print Result

In [11]:
class Vocab:
    def __init__(self,idx2word,word2idx):
        self.idx2word = idx2word
        self.word2idx = word2idx
        self.max_len = 25
        self.eos_idx = 8002
        self.EN_WHITELIST  = '0123456789abcdefghijklmnopqrstuvwxyz '             
            
    '''
    idx -> word with EOS
    '''        
    def decode_line(self,sentence_idx,remove_pad=True,remove_eos=True):  #sentence_idx: 1d_matrix     
        sentence = []
        for w in sentence_idx:
            if remove_eos and w==self.eos_idx:
                continue
            if remove_pad and w==0 : 
                continue
            sentence.append(self.idx2word[w])
            #if w==self.eos_idx:
            #    break
        sentence = ' '.join(sentence)
        return sentence
    
    def decode(self,sentence_idxs,remove_pad=True,remove_eos=True): #sentence_idxs: 2d_matrix 
        sentences = []
        for s in sentence_idxs: 
            sentences.append(self.decode_line(s,
                                              remove_pad=remove_pad,
                                              remove_eos=remove_eos))
        return sentences
            
    '''
    word -> idx with EOS
    '''
    def encode_line(self,sentence):  #sentence: 1d_matrix
        sentence = sentence.lower()
        s_list = ''.join([ ch for ch in sentence if ch in self.EN_WHITELIST ]).split()
        sentence_idx = []
        for w in s_list:
            sentence_idx.append(self.word2idx[w])
        n = len(sentence_idx)
        if  n > self.max_len:
            sentence_idx = sentence_idx[:self.max_len] 
        elif n < self.max_len:
            sentence_idx = sentence_idx + [self.eos_idx] + [0]*(self.max_len-n-1)  
        return sentence_idx
    
    def encode(self,sentences): #sentences: 2d_matrix   
        sentence_idxs = []
        for s in sentences: 
            sentence_idxs.append(self.encode_line(s))
        return np.array(sentence_idxs)
    
    def print_QA(self, ques , pred_ans, strd_ans):
        n = len(ques)
        for i in range(n):
            idxs = [ ques[i],  pred_ans[i] , strd_ans[i]]
            sents = vocab.decode(idxs)
            print('\nQ      :'+sents[0])  
            print('A      :'+sents[2])
            print('pred A :'+sents[1]) 
            
    def print_QA_1(self, ques , pred_ans_train, pred_ans_test, strd_ans):
        n = len(ques)
        for i in range(n):
            idxs = [ ques[i],  pred_ans_train[i], pred_ans_test[i] , strd_ans[i]]
            sents = vocab.decode(idxs)
            print('\nQ      :'+sents[0])  
            print('A      :'+sents[3])
            print('train A:'+sents[1])    
            print('test A :'+sents[2]) 
            
    def print_QA_2(self, ques , ans):
        n = len(ques)
        for i in range(n):
            idxs = [ ques[i], ans[i]]
            sents = vocab.decode(idxs)
            print('\nQ      :'+sents[0])  
            print('A      :'+sents[1])


In [12]:
predict.size()

torch.Size([175, 1])

In [16]:
n = 175/25
n

7

In [17]:
vocab = Vocab(metadata['idx2w'] , metadata['w2idx'])

### Try train corpus

In [18]:
pred_ans = predict.cpu().view(-1,n).data.numpy().T #predicted answer in train phase
strd_ans = a.cpu().view(-1,n).data.numpy().T #standard answer
ques     = q.cpu().view(-1,n).data.numpy().T #quenstions
vocab.print_QA(ques , pred_ans, strd_ans)


Q      :why suggs as doctor i know
A      :unk i another over i in but he you and eighteen
pred A :hey i unk milk i in but he unk

Q      :are what they was he
A      :get cant theres one could a i did had my party
pred A :unk really its one calvin a i was did navy of

Q      :you is can mrs joe request
A      :out million a unk get shock arthur had and to fucking
pred A :out tell a charge say hey arthur dont something corner

Q      :trouble it use unk planning shirt get
A      :of you car us professor ok enough i at this watching
pred A :of out unk to professor to he go wheres following

Q      :this what you ready unk in
A      :cannon you out empty the of be reason told least is and
pred A :here you up unk with of ill unk didnt the is

Q      :out river as and i there
A      :know us next i this with to you rick my im
pred A :know me unk i her gonna to you rick my

Q      :on certainly long waiting when didnt
A      :baby probably street pay thing you yeah believe so him vacation 

### Try test corpus

In [19]:
model.eval()
o = model(q,a,feed_previous=True) #logits
_,predict_test = o.max(1)
#vocab.decode(predict_test.cpu().view(-1,10).data.numpy().T,remove_eos=False,remove_pad=False)
pred_ans_test = predict_test.cpu().view(-1,n).data.numpy().T #predicted answer in test phase
vocab.print_QA_1(ques , pred_ans, pred_ans_test, strd_ans)


Q      :why suggs as doctor i know
A      :unk i another over i in but he you and eighteen
train A:hey i unk milk i in but he unk
test A :unk unk

Q      :are what they was he
A      :get cant theres one could a i did had my party
train A:unk really its one calvin a i was did navy of
test A :you know i dont know what youre talking about

Q      :you is can mrs joe request
A      :out million a unk get shock arthur had and to fucking
train A:out tell a charge say hey arthur dont something corner
test A :its a unk

Q      :trouble it use unk planning shirt get
A      :of you car us professor ok enough i at this watching
train A:of out unk to professor to he go wheres following
test A :oh i message i was just asking exchange

Q      :this what you ready unk in
A      :cannon you out empty the of be reason told least is and
train A:here you up unk with of ill unk didnt the is
test A :hey okay

Q      :out river as and i there
A      :know us next i this with to you rick my im
train A:know

# Try Chatting

In [20]:
lines = []
lines.append( 'you can do it'  )
lines.append( 'how are you'    )
lines.append( 'fuck you'  )
lines.append( 'jesus christ you scared the shit out of me'  )
lines.append( 'youre terrible'  )
lines.append( 'is something wrong' )
lines.append( 'nobodys gonna get inside' )
lines.append( 'im sorry'  )
lines.append( 'shut up'  )
N = len(lines)
lines = vocab.encode(lines)
q_o = Variable(torch.from_numpy(lines).long()).cuda()
#vocab.decode(vocab.encode(lines))

In [21]:
model.eval()
o = model(q_o,a[:N],feed_previous=True)
_,predict_o = o.max(1)
#vocab.decode(predict_o.cpu().view(-1,3).data.numpy().T)
pred_ans_o = predict_o.cpu().view(-1,N).data.numpy().T #predicted answer 
vocab.print_QA_2(lines, pred_ans_o)


Q      :you can do it
A      :its a good idea

Q      :how are you
A      :fine again you

Q      :fuck you
A      :i love you

Q      :jesus christ you scared the shit out of me
A      :you know what place of suit are graff

Q      :youre terrible
A      :im not control about the unk

Q      :is something wrong
A      :what is it

Q      :nobodys gonna get inside
A      :sure sir

Q      :im sorry
A      :what do you mean

Q      :shut up
A      :whos the dallas


# Save

In [None]:
epoch

In [None]:
#np.save('loss_female_prob_feed_epo141.npy',epoch_mean_loss)

In [None]:
#torch.save(model.state_dict() , 'model_female_prob_feed_epo141.pth') #save model

# Back up

In [None]:
a

In [None]:
x = Variable(torch.rand(3,25)*200).long()
x

In [None]:
_,p = model(x,q,feed_previous=True).max(1)

In [None]:
vocab.decode(p.view(3,25).data.numpy())