<a href="https://colab.research.google.com/github/zhangguanheng66/tutorials/blob/tutorials_with_new_torchtext/tutorials_with_new_torchtext/Annotated_transformer_translation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%shell
rm -r /usr/local/lib/python3.6/dist-packages/torch*
pip install --pre torch torchvision torchtext -f https://download.pytorch.org/whl/nightly/cu101/torch_nightly.html

In [None]:
import torchtext
print(torchtext.__version__)

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math, copy, time
from torch.autograd import Variable
import matplotlib.pyplot as plt
import seaborn
seaborn.set_context(context="talk")
%matplotlib inline

In [None]:
from torch.nn import Transformer, Embedding

class Generator(nn.Module):
    "Define standard linear + softmax generation step."
    def __init__(self, d_model, vocab):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab)

    def forward(self, x):
        return F.log_softmax(self.proj(x), dim=-1)
        
class PositionalEncoding(nn.Module):
    "Implement the PE function."
    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) *
                             -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

class AnnotatedTransformer(nn.Module):
  def __init__(self, src_vocab, tgt_vocab,
               N=6, d_model=512, d_ff=2048, h=8, dropout=0.1):
    super(AnnotatedTransformer, self).__init__()
    self.transformer = Transformer(d_model=d_model, nhead=h,
                                   num_encoder_layers=N, num_decoder_layers=N,
                                   dim_feedforward=d_ff, dropout=dropout)
    self.d_model = d_model
    self.src_embed = Embedding(src_vocab, d_model)
    self.src_pos_embed = PositionalEncoding(d_model, dropout)
    self.tgt_embed = Embedding(tgt_vocab, d_model)
    self.tgt_pos_embed = PositionalEncoding(d_model, dropout)
    self.generator = Generator(d_model, tgt_vocab)

  def forward(self, src, tgt, src_mask, tgt_mask):
    src = self.src_pos_embed(self.src_embed(src) * math.sqrt(self.d_model))
    tgt = self.tgt_pos_embed(self.tgt_embed(tgt) * math.sqrt(self.d_model))
    # print(src.is_cuda, tgt.is_cuda, src_mask.is_cuda, tgt_mask.is_cuda)
    out = self.transformer(src, tgt, src_mask, tgt_mask)
    return out

In [None]:
def run_epoch(data_iter, model, criterion, optimizer=None):
    "Standard Training and Logging Function"
    model.train()
    start = time.time()
    total_tokens = 0
    total_loss = 0
    tokens = 0

    for i, (src, src_mask, tgt, tgt_mask) in enumerate(data_iter):
        ntokens = src.size(0) * src.size(1)
        out = model(src, tgt, src_mask, tgt_mask)
        out = model.generator(out)
        loss = criterion(out.contiguous().view(-1, out.size(-1)), 
                         src.contiguous().view(-1)) / src.size(1)

        if optimizer is not None:
          optimizer.zero_grad()
          loss.backward()
          optimizer.step()

        total_loss += loss.item()
        total_tokens += ntokens
        tokens += ntokens
        if i % 50 == 49:
            elapsed = time.time() - start
            print("Epoch Step: %d Loss: %12.9f Tokens per Sec: %f" %
                    (i, loss / ntokens, tokens / elapsed))
            start = time.time()
            tokens = 0
    return total_loss / total_tokens

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def data_gen(V, batch, num_samples, num_words):
    "Generate random data for a src-tgt copy task."
    sample_data = []
    for i in range(num_samples):
        data = torch.randint(1, V, size=(num_words, batch))
        data[0, :] = 0
        src = data[1:, :]
        tgt = data[:-1, :]
        sample_data.append((src, tgt))
    return sample_data

V, batch, nsamples, nwords = 11, 64, 300, 15
model = AnnotatedTransformer(V, V, N=2).to(device)
subsequent_mask = model.transformer.generate_square_subsequent_mask
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=4.0)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.8)

def genereate_batch(raw_data):
  src, tgt = raw_data[0]
  src_mask = subsequent_mask(src.size(0)).to(device)
  tgt_mask = subsequent_mask(tgt.size(0)).to(device)
  return src.to(device), src_mask, tgt.to(device), tgt_mask

train_data = torch.utils.data.DataLoader(data_gen(V, batch, nsamples, nwords), 
                                         collate_fn=genereate_batch)
test_data = torch.utils.data.DataLoader(data_gen(V, batch, 5, nwords), 
                                        collate_fn=genereate_batch)
for epoch in range(5):
    model.train()
    run_epoch(train_data, model, criterion, optimizer)

    # Adjust the learning rate
    scheduler.step()

    model.eval()
    print(run_epoch(test_data, model, criterion))

In [None]:
model = model.to("cpu")
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    ys = torch.ones(1, src.size(1), dtype=torch.long).fill_(start_symbol).to(src.device)
    for i in range(max_len-1):
        ys_mask = subsequent_mask(ys.size(0)).to(src.device)
        # print(src.is_cuda, ys.is_cuda, src_mask.is_cuda, ys_mask.is_cuda)
        out = model(src, ys, src_mask, ys_mask)
        prob = model.generator(out[-1, :])
        next_word = prob[-1, :].argmax(-1).item()
        ys = torch.cat([ys, torch.ones(1, src.size(1), dtype=torch.long).fill_(next_word).to(ys.device)], dim=0)
    return ys

model.eval()
src = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], dtype=torch.long).t()
src_mask = subsequent_mask(src.size(0))
print(greedy_decode(model, src, src_mask, max_len=10, start_symbol=1))

In [None]:
%%shell
python -m spacy download en
python -m spacy download de

