In [1]:
import torch
from torch import nn
from torchtext.data.utils import get_tokenizer
from collections import Counter
from torch.utils.data import DataLoader, Dataset
from torchtext.vocab import vocab
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import torch.optim as optim
import random
import time
import math
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

In [2]:
SEED = 1234

random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x10714dbf0>

In [3]:
device = torch.device('mps')
max_lines = None
batch_size = 64

In [4]:
lang1 = 'cn'
lang2 = 'en'

spacy_cn_tokenizer = get_tokenizer('spacy', language='zh_core_web_md')
spacy_en_tokenizer = get_tokenizer('spacy', language='en_core_web_md')

def cn_tokenizer(sentence):
    return [token for token in spacy_cn_tokenizer(sentence)]

def en_tokenizer(sentence):
    return [token.lower() for token in spacy_en_tokenizer(sentence)]

tokenizers = {'cn': cn_tokenizer, 'en': en_tokenizer}

def build_vocab(lang1='cn', lang2='en'):
    counter1 = Counter()
    counter2 = Counter()
    # Read the file and split into lines
    lines = open('data/%s-%s.txt' % (lang1, lang2), encoding='utf-8').read().strip().split('\n')
    for l in lines:
        l1, l2 = l.split('\t')
        counter1.update(tokenizers[lang1](l1))
        counter2.update(tokenizers[lang2](l2))

    vocab1 = vocab(counter1, min_freq=1, specials=['<unk>', '<pad>', '<sos>', '<eos>'])
    vocab2 = vocab(counter2, min_freq=1, specials=['<unk>', '<pad>', '<sos>', '<eos>'])
    vocab1.set_default_index(vocab1["<unk>"])
    vocab2.set_default_index(vocab2["<unk>"])
    return [vocab1, vocab2]

lang1_vocab, lang2_vocab = build_vocab(lang1, lang2)
lang1_vocab_size = len(lang1_vocab)
lang2_vocab_size = len(lang2_vocab)
print('lang1_vocab_vocab_size', lang1_vocab_size)
print('lang2_vocab_size', lang2_vocab_size)

lang1_vocab_vocab_size 11280
lang2_vocab_size 7123


In [5]:
torch.save(lang1_vocab, 'python/vocab/cn_vocab.pt')
torch.save(lang2_vocab, 'python/vocab/en_vocab.pt')

In [None]:
class TranslationDataset(Dataset):
    def __init__(self, lang1, lang2, max_lines, device=device):
        self.lang1 = lang1
        self.lang2 = lang2
        self.device = device
        self.data = []
        self.untokenized_data = []
        lines = open('data/%s-%s.txt' % (lang1, lang2), encoding='utf-8').read().strip().split('\n')
        lines.reverse() #use longer data
        num_lines = 0
        for l in lines:
            l1, l2 = l.split('\t')
            if (max_lines is None or max_lines > num_lines):
                l1_tokens = self.tokenize_sentence(l1, True)
                l2_tokens = self.tokenize_sentence(l2, False)
                self.data.append((l1_tokens, l2_tokens))
                self.untokenized_data.append((l1, l2))
                num_lines += 1

        self.len = len(self.data)

    def tokenize_sentence(self, sentence, is_lang1):
        vocab = lang1_vocab if is_lang1 else lang2_vocab
        tokenizer = lang1_tokenizer if is_lang1 else lang2_tokenizer
        indexes = [vocab[token] for token in tokenizer(sentence)]
        # indexes = [vocab['<sos>']] + indexes + [vocab['<eos>']]
        return torch.tensor(indexes, dtype=torch.long, device=self.device)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

In [None]:

dataset = TranslationDataset(lang1, lang2, max_lines, device)
print('length of dataset', len(dataset))

In [None]:
train_set, val_set, test_set = torch.utils.data.random_split(dataset, [0.8, 0.1, 0.1], generator=torch.Generator().manual_seed(SEED))

