In [2]:
# code by Tae Hwan Jung(Jeff Jung) @graykode, modify by wmathor
import torch
import numpy as np
import torch.nn as nn
import torch.utils.data as Data
import torch.nn.functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
# S: Symbol that  shows starting of decoding input
# E: Symbol that shows starting of decoding output
# ?: Symbol that will fill in blank sequence if current batch data size is short than n_step

cpu


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
letter = [c for c in 'SE?abcdefghijklmnopqrstuvwxyz']
letter2idx = {n: i for i, n in enumerate(letter)}

seq_data = [['man', 'women'], ['black', 'white'], ['king', 'queen'], ['girl', 'boy'], ['up', 'down'], ['high', 'low']]

# Seq2Seq Parameter
n_step = max([max(len(i), len(j)) for i, j in seq_data]) # max_len(=5)
n_class = len(letter2idx) # classfication problem
batch_size = 1

In [5]:
def make_data(seq_data):
    enc_input_all, dec_input_all, dec_output_all = [], [], []

    for seq in seq_data:
        for i in range(2):
            seq[i] = seq[i] + '?' * (n_step - len(seq[i])) # 'man??', 'women'

        enc_input = [letter2idx[n] for n in (seq[0] + 'E')] # ['m', 'a', 'n', '?', '?', 'E']
        dec_input = [letter2idx[n] for n in ('S' + seq[1])] # ['S', 'w', 'o', 'm', 'e', 'n']
        dec_output = [letter2idx[n] for n in (seq[1] + 'E')] # ['w', 'o', 'm', 'e', 'n', 'E']
        
        enc_input_all.append(np.eye(n_class)[enc_input])#6,6,29
        dec_input_all.append(np.eye(n_class)[dec_input])
        dec_output_all.append(dec_output) # not one-hot

    # make tensor
    return torch.Tensor(enc_input_all), torch.Tensor(dec_input_all), torch.LongTensor(dec_output_all)

'''
enc_input_all: [6, n_step+1 (because of 'E'), n_class]
dec_input_all: [6, n_step+1 (because of 'S'), n_class]
dec_output_all: [6, n_step+1 (because of 'E')]
'''
enc_input_all, dec_input_all, dec_output_all = make_data(seq_data)
dec_output_all

tensor([[25, 17, 15,  7, 16,  1],
        [25, 10, 11, 22,  7,  1],
        [19, 23,  7,  7, 16,  1],
        [ 4, 17, 27,  2,  2,  1],
        [ 6, 17, 25, 16,  2,  1],
        [14, 17, 25,  2,  2,  1]])

In [6]:
#训练时dec_input_all是正确结果，该训练方式称为teacher forcing（teacher forcing只能用于监督学习中吗）
class TranslateDataSet(Data.Dataset):
    def __init__(self, enc_input_all, dec_input_all, dec_output_all):
        self.enc_input_all = enc_input_all
        self.dec_input_all = dec_input_all
        self.dec_output_all = dec_output_all
    
    def __len__(self): # return dataset size
        return len(self.enc_input_all)
    
    def __getitem__(self, idx):
        return self.enc_input_all[idx], self.dec_input_all[idx], self.dec_output_all[idx]

loader = Data.DataLoader(TranslateDataSet(enc_input_all, dec_input_all, dec_output_all), batch_size, True)

In [7]:
from myNet import MultiHeadAttention
num_hiddens, num_heads = 50, 1
key_size, query_size, value_size = 29,29,29
model = MultiHeadAttention(key_size,query_size,value_size,num_hiddens,num_heads,0.5,bias=True).to(device)
#print(model.eval())
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
#print(model)

In [8]:
for epoch in range(10000):
    loss = 0
    optimizer.zero_grad()
    for enc_input_batch, dec_input_batch, dec_output_batch in loader:
      (enc_input_batch, dec_intput_batch, dec_output_batch) = (enc_input_batch.to(device), dec_input_batch.to(device), dec_output_batch.to(device))
      # enc_input_batch : [batch_size, n_step+1, n_class]
      # dec_intput_batch : [batch_size, n_step+1, n_class]
      # dec_output_batch : [batch_size, n_step+1], not one-hot
      pred = model(dec_intput_batch,enc_input_batch, enc_input_batch,None)
      for i in range(len(dec_output_batch)):
          # pred[i] : [n_step+1, n_class]
          # dec_output_batch[i] : [n_step+1]
          loss += criterion(pred[i], dec_output_batch[i])
    if (epoch + 1) % 1000 == 0:
        print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
        
    
    loss.backward()
    optimizer.step()

Epoch: 1000 cost = 16.208218
Epoch: 2000 cost = 14.647017
Epoch: 3000 cost = 12.937181
Epoch: 4000 cost = 12.469214
Epoch: 5000 cost = 10.680637
Epoch: 6000 cost = 10.202083
Epoch: 7000 cost = 10.212413
Epoch: 8000 cost = 8.516916
Epoch: 9000 cost = 6.614962
Epoch: 10000 cost = 6.697484


In [9]:
torch.save(model, "mha")
model = torch.load("mha")

In [10]:

model.eval()

MultiHeadAttention(
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (W_q): Linear(in_features=29, out_features=50, bias=True)
  (W_k): Linear(in_features=29, out_features=50, bias=True)
  (W_v): Linear(in_features=29, out_features=50, bias=True)
  (W_o): Linear(in_features=50, out_features=29, bias=True)
)

In [12]:
# Test
def translate(word):
    enc_input, dec_input, _ = make_data([[word, '?' * n_step]])
    enc_input, dec_input = enc_input.to(device), dec_input.to(device)
    # make hidden shape [num_layers * num_directions, batch_size, n_hidden]
    #hidden = torch.zeros(1, 1, n_hidden).to(device)
    output = model(dec_input,enc_input,enc_input,None)
    # output : [n_step+1, batch_size, n_class]
 
    predict = output.data.max(2, keepdim=True)[1] # select n_class dimension
    decoded = [letter[i] for i in predict[0]]

    #translated = ''.join(decoded[:decoded.index('E')])
    translated = ''.join(decoded)
    return translated.replace('?', '')

print('test')
print('man ->', translate('man'))
print('mans ->', translate('mans'))
print('king ->', translate('king'))
print('black ->', translate('black'))
print('up ->', translate('up'))
#print('left->',translate('left'))

test


In [4]:
class Att(nn.Module):
    def __init__(self,encoder_input_size,decoder_input_size, num_hiddens, **kwargs):
        super(Att, self).__init__(**kwargs)
        self.encoder_input_size = encoder_input_size
        self.W_k = nn.Linear(encoder_input_size, num_hiddens)
        self.W_v = nn.Linear(encoder_input_size, num_hiddens)
        self.W_q = nn.Linear(decoder_input_size, num_hiddens)
        self.attn = nn.Linear(num_hiddens * 2, num_hiddens)
    def forward(self,encoder_input,decoder_input):
        k = self.w_k(encoder_input)
        v = self.w_v(encoder_input)
        q = self.w_q(decoder_input)
        attn_weights =F.softmax(self.attn(torch.cat([k,q], dim=2)))
        context = torch.bmm(attn_weights.unsqueeze(1), v)
        return F.softmax(context,dim=-1)