In [None]:
%%shell
wget https://pytorch.s3.amazonaws.com/models/text/annotated_transformer_en_vocab.txt
wget https://pytorch.s3.amazonaws.com/models/text/annotated_transformer_de_vocab.txt
wget https://pytorch.s3.amazonaws.com/models/text/annotated_torch_transformer.pt
wget https://s3.amazonaws.com/opennmt-models/iwslt.pt

In [None]:
import torch
import torchtext
from torchtext.experimental.datasets.raw import IWSLT
from torchtext.experimental.vocab import load_vocab_from_file
from torchtext.data.utils import get_tokenizer
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

f = open('annotated_transformer_de_vocab.txt', 'r')
de_vocab = load_vocab_from_file(f)
de_tokenizer = get_tokenizer("spacy", language='de')

f = open('annotated_transformer_en_vocab.txt', 'r')
en_vocab = load_vocab_from_file(f)
en_tokenizer = get_tokenizer("spacy", language='en')

def data_process(de_raw, en_raw):
  return (de_vocab(de_tokenizer(de_raw)), en_vocab(en_tokenizer(en_raw)))

raw_train_iter, raw_valid_iter, raw_test_iter = IWSLT()
processed_data = [data_process(de_raw, en_raw) for (de_raw, en_raw) in raw_train_iter]
pad_idx = en_vocab(['<pad>'])[0]

In [None]:
# Use a pretrained torch.nn.Transformer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load("annotated_torch_transformer.pt")
model = model.to(device)

def greedy_decode(model, src, src_mask, max_len, start_symbol):
    ys = torch.ones(1, src.size(1), dtype=torch.long).fill_(start_symbol).to(src.device)
    for i in range(max_len-1):
        ys_mask = subsequent_mask(ys.size(0)).to(src.device)
        out = model(src, ys, src_mask, ys_mask)
        prob = model.generator(out[-1, :])
        next_word = prob[-1, :].argmax(-1).item()
        ys = torch.cat([ys, torch.ones(1, src.size(1), dtype=torch.long).fill_(next_word).to(ys.device)], dim=0)
    return ys

def collate_func(data_batch):
  de_data_, en_data_ = [], []
  for (de_item, en_item) in data_batch:
    de_data_.append(torch.tensor(de_item, dtype=torch.long))
    en_data_.append(torch.tensor(en_item, dtype=torch.long))
  de_data_ = pad_sequence(de_data_, padding_value=pad_idx).long()
  en_data_ = pad_sequence(en_data_, padding_value=pad_idx).long()
  return de_data_, en_data_

train_dataloader = DataLoader(processed_data, batch_size=8, shuffle=False, collate_fn=collate_func)
count = 0
for (de_data, en_data) in train_dataloader:
  src, tgt = de_data[0].unsqueeze(0).transpose(0, 1).to(device), en_data[0].unsqueeze(0).transpose(0, 1).to(device)
  src_mask = subsequent_mask(src.size(0)).to(device)
  out = greedy_decode(model, src, src_mask, max_len=60, 
                      start_symbol=en_vocab(['<s>'])[0])
  print('-------------------------------', count)
  count += 1
  print("Source:", " ".join(de_vocab.lookup_tokens(src.t()[0].tolist())))
  print("Target:", " ".join(en_vocab.lookup_tokens(tgt.t()[0].tolist())))
  print("Translation:", " ".join(en_vocab.lookup_tokens(out.t()[0].tolist())))

  if count > 2:
    break

In [None]:
# Use the pretrained Transformer model from Annotated Transformer tutorial
from annotated_transformer import EncoderDecoder, Encoder, EncoderLayer, MultiHeadedAttention, \
  PositionwiseFeedForward, SublayerConnection, LayerNorm, Decoder, \
  DecoderLayer, Embeddings
  
def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0

model = torch.load("iwslt.pt")
device = torch.device('cpu')
model = model.to(device)

def greedy_decode(model, src, src_mask, max_len, start_symbol):
    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data)
    for i in range(max_len-1):
        out = model.decode(memory, src_mask, 
                           Variable(ys), 
                           Variable(subsequent_mask(ys.size(1))
                                    .type_as(src.data)))
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim = 1)
        next_word = next_word.data[0]
        ys = torch.cat([ys, 
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
    return ys

def collate_func(data_batch):
  de_data_, en_data_ = [], []
  for (de_item, en_item) in data_batch:
    de_data_.append(torch.tensor(de_item, dtype=torch.long))
    en_data_.append(torch.tensor(en_item, dtype=torch.long))
  de_data_ = pad_sequence(de_data_, padding_value=pad_idx).long()
  en_data_ = pad_sequence(en_data_, padding_value=pad_idx).long()
  return de_data_, en_data_

train_dataloader = DataLoader(processed_data, batch_size=8, shuffle=False, collate_fn=collate_func)
count = 0
for (de_data, en_data) in train_dataloader:
  de_data, en_data = de_data.transpose(0, 1), en_data.transpose(0, 1)
  src, tgt = de_data[0].unsqueeze(0).to(device), en_data[0].unsqueeze(0).to(device)
  src_mask = (src != de_vocab(['<pad>'])[0]).unsqueeze(0).to(device)
  out = greedy_decode(model, src, src_mask, max_len=60, 
                      start_symbol=en_vocab(['<s>'])[0])
  print('-------------------------------', count)
  count += 1
  print("Source:", " ".join(de_vocab.lookup_tokens(src[0].tolist())))
  print("Target:", " ".join(en_vocab.lookup_tokens(tgt[0].tolist())))
  print("Translation:", " ".join(en_vocab.lookup_tokens(out[0].tolist())))

  if count > 2:
    break