def collate_fn(batch):
    source = [item[0] for item in batch] 
    #pad them using pad_sequence method from pytorch. 
    source = pad_sequence(source, batch_first=True, padding_value=lang1_vocab['<pad>']) 
    
    #get all target indexed sentences of the batch
    target = [item[1] for item in batch] 
    #pad them using pad_sequence method from pytorch. 
    target = pad_sequence(target, batch_first=True, padding_value=lang2_vocab['<pad>'])
    return source.to(device), target.to(device)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

print('length of train_loader', len(train_loader))
print('length of val_loader', len(val_loader))
print('length of test_loader', len(test_loader))


In [None]:
for batch in train_loader:
    print(batch[0][0], batch[1][0])
    print(batch[0].shape, batch[1].shape)
    break

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000, device=device):
        super(PositionalEncoding, self).__init__()
        self.device= device
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model, device=device)
        position = torch.arange(0, max_len, dtype=torch.float, device=device).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)).to(device)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1).to(device)
        self.register_buffer('pe', pe)

    def forward(self, x):
        '''
        x: [seq_len, batch_size, d_model]
        '''
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


In [None]:
def get_attn_pad_mask(seq_q, seq_k):
    '''
    seq_q: [batch_size, seq_len]
    seq_k: [batch_size, seq_len]
    seq_len could be src_len or it could be tgt_len
    seq_len in seq_q and seq_len in seq_k maybe not equal
    '''
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    # eq(zero) is PAD token
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # [batch_size, 1, len_k], False is masked
    return pad_attn_mask.expand(batch_size, len_q, len_k)  # [batch_size, len_q, len_k]

def get_attn_subsequence_mask(seq, device):
    '''
    seq: [batch_size, tgt_len]
    '''
    mask = torch.ones((seq.size(0), seq.size(1), seq.size(1)), device=device) # [batch_size, tgt_len, tgt_len]
    subsequence_mask = torch.triu(mask, diagonal=1) # Upper triangular matrix
    return subsequence_mask # [batch_size, tgt_len, tgt_len]

In [None]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k):
        super(ScaledDotProductAttention, self).__init__()
        self.d_k = d_k
    def forward(self, Q, K, V, attn_mask):
        '''
        Q: [batch_size, n_heads, len_q, d_k]
        K: [batch_size, n_heads, len_k, d_k]
        V: [batch_size, n_heads, len_v(=len_k), d_v]
        attn_mask: [batch_size, n_heads, seq_len, seq_len]
        '''
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k) # scores : [batch_size, n_heads, len_q, len_k]
        scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is True.
        
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V) # [batch_size, n_heads, len_q, d_v]
        return context, attn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, d_k, d_v, n_heads):
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.d_k = d_k
        self.d_v = d_v
        self.scaled_dot_attn = ScaledDotProductAttention(d_k)
        self.layer_norm = nn.LayerNorm(d_model)
        self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)
        self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)
    def forward(self, input_Q, input_K, input_V, attn_mask):
        '''
        input_Q: [batch_size, len_q, d_model]
        input_K: [batch_size, len_k, d_model]
        input_V: [batch_size, len_v(=len_k), d_model]
        attn_mask: [batch_size, seq_len, seq_len]
        '''
        residual, batch_size = input_Q, input_Q.size(0)
        # 下面的多头的参数矩阵是放在一起做线性变换的，然后再拆成多个头，这是工程实现的技巧
        # B: batch_size, S:seq_len, D: dim
        # (B, S, D) -proj-> (B, S, D_new) -split-> (B, S, Head, W) -trans-> (B, Head, S, W)
        #           线性变换               拆成多头

        Q = self.W_Q(input_Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)  # Q: [batch_size, n_heads, len_q, d_k]
        K = self.W_K(input_K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)  # K: [batch_size, n_heads, len_k, d_k]
        V = self.W_V(input_V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1,2)  # V: [batch_size, n_heads, len_v(=len_k), d_v]

        attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1) # attn_mask : [batch_size, n_heads, seq_len, seq_len]

        # context: [batch_size, n_heads, len_q, d_v], attn: [batch_size, n_heads, len_q, len_k]
        context, attn = self.scaled_dot_attn(Q, K, V, attn_mask)
        context = context.transpose(1, 2).reshape(batch_size, -1, self.n_heads * self.d_k) # context: [batch_size, len_q, n_heads * d_v]
        output = self.fc(context) # [batch_size, len_q, d_model]
        return self.layer_norm(output + residual), attn # [batch_size, len_q, d_model]

