schedul sampling

In [1]:
import numpy as np
import torch
from torch.utils import data
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.nn.utils.rnn import pack_padded_sequence
import pickle
import random

#  Data Handling

load data

In [2]:
ans = np.load("./datasets/f_idx_a_1.npy")
ques = np.load("./datasets/f_idx_q_1.npy")

with open('./datasets/metadata_1.pkl', 'rb') as f:
        metadata = pickle.load(f)

用torch的utils.data建立一來自numpy的dataset

In [3]:
class FemaleDataset(data.Dataset): 
    def __init__(self,ques,ans):
        self.ques = ques
        self.ans = ans
    
    def __getitem__(self, index):
        ques_tensor = torch.from_numpy(self.ques[index]).long()
        ans_tensor = torch.from_numpy(self.ans[index]).long()
        
        return ques_tensor , ans_tensor
    
    def __len__(self):
        return 33589

將idx解析成文字

# Define Model

In [4]:
batch_size = 32

female_dataset = FemaleDataset(ques,ans)
train_loader = torch.utils.data.DataLoader(dataset=female_dataset,
                                           batch_size=batch_size, 
                                           shuffle=True)

In [5]:
class Seq2Seq(nn.Module):
    def __init__(self,
                 src_voc_size=9000,
                 trg_voc_size=9000,
                 src_embedding_size=256,
                 trg_embedding_size=256,
                 enc_hidden_size=200,
                 dec_hidden_size=200):
        
        super(Seq2Seq, self).__init__()
        self.trg_embedding_size = trg_embedding_size
        self.dec_hidden_size = dec_hidden_size
        
        self.src_embedder = nn.Embedding(src_voc_size , src_embedding_size)
        self.encoder = nn.LSTM(src_embedding_size ,enc_hidden_size,3, batch_first=True,dropout=0.5)
        
        self.trg_embedder = nn.Embedding(trg_voc_size , trg_embedding_size)
        self.decoder = nn.LSTM(trg_embedding_size ,dec_hidden_size,3, batch_first=True,dropout=0.5)
        self.cls = nn.Linear(dec_hidden_size , trg_voc_size)
    
    def forward(self,source,target,feed_previous=False):
        batch_size = source.size()[0]
        src_em = self.src_embedder(source)
        trg_em = self.trg_embedder(target)
        
        _ , enc_state = self.encoder(src_em)
        
        GO = Variable(torch.zeros(batch_size,1,self.trg_embedding_size)).cuda()
        
        if feed_previous: #test phase
            logits_ = []
            inputs = GO
            h = enc_state
            for i in range(25):
                output , h = self.decoder(inputs,h)
                logits = self.cls(output.view(-1, self.dec_hidden_size))  # (1, vocab_size)
                logits_.append(logits)
                
                predicted = logits.max(1)[1]
                inputs = self.trg_embedder(predicted)
                    
            return torch.cat(logits_,0)
            
        else: #train phase -- schedule sampling
            '''
            dec_in = torch.cat([GO,trg_em[:,:-1,:]],1)  # trg_em.shape=(batch_size, time step, trg_embedding_size )
            outputs , _ = self.decoder(dec_in,enc_state)
            outputs = outputs.contiguous().view(-1,self.dec_hidden_size)
            logits = self.cls(outputs)        
            return logits  
            '''
            logits_ = []
            dec_in = torch.cat([GO,trg_em[:,:-1,:]],1)
            
            h = enc_state
            for i in range(25):
                inputs = torch.unsqueeze(dec_in[:,i,:],1)
                
                if i < 5:
                    inputs =  torch.unsqueeze(dec_in[:,i,:],1)
                else:
                    if random.random() < 0.5 : 
                        inputs =  torch.unsqueeze(dec_in[:,i,:],1)  #usual training policy 
                    else:
                        inputs = self.trg_embedder(predicted)  #schedule sampling
                
                output , h = self.decoder(inputs,h)
                logits = self.cls(output.view(-1, self.dec_hidden_size))  # (1, vocab_size)
                logits_.append(logits)
                predicted = logits.max(1)[1]  # predicted.shape=(batch_size, time step=1)
                
            return torch.cat(logits_,0)

In [6]:
model = Seq2Seq().cuda()

In [7]:
model.train()

Seq2Seq (
  (src_embedder): Embedding(9000, 256)
  (encoder): LSTM(256, 200, num_layers=3, batch_first=True, dropout=0.5)
  (trg_embedder): Embedding(9000, 256)
  (decoder): LSTM(256, 200, num_layers=3, batch_first=True, dropout=0.5)
  (cls): Linear (200 -> 9000)
)

# Training

In [8]:
train_op = optim.Adam(model.parameters() ,lr=3e-4)

