# The Annotated Transformer

Experimental codes written/copied from [here](http://nlp.seas.harvard.edu/2018/04/03/attention.html#embeddings-and-softmax) to fully understand the architecture of the transformer.

In [None]:
# Commands to install necessary packages
#!pip install numpy matplotlib spacy torchtext seaborn
#!conda install pytorch torchvision -c pytorch -y
#!pip install pixiedust

## Preliminaries

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math, copy, time
from torch.autograd import Variable
import matplotlib.pyplot as plt
import seaborn
seaborn.set_context(context="talk")
%matplotlib inline

## EncoderDecoder Module Architecture

In [None]:
class EncoderDecoder(nn.Module):
    
    def __init__(self, encoder, decoder, generator, src_embed, tgt_embed):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.generator = generator
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        
    def forward(self, src, src_mask, tgt, tgt_mask):
        return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)
    
    def encode(self, src, src_mask):
        return self.encoder(self.src_embed(src), src_mask)
    
    def decode(self, memory, src_mask, tgt, tgt_mask):
        return self.decoder(memory, src_mask, self.tgt_embed(tgt), tgt_mask)
    
class Generator(nn.Module):
    
    def __init__(self, d_module, vocab):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_module, vocab)
        
    def forward(self, x):
        return F.log_softmax(self.proj(x), dim=-1)

## The Encoder Module

It is composed of six identical EncoderLayers plus output normalization.

It should be noted that, after the paper publication, an improvement was discovered, i.e. to normalize the input embedding vector in the encoder. The additional normalization step caused the encoding layers to start with normalization and end with residual addition. 

In [None]:
def clone(module, N):
    # make N identical copies of module
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

In [None]:
class Encoder(nn.Module):
    
    # The encoder module is composed of six EncoderLayers
    # followed by a layer normalization
    
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clone(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, src_embed, src_mask):
        for encoderLayer in self.layers:
            src_embed = encoderLayer(src_embed, src_mask)
        # normalization is necessary since each EncoderLayer ends with residual addition.
        return self.norm(src_embed)