In [None]:
class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(d_model, d_ff, bias=False),
            nn.ReLU(),
            nn.Linear(d_ff, d_model, bias=False)
        )
        self.layer_norm = nn.LayerNorm(d_model)
    def forward(self, inputs):
        '''
        inputs: [batch_size, seq_len, d_model]
        '''
        residual = inputs
        output = self.fc(inputs)
        return self.layer_norm(output + residual) # [batch_size, seq_len, d_model]


In [None]:

class EncoderLayer(nn.Module):
    def __init__(self, d_model, d_k, d_v, n_heads, d_ff):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention(d_model, d_k, d_v, n_heads)
        self.pos_ffn = PoswiseFeedForwardNet(d_model, d_ff)

    def forward(self, enc_inputs, enc_self_attn_mask):
        '''
        enc_inputs: [batch_size, src_len, d_model]
        enc_self_attn_mask: [batch_size, src_len, src_len]
        '''
        # enc_outputs: [batch_size, src_len, d_model], attn: [batch_size, n_heads, src_len, src_len]
        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V
        enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size, src_len, d_model]
        return enc_outputs, attn


class Encoder(nn.Module):
    def __init__(self, d_model, n_layers, src_vocab_size, d_k, d_v, n_heads, d_ff, device=device):
        super(Encoder, self).__init__()
        self.device=device
        self.src_emb = nn.Embedding(src_vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([EncoderLayer(d_model, d_k, d_v, n_heads, d_ff) for _ in range(n_layers)])

    def forward(self, enc_inputs):
        '''
        enc_inputs: [batch_size, src_len]
        '''
        embeded = self.src_emb(enc_inputs) # [batch_size, src_len, d_model]
        posed_embeded = self.pos_emb(embeded.transpose(0, 1)).transpose(0, 1) # [batch_size, src_len, d_model]
        enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs) # [batch_size, src_len, src_len]
        
        enc_self_attns = []
        outputs = posed_embeded
        for layer in self.layers:
            # enc_outputs: [batch_size, src_len, d_model], enc_self_attn: [batch_size, n_heads, src_len, src_len]
            outputs, enc_self_attn = layer(outputs, enc_self_attn_mask)
            enc_self_attns.append(enc_self_attn)
        return outputs, enc_self_attns

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, d_k, d_v, n_heads, d_ff):
        super(DecoderLayer, self).__init__()
        self.dec_self_attn = MultiHeadAttention(d_model, d_k, d_v, n_heads)
        self.dec_enc_attn = MultiHeadAttention(d_model, d_k, d_v, n_heads)
        self.pos_ffn = PoswiseFeedForwardNet(d_model, d_ff)

    def forward(self, dec_inputs, encoder_outputs, dec_self_attn_mask, dec_enc_attn_mask):
        '''
        dec_inputs: [batch_size, tgt_len, d_model]
        enc_outputs: [batch_size, src_len, d_model]
        dec_self_attn_mask: [batch_size, tgt_len, tgt_len]
        dec_enc_attn_mask: [batch_size, tgt_len, src_len]
        '''
        # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]
        dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)
        # dec_outputs: [batch_size, tgt_len, d_model], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
        dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, encoder_outputs, encoder_outputs, dec_enc_attn_mask)
        dec_outputs = self.pos_ffn(dec_outputs) # [batch_size, tgt_len, d_model]
        return dec_outputs, dec_self_attn, dec_enc_attn