In [18]:
epochs = 200
loss_hist = []
loss_ = 3
model.train()
for epoch in range(epochs):
    epoch_mean_loss = []

    for i , (q,a) in enumerate(train_loader):
        q = Variable(q).cuda()
        a = Variable(a).cuda()
   
        logits = model(q,a,feed_previous=False)
        _,predict = logits.max(1)
        
        loss = F.cross_entropy(logits ,torch.transpose(a,0,1).contiguous().view(-1))
        train_op.zero_grad()
        loss.backward()
        train_op.step()
        
        epoch_mean_loss.append(loss.data[0])
    
    loss_ = np.mean(epoch_mean_loss)
    loss_hist.append(loss_)
    if epoch % 10 == 0  or epoch == epochs-1:
        print "epoch:%s , loss:%s" % (epoch , loss_ )
    if epoch % 50 == 0 or epoch == epochs-1:
        torch.save(model.state_dict() , 'pth/model_female_sche_sampling_epo%s.pth'%epoch) #save model
        
np.save('loss_female_sche_sampling_epo%s.npy'%epochs,loss_hist)

epoch:0 , loss:2.13266061488
epoch:10 , loss:1.89427414065
epoch:20 , loss:1.78954414686
epoch:30 , loss:1.70985880046
epoch:40 , loss:1.63794595003
epoch:50 , loss:1.58399377034
epoch:60 , loss:1.53486273612
epoch:70 , loss:1.48763798061
epoch:80 , loss:1.44421974727
epoch:90 , loss:1.40060243976
epoch:100 , loss:1.361112371
epoch:110 , loss:1.3247071609
epoch:120 , loss:1.28968011345
epoch:130 , loss:1.25642937155
epoch:140 , loss:1.22447708692
epoch:150 , loss:1.19096727683
epoch:160 , loss:1.16370675944


KeyboardInterrupt: 

In [19]:
_,predict = logits.max(1)

In [20]:
class Vocab:
    def __init__(self,idx2word,word2idx):
        self.idx2word = idx2word
        self.word2idx = word2idx
        self.max_len = 25
        self.eos_idx = 8002
        self.EN_WHITELIST  = '0123456789abcdefghijklmnopqrstuvwxyz '             
            
    '''
    idx -> word with EOS
    '''        
    def decode_line(self,sentence_idx,remove_pad=True,remove_eos=True):  #sentence_idx: 1d_matrix     
        sentence = []
        for w in sentence_idx:
            if remove_eos and w==self.eos_idx:
                continue
            if remove_pad and w==0 : 
                continue
            sentence.append(self.idx2word[w])
            #if w==self.eos_idx:
            #    break
        sentence = ' '.join(sentence)
        return sentence
    
    def decode(self,sentence_idxs,remove_pad=True,remove_eos=True): #sentence_idxs: 2d_matrix 
        sentences = []
        for s in sentence_idxs: 
            sentences.append(self.decode_line(s,
                                              remove_pad=remove_pad,
                                              remove_eos=remove_eos))
        return sentences
            
    '''
    word -> idx with EOS
    '''
    def encode_line(self,sentence):  #sentence: 1d_matrix
        sentence = sentence.lower()
        s_list = ''.join([ ch for ch in sentence if ch in self.EN_WHITELIST ]).split()
        sentence_idx = []
        for w in s_list:
            sentence_idx.append(self.word2idx[w])
        n = len(sentence_idx)
        if  n > self.max_len:
            sentence_idx = sentence_idx[:self.max_len] 
        elif n < self.max_len:
            sentence_idx = sentence_idx + [self.eos_idx] + [0]*(self.max_len-n-1)  
        return sentence_idx
    
    def encode(self,sentences): #sentences: 2d_matrix   
        sentence_idxs = []
        for s in sentences: 
            sentence_idxs.append(self.encode_line(s))
        return np.array(sentence_idxs)
    
    def print_QA(self, ques , pred_ans, strd_ans):
        n = len(ques)
        for i in range(n):
            idxs = [ ques[i],  pred_ans[i] , strd_ans[i]]
            sents = vocab.decode(idxs,remove_eos=True,remove_pad=True)
            print('\nQ      :'+sents[0])  
            print('A      :'+sents[2])
            print('pred A :'+sents[1]) 
            
    def print_QA_1(self, ques , pred_ans_train, pred_ans_test, strd_ans):
        n = len(ques)
        for i in range(n):
            idxs = [ ques[i],  pred_ans_train[i], pred_ans_test[i] , strd_ans[i]]
            sents = vocab.decode(idxs,remove_eos=True,remove_pad=True)
            print('\nQ      :'+sents[0])  
            print('A      :'+sents[3])
            print('train A:'+sents[1])    
            print('test A :'+sents[2]) 
            
    def print_QA_2(self, ques , ans):
        n = len(ques)
        for i in range(n):
            idxs = [ ques[i], ans[i]]
            sents = vocab.decode(idxs)
            print('\nQ      :'+sents[0])  
            print('A      :'+sents[1])