class LayerNorm(nn.Module):
    
    def __init__(self, feature_size, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a2 = nn.Parameter(torch.ones(feature_size))
        self.b2 = nn.Parameter(torch.zeros(feature_size))
        self.eps = eps
        
    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.a2 * (x - mean) / (std + self.eps) + self.b2

In [None]:
class EncoderLayer(nn.Module):
    
    # Each encoder layer consists of the following:
    # 1. input normalization
    # 2. self attention & dropout
    # 3. residual addition and normalization
    # 4. a fully connected layer
    # 5. residual addition
    def __init__(self, d_model, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.dropout = nn.Dropout(dropout)
        self.norm_attn_input = LayerNorm(d_model)
        self.norm_fc_input = LayerNorm(d_model)
        self.size = d_model
        
    def forward(self, x, mask):
        input_x = x
        
        norm_x = self.norm_attn_input(input_x)
        attn_out = self.dropout(self.self_attn(norm_x, norm_x, norm_x, mask))
        input_x = input_x + attn_out
        
        norm_x = self.norm_fc_input(input_x)
        fc_out = self.dropout(self.feed_forward(norm_x))
        return input_x + fc_out

## The Decoder Module

In [None]:
# The decoder consists of six layers of DecoderLayers plus output normalization
class Decoder(nn.Module):
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clone(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, memory, src_mask, tgt, tgt_mask):
        for layer in self.layers:
            tgt = layer(memory, src_mask, tgt, tgt_mask)
        return self.norm(tgt)

In [None]:
# In addition to the two sub-layers in each encoder layer, the decoder layer
# has a third sub-layer, which performs multi-headed attention of the
# encoder output
class DecoderLayer(nn.Module):
    def __init__(self, d_model, src_attn, self_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.norm_self_attn_input = LayerNorm(d_model)
        self.norm_src_attn_input = LayerNorm(d_model)
        self.norm_fc_input = LayerNorm(d_model)
        self.src_attn = src_attn
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.dropout = nn.Dropout(dropout)
        self.size = d_model
        
    def forward(self, memory, src_mask, tgt, tgt_mask):
        input_x = tgt
        
        norm_x = self.norm_self_attn_input(input_x)
        self_attn_out = self.dropout(self.self_attn(norm_x, norm_x, norm_x, tgt_mask))
        input_x = input_x + self_attn_out
        
        norm_x = self.norm_fc_input(input_x)
        src_attn_out = self.dropout(self.src_attn(norm_x, memory, memory, src_mask))
        input_x = input_x + src_attn_out
        
        norm_x = self.norm_src_attn_input(input_x)
        fc_out = self.dropout(self.feed_forward(norm_x))
        return input_x + fc_out

In [None]:
# To build attention mask for the decoder outputs
def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0
plt.figure(figsize=(5,5))
plt.imshow(subsequent_mask(20)[0])

## The Attention Module

In [None]:
def attention(Q, K, V, mask=None, dropout=None):
    dim_k = Q.size(-1)
    dot_prod = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(dim_k)
    if mask is not None:
        dot_prod = dot_prod.masked_fill(mask == 0, -1e9)
    p_attn = F.softmax(dot_prod, dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, V), p_attn

In [None]:
class MultiHeadedAttention(nn.Module):
    # Multi headed attention allows attention to different
    # locations from different subspaces.
    def __init__(self, head_count, model_size, dropout=0.1):
        super(MultiHeadedAttention, self).__init__()
        assert model_size % head_count == 0
        self.linears = clone(nn.Linear(model_size, model_size), 4)
        self.head_count = head_count
        self.head_size = model_size // head_count
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, Q, K, V, mask=None):
        if mask is not None:
            # use the same mask for all heads
            mask = mask.unsqueeze(1)
        batch_size = Q.size(0)
        q_trans, k_trans, v_trans = [\
                                     l(x).view(batch_size, -1, self.head_count, self.head_size).transpose(1, 2) \
                                     for l, x in zip(self.linears, (Q, K, V))]
        x, self.attn = attention(q_trans, k_trans, v_trans, mask=mask, dropout=self.dropout)
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.head_count * self.head_size)
        return self.linears[-1](x)

## Position-wise Feed-Forward Network

In [None]:
class PositionwiseFeedForward(nn.Module):
    # The position-wise feed-forward network has two linear models with Relu activation in between.
    # the feed-forward network is shared among word embeddings of all positions.
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.lt1 = nn.Linear(d_model, d_ff)
        self.lt2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        return self.lt2(self.dropout(F.relu(self.lt1(x))))

## Word Embedding

In [None]:
class Embeddings(nn.Module):
    # We use the usual learned embedding to convert the input/output of vocab size
    # to vector of dimension d_model. 
    def __init__(self, vocab, d_model):
        super(Embeddings, self).__init__()
        self.embeddings = nn.Embedding(vocab, d_model)
        self.d_model = d_model
        
    def forward(self, x):
        return self.embeddings(x) * math.sqrt(self.d_model)

## Position Embedding

In [None]:
class PositionEmbedding(nn.Module):
    # The position embedding are sine and cosine functions so that
    # embedding vectors at p+k is a linear transformation from the
    # embedding vectors at position p, and the transformation coef-
    # ficients are indpendant of the value p.
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super(PositionEmbedding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(-1 * torch.arange(0, d_model, 2, dtype=torch.float) * math.log(10000.0) / d_model)
        pe[:,0::2] = torch.sin(position * div_term)
        pe[:,1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)
        
    def forward(self, x):
        x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False)
        return self.dropout(x)

## The Full Model

In [None]:
def make_model(src_vocab, tgt_vocab, d_model=512, d_ff=2048, head_count=8, dropout=0.1, layer_count=6):
    mh_attn = MultiHeadedAttention(head_count, d_model, dropout=dropout)
    feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout=dropout)
    pe = PositionEmbedding(d_model, dropout=0.1)
    encoder_layer = EncoderLayer(d_model, mh_attn, feed_forward, dropout=dropout)
    encoder = Encoder(encoder_layer, layer_count)
    decoder_layer = DecoderLayer(d_model, copy.deepcopy(mh_attn), copy.deepcopy(mh_attn), copy.deepcopy(feed_forward), dropout)
    decoder = Decoder(decoder_layer, layer_count)
    generator = Generator(d_model, tgt_vocab)
    src_embed = nn.Sequential(Embeddings(src_vocab, d_model), pe)
    tgt_embed = nn.Sequential(Embeddings(tgt_vocab, d_model), copy.deepcopy(pe))
    model = EncoderDecoder(encoder, decoder, generator, src_embed, tgt_embed)
    
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    return model