class Decoder(nn.Module):
    def __init__(self, d_model, n_layers, tgt_vocab_size, d_k, d_v, n_heads, d_ff, device=device):
        super(Decoder, self).__init__()
        self.device=device
        self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([DecoderLayer(d_model, d_k, d_v, n_heads, d_ff) for _ in range(n_layers)])

    def forward(self, decoder_inputs, encoder_inputs, encoder_outputs):
        '''
        dec_inputs: [batch_size, tgt_len]
        enc_intpus: [batch_size, src_len]
        enc_outputs: [batsh_size, src_len, d_model]
        '''
        embeded = self.tgt_emb(decoder_inputs) # [batch_size, tgt_len, d_model]
        pos_embeded = self.pos_emb(embeded.transpose(0, 1)).transpose(0, 1) # [batch_size, tgt_len, d_model]
        dec_self_attn_pad_mask = get_attn_pad_mask(decoder_inputs, decoder_inputs) # [batch_size, tgt_len, tgt_len]
        dec_self_attn_subsequence_mask = get_attn_subsequence_mask(decoder_inputs, self.device) # [batch_size, tgt_len, tgt_len]
        
        # Decoder中把两种mask矩阵相加（既屏蔽了pad的信息，也屏蔽了未来时刻的信息）
        dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequence_mask), 0) # [batch_size, tgt_len, tgt_len]

        # 这个mask主要用于encoder-decoder attention层
        # get_attn_pad_mask主要是enc_inputs的pad mask矩阵(因为enc是处理K,V的，求Attention时是用v1,v2,..vm去加权的，要把pad对应的v_i的相关系数设为0，这样注意力就不会关注pad向量)
        #                       dec_inputs只是提供expand的size的
        dec_enc_attn_mask = get_attn_pad_mask(decoder_inputs, encoder_inputs) # [batc_size, tgt_len, src_len]

        dec_self_attns, dec_enc_attns = [], []
        decoder_outputs = pos_embeded
        for layer in self.layers:
            # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
            decoder_outputs, dec_self_attn, dec_enc_attn = layer(decoder_outputs, encoder_outputs, dec_self_attn_mask, dec_enc_attn_mask)
            dec_self_attns.append(dec_self_attn)
            dec_enc_attns.append(dec_enc_attn)
        return decoder_outputs, dec_self_attns, dec_enc_attns

In [None]:
class Transformer(nn.Module):
    def __init__(self, d_model, n_layers, src_vocab_size, tgt_vocab_size, d_k, d_v, n_heads, d_ff, device=device):
        super(Transformer, self).__init__()
        self.device=device
        self.encoder = Encoder(d_model, n_layers, src_vocab_size, d_k, d_v, n_heads, d_ff)
        self.decoder = Decoder(d_model, n_layers, tgt_vocab_size, d_k, d_v, n_heads, d_ff)
        self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False)
    def forward(self, enc_inputs, dec_inputs):
        '''
        enc_inputs: [batch_size, src_len]
        dec_inputs: [batch_size, tgt_len]
        '''

        # enc_outputs: [batch_size, src_len, d_model], enc_self_attns: [n_layers, batch_size, n_heads, src_len, src_len]
        enc_outputs, enc_self_attns = self.encoder(enc_inputs)
        # dec_outpus: [batch_size, tgt_len, d_model], dec_self_attns: [n_layers, batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [n_layers, batch_size, tgt_len, src_len]
        dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs)
        dec_logits = self.projection(dec_outputs) # dec_logits: [batch_size, tgt_len, tgt_vocab_size]
        return dec_logits.view(-1, dec_logits.size(-1)), enc_self_attns, dec_self_attns, dec_enc_attns
    
    def interface(self, enc_input, start_symbol, tgt_eos):
        enc_outputs, enc_self_attns = self.encoder(enc_input)
        dec_input = torch.zeros(1, 0).type_as(enc_input.data)
        terminal = False
        next_symbol = start_symbol
        while not terminal:         
            dec_input = torch.cat([dec_input.detach(),torch.tensor([[next_symbol]],dtype=enc_input.dtype, device=self.device)],-1)
            dec_outputs, _, _ = self.decoder(dec_input, enc_input, enc_outputs)
            projected = self.projection(dec_outputs)
            prob = projected.squeeze(0).max(dim=-1, keepdim=False)[1]

            # 增量更新（我们希望重复单词预测结果是一样的）
            # 我们在预测是会选择性忽略重复的预测的词，只摘取最新预测的单词拼接到输入序列中
            # 拿出当前预测的单词(数字)。我们用x'_t对应的输出z_t去预测下一个单词的概率，不用z_1,z_2..z_{t-1}
            next_word = prob.data[-1]
            next_symbol = next_word
            if next_symbol == tgt_eos:
                terminal = True        

        predict, _, _, _ = self.forward(enc_input, dec_input)
        predict = predict.data.max(1, keepdim=True)[1]
        return predict
            