In [21]:
n = 800/25
n

32

In [22]:
vocab = Vocab(metadata['idx2w'] , metadata['w2idx'])

### Try train corpus

In [23]:
pred_ans = predict.cpu().view(-1,n).data.numpy().T #predicted answer in train phase
strd_ans = a.cpu().view(-1,n).data.numpy().T #standard answer
ques     = q.cpu().view(-1,n).data.numpy().T #quenstions
vocab.print_QA(ques , pred_ans, strd_ans)


Q      :hey thought i to dont something my it
A      :i do even of live what crazy it theyre take
pred A :i really not

Q      :fool it care you hate you cotton has
A      :together is want it are in i does unk off
pred A :i said to halloween i i kronos was the awful i i i i i unk i was a unk to you

Q      :this was about should you want well told a
A      :it jack it you that know a unk call
pred A :thanks i

Q      :come sense that to okay me unk its
A      :sort it sure house the shut no me
pred A :schools moneys unk woman isnt a neither

Q      :from it i do i i all him
A      :other why would hes i hang honey if
pred A :i dont not being disrespect woman

Q      :did pretty about would would right
A      :eat should look dead need seen free its you
pred A :you you i

Q      :not bound it like meet i
A      :and i stupid to a right i the need
pred A :no i i

Q      :the with you somebody have what
A      :theyve trust on know unk and was police anything the
pred A :in keeping the 

### Try test corpus

In [24]:
model.eval()
o = model(q,a,feed_previous=True) #logits
_,predict_test = o.max(1)
#vocab.decode(predict_test.cpu().view(-1,10).data.numpy().T,remove_eos=False,remove_pad=False)
pred_ans_test = predict_test.cpu().view(-1,n).data.numpy().T #predicted answer in test phase
vocab.print_QA_1(ques , pred_ans, pred_ans_test, strd_ans)


Q      :hey thought i to dont something my it
A      :i do even of live what crazy it theyre take
train A:i really not
test A :i know what you mean

Q      :fool it care you hate you cotton has
A      :together is want it are in i does unk off
train A:i said to halloween i i kronos was the awful i i i i i unk i was a unk to you
test A :why do you want to to do me

Q      :this was about should you want well told a
A      :it jack it you that know a unk call
train A:thanks i
test A :i dont know

Q      :come sense that to okay me unk its
A      :sort it sure house the shut no me
train A:schools moneys unk woman isnt a neither
test A :a precious cheap

Q      :from it i do i i all him
A      :other why would hes i hang honey if
train A:i dont not being disrespect woman
test A :i dont know

Q      :did pretty about would would right
A      :eat should look dead need seen free its you
train A:you you i
test A :no you dont want to make poets credit unk unk you you you

Q      :not bound it

# Try Chatting

In [25]:
lines = []
lines.append( 'you can do it'  )
lines.append( 'how are you'    )
lines.append( 'fuck you'  )
lines.append( 'jesus christ you scared the shit out of me'  )
lines.append( 'youre terrible'  )
lines.append( 'is something wrong' )
lines.append( 'nobodys gonna get inside' )
lines.append( 'im sorry'  )
lines.append( 'shut up'  )
N = len(lines)
lines = vocab.encode(lines)
q_o = Variable(torch.from_numpy(lines).long()).cuda()
#vocab.decode(vocab.encode(lines))

In [26]:
model.eval()
o = model(q_o,q_o,feed_previous=True)
_,predict_o = o.max(1)
#vocab.decode(predict_o.cpu().view(-1,3).data.numpy().T)
pred_ans_o = predict_o.cpu().view(-1,N).data.numpy().T #predicted answer 
vocab.print_QA_2(lines, pred_ans_o)


Q      :you can do it
A      :yeah i know

Q      :how are you
A      :im fine

Q      :fuck you
A      :i cant help it michigan is a here

Q      :jesus christ you scared the shit out of me
A      :well its a nice kids thats a

Q      :youre terrible
A      :i know but i thought to you it i i you you you you you

Q      :is something wrong
A      :i know

Q      :nobodys gonna get inside
A      :i was unk to unk unk tries and meet her dumb

Q      :im sorry
A      :i know

Q      :shut up
A      :i have to find the and i didnt want to to be unk


# Save

In [27]:
epoch

164

In [28]:
torch.save(model.state_dict() , 'pth/model_female_sche_sampling_epo%s.pth'%epoch) #save model     
np.save('loss_female_sche_sampling_epo%s.npy'%epochs,loss_hist)

# Back up

In [None]:
a

In [None]:
x = Variable(torch.rand(3,25)*200).long()
x

In [None]:
_,p = model(x,q,feed_previous=True).max(1)

In [None]:
vocab.decode(p.view(3,25).data.numpy())