In [None]:
tmp_model = make_model(10, 10, 16, 32)

## Training

### Batching & Masking

Each word in a target training sentence pays attention only
    to words before it, and hence a target mask is created. Words
    before the last becomes new training input (self.tgt). And
    words after the first becomes the label (self.tgt_y).

In [None]:
class Batch:
    def __init__(self, src, tgt=None, pad=0):
        self.src = src
        self.src_mask = (src != pad).unsqueeze(-2)
        if tgt is not None:
            self.tgt = tgt[:,:-1]
            self.tgt_mask = self.make_std_mask(self.tgt, pad)
            self.tgt_y = tgt[:,1:]
            self.ntokens = (self.tgt_y != pad).data.sum().type(torch.FloatTensor)
            
    def make_std_mask(self, tgt, pad):
        tgt_mask = (tgt != pad).unsqueeze(-2)
        tgt_mask = tgt_mask & Variable(subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
        return tgt_mask

### A Generic Training Loop

Given a batch data iterator, a model and a loss function, run a training loop and report loss (and other information) every 50 steps.

In [None]:
def run_epoch(batch_iter, model, loss_compute, log_interval=50):
    total_loss = 0
    tokens = 0
    total_tokens = 0
    start = time.time()
    for i, batch in enumerate(batch_iter):
        out = model.forward(batch.src, batch.src_mask, batch.tgt, batch.tgt_mask)
        loss = loss_compute(out, batch.tgt_y, batch.ntokens)
        total_loss += loss
        total_tokens += batch.ntokens
        tokens += batch.ntokens
        if i % log_interval == 0:
            elapsed = time.time() - start
            print("Epoch Step: %d Loss: %f Tokens per Sec: %f" %
                    (i, loss / batch.ntokens, tokens / elapsed))
            start = time.time()
            tokens = 0
    return total_loss / total_tokens

### the Batch Size Function

The training utilizes [torchtext package](https://github.com/pytorch/text) for accessing datasets and preprocessing. For this purpose and dynamic batching, a batch size calculation function is to be provided.

In [None]:
global max_src_in_batch, max_tgt_in_batch
def my_batch_size_fn(new, count, sofar):
    "Keep augmenting batch and calculate total number of tokens + padding."
    global max_src_in_batch, max_tgt_in_batch
    if count == 1:
        max_src_in_batch = 0
        max_tgt_in_batch = 0
    max_src_in_batch = max(max_src_in_batch,  len(new.src))
    max_tgt_in_batch = max(max_tgt_in_batch,  len(new.trg) + 2)
    # torchtext automatically pads sequences to the maximum sequence length
    src_elements = count * max_src_in_batch
    tgt_elements = count * max_tgt_in_batch
    return max(src_elements, tgt_elements)

### the Optimizer

The Adam optimizer with $\beta_1=0.9$, $\beta_2=0.98$ and $\epsilon=10^{-9}$ is used with an adaptive learning rate. The learning rate is designed to linearly increase until a given $warmup\_step$, and then decrease proportially to $\sqrt{step\_number}$. Additionally the learning rate is inversely proportionla to $\sqrt{d_{model}}$. i.e.

$lrate = factor \times d_{model}^{-1/2} \times \min\left\{{step\_number}^{-1/2},\frac{step\_number}{{warmup\_step}^{3/2}}\right\}$

In [None]:
class NoamOpt:
    "Optim wrapper that implements rate."
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0
        
    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()
        
    def rate(self, step = None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * \
            (self.model_size ** (-0.5) *
            min(step ** (-0.5), step * self.warmup ** (-1.5)))
        
def get_std_opt(model):
    return NoamOpt(model.src_embed[0].d_model, 2, 4000,
            torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))

### the Label Smoothing

We use label smoothing along with the KL divergence for the loss function. 

In [None]:
class LabelSmoothing(nn.Module):
    # size - the vocabulary size
    # smoothing - the smooth factor
    # padding_index - the index with padding
    def __init__(self, size, padding_idx, smoothing=0.0):
        super(LabelSmoothing, self).__init__()
        self.criterion = nn.KLDivLoss(reduction='sum')
        self.size = size
        self.smoothing = smoothing
        self.padding_index = padding_idx
        self.true_dist = None
        
    def forward(self, x, target):
        assert self.size == x.size(1)
        true_dist = x.data.clone()
        true_dist.fill_(self.smoothing / (self.size - 2))
        true_dist.scatter_(1, target.data.unsqueeze(1), 1.0 - self.smoothing)
        true_dist[:,self.padding_index] = 0
        mask = torch.nonzero(target.data == self.padding_index)
        if mask.size(0) > 0:
            true_dist.index_fill_(0, mask.squeeze(), 0.0)
        self.true_dist = true_dist
        return self.criterion(x, Variable(true_dist, requires_grad=False))

### Synthetic Data

In [None]:
def data_gen(V, batch, nbatches):
    "Generate random data for a src-tgt copy task."
    for i in range(nbatches):
        data = torch.from_numpy(np.random.randint(1, V, size=(batch, 10)))
        data[:, 0] = 1
        src = Variable(data, requires_grad=False)
        tgt = Variable(data, requires_grad=False)
        yield Batch(src, tgt, 0)

### Loss Computation

In [None]:
class SimpleLossCompute:
    "A simple loss compute and train function."
    def __init__(self, generator, criterion, opt=None):
        self.generator = generator
        self.criterion = criterion
        self.opt = opt
        
    def __call__(self, x, y, norm):
        x = self.generator(x)
        loss = self.criterion(x.contiguous().view(-1, x.size(-1)), 
                              y.contiguous().view(-1)) / norm
        loss.backward()
        if self.opt is not None:
            self.opt.step()
            self.opt.optimizer.zero_grad()
        return loss.data.item() * norm

### Greedy Decoding

In [None]:
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data)
    for i in range(max_len-1):
        out = model.decode(memory, src_mask, 
                           Variable(ys), 
                           Variable(subsequent_mask(ys.size(1))
                                    .type_as(src.data)))
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim = 1)
        next_word = next_word.data[0]
        ys = torch.cat([ys, 
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
    return ys

### Try a Simple Copy Task

In [None]:
# Train the simple copy task.
TRY_COPY_TASK = True
if TRY_COPY_TASK:
    V = 11
    criterion = LabelSmoothing(size=V, padding_idx=0, smoothing=0.0)
    model = make_model(V, V, layer_count=2)
    model_opt = NoamOpt(model.src_embed[0].d_model, 1, 400,
            torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))

    for epoch in range(10):
        model.train()
        run_epoch(data_gen(V, 30, 20), model, 
                  SimpleLossCompute(model.generator, criterion, model_opt))
        model.eval()
        print(run_epoch(data_gen(V, 30, 5), model, 
                        SimpleLossCompute(model.generator, criterion, None)))

    model.eval()
    src = Variable(torch.LongTensor([[1,2,3,4,5,6,7,8,9,10]]) )
    src_mask = Variable(torch.ones(1, 1, 10) )
    print(greedy_decode(model, src, src_mask, max_len=10, start_symbol=1))

### Data Loading

In [None]:
from torchtext import data, datasets
import spacy

#!python -m spacy download en
#!python -m spacy download de

### Tight batching

Tight batching can improve training performance. We consolidate 100 batches from torchtext and sort by sample sizes to create tight batches

In [None]:
class TightBatchingIterator(data.Iterator):
    """
    To contruct tight batching to improve training speed.
    """
    def __init__(self, dataset, batch_size, sort_key=None, device=None,
                 batch_size_fn=None, train=True,
                 repeat=False, shuffle=None, sort=None,
                 sort_within_batch=None):
        super(TightBatchingIterator, self).__init__(dataset, batch_size, \
                sort_key=sort_key, device=device, \
                batch_size_fn=batch_size_fn, train=train, \
                repeat=repeat, shuffle=shuffle, sort=sort, \
                sort_within_batch=sort_within_batch)
        self.batches = None

    def create_batches(self):
        """
        When training, we extract 100 batches from torchtext, sort them by size,
        and then send them as individual batches.
        For non-training, we simply perform a sort inside a batch.
        """
        if self.train:
            def pool(data_in, random_shuffler):
                for macro_batch in data.batch(data_in, self.batch_size * 100):
                    one_batch_iterator = data.batch(sorted(macro_batch, key=self.sort_key), \
                                           self.batch_size, self.batch_size_fn)
                    for one_batch in random_shuffler(list(one_batch_iterator)):
                        yield one_batch
            self.batches = pool(self.data(), self.random_shuffler)
        else:
            self.batches = []
            for one_batch in data.batch(self.data(), self.batch_size, self.batch_size_fn):
                self.batches.append(sorted(one_batch, key=self.sort_key))

def rebatch(padding_idx, batch):
    "Fix order in torchtext to match ours"
    src, tgt = batch.src.transpose(0, 1), batch.trg.transpose(0, 1)
    return Batch(src, tgt, padding_idx)

### GPU/CPU Training

In [None]:
class MultiGPULossCompute:
    """
    Use mutlipe GPU for loss compute when available
    """
    def __init__(self, generator, criterion, opt=None, devices=None, chunk_size=5):
        self.devices = devices
        self.generator = generator
        # criterion does not change during training
        # replicate at object construction time
        self.criterion = nn.parallel.replicate(criterion, devices=devices)
        self.opt = opt
        self.chunk_size = chunk_size

    def __call__(self, out, target, normalize):
        # Inference on the generator
        generator = nn.parallel.replicate(self.generator, devices=self.devices)
        out_scatter = nn.parallel.scatter(out, target_gpus=self.devices)
        out_grad = [[] for _ in out_scatter]
        target_scatter = nn.parallel.scatter(target, target_gpus=self.devices)

        # Divide generating into chunks
        chunk_size = self.chunk_size
        for i in range(0, out_scatter[0].size(1), chunk_size):
            # prediction distribution for a chunk
            out_column = [[Variable(o[:, i:i + chunk_size].data, \
                            requires_grad=self.opt is not None)] \
                            for o in out_scatter]
            gen = nn.parallel.parallel_apply(generator, out_column)
            # compute loss for a chunk
            pred_label = [(g.contiguous().view(-1, g.size(-1)), \
                  t[:, i:i + chunk_size].contiguous().view(-1)) \
                     for g, t in zip(gen, target_scatter)]
            loss = nn.parallel.parallel_apply(self.criterion, pred_label)
            loss_compute = nn.parallel.gather(loss, target_device=self.devices[0])
            loss_compute = loss_compute.sum()[0] / normalize
            total += loss_compute.data[0]

            if self.opt is not None:
                loss_compute.backward()
                for j, loss_compute in enumerate(loss):
                    out_grad[j].append(out_column[j][0].grad.data.clone())

        # Backprop all loss through transformer.
        if self.opt is not None:
            out_grad = [Variable(torch.cat(og, dim=1)) for og in out_grad]
            out1 = out
            out2 = nn.parallel.gather(out_grad, \
                                    target_device=self.devices[0])
            out1.backward(gradient=out2)
            self.opt.step()
            self.opt.optimizer.zero_grad()
        return total * normalize

### The Training Class

In [None]:
from torchtext import data, datasets
import spacy

class IWSLTTrainer:
    """The training class utilizes CPU/GPU for training"""
    def __init__(self, use_gpu=False, devices=None, from_lang="de", to_lang="en"):
        self.use_gpu = use_gpu
        self.devices = devices

        spacy_from = spacy.load(from_lang)
        spacy_to = spacy.load(to_lang)
        tokenizer_from = lambda x: [tok.text for tok in spacy_from.tokenizer(x)]
        tokenizer_to = lambda x: [tok.text for tok in spacy_to.tokenizer(x)]

        bos_word = '<s>'
        eos_word = '</s>'
        blank_word = '<blank>'
        src_field = data.Field(tokenize=tokenizer_from, pad_token=blank_word)
        tgt_field = data.Field(tokenize=tokenizer_to, init_token=bos_word, \
                     eos_token=eos_word, pad_token=blank_word)
        max_len = 100
        self.train, self.val, self.test = datasets.IWSLT.splits(
            exts=('.' + from_lang, '.' + to_lang), fields=(src_field, tgt_field),
            filter_pred=lambda x: len(vars(x)['src']) <= max_len and \
                          len(vars(x)['trg']) <= max_len)

        min_freq = 2
        src_field.build_vocab(self.train.src, min_freq=min_freq)
        tgt_field.build_vocab(self.train.trg, min_freq=min_freq)
        self.pad_idx = tgt_field.vocab.stoi[blank_word]
        self.tgt_field = tgt_field
        self.src_field = src_field

        self.model = make_model(len(src_field.vocab), len(tgt_field.vocab), layer_count=6)
        if use_gpu:
            self.model.cuda()
        self.criterion = LabelSmoothing(size=len(tgt_field.vocab), \
                    padding_idx=self.pad_idx, smoothing=0.1)
        if use_gpu:
            self.criterion.cuda()

        batch_size = 12000
        self.train_iter = TightBatchingIterator(self.train, batch_size=batch_size, device=None, \
                            repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)), \
                            batch_size_fn=my_batch_size_fn, train=True)
        self.valid_iter = TightBatchingIterator(self.val, batch_size=batch_size, device=None, \
                            repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)), \
                            batch_size_fn=my_batch_size_fn, train=False)
        if use_gpu:
            self.model_par = nn.DataParallel(self.model, device_ids=devices)

        self.model_opt = NoamOpt(self.model.src_embed[0].d_model, 1, 2000, \
            torch.optim.Adam(self.model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))

    def run_training(self, epoch_count, log_interval=50):
        '''Run training for multiple epochs'''
        for _ in range(epoch_count):
            if self.use_gpu:
                self.model_par.train()
                run_epoch((rebatch(self.pad_idx, b) for b in self.train_iter), \
                      self.model_par, \
                      MultiGPULossCompute(self.model.generator, self.criterion, \
                                          devices=self.devices, opt=self.model_opt), \
                      log_interval=log_interval)
                self.model_par.eval()
                loss = run_epoch((rebatch(self.pad_idx, b) for b in self.valid_iter), \
                              self.model_par, \
                              MultiGPULossCompute(self.model.generator, self.criterion, \
                                  devices=self.devices, opt=None), \
                              log_interval=log_interval)
                print(loss)
            else:
                self.model.train()
                run_epoch((rebatch(self.pad_idx, b) for b in self.train_iter), \
                      self.model, \
                      SimpleLossCompute(self.model.generator, self.criterion, opt=self.model_opt), \
                      log_interval=log_interval)
                self.model.eval()
                loss = run_epoch((rebatch(self.pad_idx, b) for b in self.valid_iter), \
                              self.model, \
                              MultiGPULossCompute(self.model.generator, self.criterion, \
                                  devices=self.devices, opt=None), \
                              log_interval=log_interval)
                print(loss)

    def run_validating(self):
        """
        Once trained we can decode the model to produce a set of translations.
        Here we simply translate the first sentence in the validation set. This
        dataset is pretty small so the translations with greedy search are
        reasonably accurate.
        """
        for _, batch in enumerate(self.valid_iter):
            src = batch.src.transpose(0, 1)[:1]
            src_mask = (src != self.src_field.vocab.stoi["<blank>"]).unsqueeze(-2)
            out = greedy_decode(self.model, src, src_mask, \
                        max_len=60, start_symbol=self.tgt_field.vocab.stoi["<s>"])
            print("Translation:", end="\t")
            for j in range(1, out.size(1)):
                sym = self.tgt_field.vocab.itos[out[0, j]]
                if sym == "</s>":
                    break
                print(sym, end=" ")
            print()
            print("Target:", end="\t")
            for j in range(1, batch.trg.size(0)):
                sym = self.tgt_field.vocab.itos[batch.trg.data[j, 0]]
                if sym == "</s>":
                    break
                print(sym, end=" ")
            print()
            break

### Training

In [None]:
def train_on_cpu():
    """Run training for IWSLT on CPU for debugging purpose"""
    trainer = IWSLTTrainer()
    trainer.run_training(10, log_interval=1)

train_on_cpu()