In [None]:
# Transformer Parameters
# d_model = 512  # Embedding Size
# d_ff = 2048 # FeedForward dimension
# d_k = d_v = 64  # dimension of K(=Q), V
# n_layers = 6  # number of Encoder of Decoder Layer
# n_heads = 8  # number of heads in Multi-Head Attention

d_model = 128  # Embedding Size
d_ff = 256 # FeedForward dimension
d_k = d_v = 16  # dimension of K(=Q), V
n_layers = 4  # number of Encoder of Decoder Layer
n_heads = 4  # number of heads in Multi-Head Attention

model = Transformer(d_model, n_layers, lang1_vocab_size, lang2_vocab_size, d_k, d_v, n_heads, d_ff).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=lang2_vocab['<pad>'])
# optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)
optimizer = optim.Adam(model.parameters(), betas = (0.9, 0.98), eps = 1.0e-6)

In [None]:
def count_parameters(model):
    params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'The model has {params:,} trainable parameters')

count_parameters(model)

In [None]:
def train(model, iterator, optimizer, criterion):
    model.train()
    epoch_loss = 0
    for i, (enc_inputs, target) in enumerate(iterator):
        dec_inputs = torch.cat([torch.ones_like(target[:, :1]).fill_(lang2_vocab['<sos>']), target], dim=1)
        dec_outputs = torch.cat([target, torch.ones_like(target[:, :1]).fill_(lang2_vocab['<eos>'])], dim=1)

        optimizer.zero_grad()
         # outputs: [batch_size * tgt_len, tgt_vocab_size]
        outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)
        loss = criterion(outputs, dec_outputs.view(-1))
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    return epoch_loss / len(iterator)

In [None]:
def evaluate(model, iterator, criterion):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for i, (enc_inputs, target) in enumerate(iterator):
            dec_inputs = torch.cat([torch.ones_like(target[:, :1]).fill_(lang2_vocab['<sos>']), target], dim=1)
            dec_outputs = torch.cat([target, torch.ones_like(target[:, :1]).fill_(lang2_vocab['<eos>'])], dim=1)
            outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)
            loss = criterion(outputs, dec_outputs.view(-1))
            epoch_loss += loss.item()
    return epoch_loss / len(iterator)

In [None]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
n_epochs = 20
for epoch in range(n_epochs):
    start_time = time.time()
    train_loss = train(model, train_loader, optimizer, criterion)
    valid_loss = evaluate(model, val_loader, criterion)

    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f}')

In [None]:
enc_inputs, _ = next(iter(train_loader))
enc_input = enc_inputs[0].view(1, -1)
predict = model.interface(enc_input, start_symbol=lang2_vocab['sos'], tgt_eos=lang2_vocab['eos'])
print(enc_input[0], '->', [lang2_vocab.get_itos()[n.item()] for n in predict.squeeze()])

In [None]:
torch.save(model.state_dict(), 'transformer.pt')