pytorch에서 제공하는 TransformerEncoder, TransformerDecoder를 이용해서 transformer 구현한 예제


In [1]:
import torch
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer
from torch.nn.modules import LayerNorm
from torch.nn.init import xavier_uniform_
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim
from torchtext.data import Dataset
import math
import os
import io
from torchtext.utils import download_from_url
from torchtext.vocab import Vocab
from collections import Counter
import tarfile
import numpy as np
import random

In [2]:
MODEL_DIM = 512
NUM_HEADS = 8
NUM_LAYERS = 6
BATCH_SIZE = 64
NUM_EPOCHS = 8
SEQ_LEN = 50

In [3]:
class PositionEmbed(nn.Module):
  def __init__(self, d_model, max_seq_length):
    super().__init__()
    self.pos_embed = nn.Embedding(max_seq_length, d_model)

  def forward(self, inputs):
    positions = torch.arange(inputs.size(1), device=inputs.device, dtype=inputs.dtype).expand(inputs.size(0), inputs.size(1)).contiguous() + 1
    pos_mask = inputs.eq(1)
    positions.masked_fill_(pos_mask, 0)
    pos_embed = self.pos_embed(positions)

    return pos_embed


In [4]:
inputs = torch.randint(0, 100, (1, 10))
pos_emb = PositionEmbed(MODEL_DIM, 100)
embed = pos_emb(inputs)
embed

tensor([[[ 2.0831,  1.5158, -1.0150,  ..., -0.2912,  0.3344, -0.6084],
         [-0.7239,  0.3984, -0.2242,  ...,  1.2703,  1.4499,  1.5684],
         [-0.3978, -0.8903,  0.7149,  ...,  1.2819,  0.4226,  1.5349],
         ...,
         [ 0.0064, -0.3452, -0.1066,  ..., -2.5091, -0.8501,  0.2440],
         [ 2.0557, -0.6697,  0.9876,  ..., -1.8837, -0.1613, -0.4974],
         [-0.0996, -0.0515,  0.5646,  ..., -1.1756, -1.2993, -2.7639]]],
       grad_fn=<EmbeddingBackward>)

In [5]:
class Transformer(nn.Module):
    def __init__(self, vocab_size_src, vocab_size_trg, d_model, nhead, num_layers, dim_feedforward, dropout, max_seq_length=100):
        super().__init__()
        self.d_model = d_model
        self.seq_len = max_seq_length
        self.embed_src = nn.Embedding(vocab_size_src, d_model)
        self.embed_trg = nn.Embedding(vocab_size_trg, d_model)
        self.pos_embed = PositionEmbed(d_model, max_seq_length * 2)

        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout=0.1, activation="relu")      
        encoder_norm = LayerNorm(d_model)
        self.encoder = TransformerEncoder(encoder_layer, num_layers, encoder_norm)

        decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout=0.1, activation='relu')
        decoder_norm = LayerNorm(d_model)
        self.decoder = TransformerDecoder(decoder_layer, num_layers, decoder_norm)

        self.fc = nn.Sequential(nn.Linear(d_model, vocab_size_trg),
                                nn.LogSoftmax(dim=-1))

    def _reset_parameters(self):
        r"""Initiate parameters in the transformer model."""

        for p in self.parameters():
            if p.dim() > 1:
                xavier_uniform_(p)

    def rearrange(self, x):
      return torch.transpose(x, 0, 1)
    
    def get_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, src, tgt, src_key_padding_mask=None, tgt_key_padding_mask=None):
        src = self.rearrange(src)  # (batch_size, seq_len) => (seq_len, batch_size)
        tgt = self.rearrange(tgt)  # (batch_size, seq_len) => (seq_len, batch_size)
        if src.size(1) != tgt.size(1):
          raise RuntimeError("the batch number of src and tgt must be equal")

        src = self.embed_src(src) + self.pos_embed(src)
        tgt = self.embed_trg(tgt) + self.pos_embed(tgt)
        src_mask = self.get_mask(src.size()[0]).to(src.device)
        tgt_mask = self.get_mask(tgt.size()[0]).to(tgt.device)
        if src_key_padding_mask is not None:
          memory_key_padding_mask = src_key_padding_mask.clone()
        else:
          memory_key_padding_mask = None

        self.encoder_output = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
        decoder_output = self.decoder(tgt, self.encoder_output, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask)
        decoder_output = self.fc(decoder_output)
        decoder_output = self.rearrange(decoder_output)  # (seq_len, batch_size, vocab_size) = > (batch_size, seq_len, vocab_size)
        return decoder_output

    def predict(self, src, src_key_padding_mask=None):
      src = self.rearrange(src)  # (batch_size, seq_len) => (seq_len, batch_size)
      tgt = torch.ones([1, src.size()[1]], dtype=torch.long).to(src.device)

      src = self.embed_src(src) + self.pos_embed(src)
      tgt = self.embed_trg(tgt) + self.pos_embed(tgt)
      src_mask = self.get_mask(src.size()[0]).to(src.device)
      if src_key_padding_mask is not None:
        memory_key_padding_mask = src_key_padding_mask.clone()
      else:
        memory_key_padding_mask = None

      self.encoder_output = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
      decoder_input = tgt
      for i in range(self.seq_len):
        tgt_mask = self.get_mask(decoder_input.size()[0]).to(tgt.device)
        decoder_output = self.decoder(decoder_input, self.encoder_output, tgt_mask=tgt_mask, memory_key_padding_mask=memory_key_padding_mask)  # (seq_len, batch_size, d_model)
        decoder_output = self.fc(decoder_output)
        decoder_max = torch.argmax(decoder_output, dim=-1)
        decoder_max = self.embed_trg(decoder_max) + self.pos_embed(decoder_max)
        decoder_input = torch.cat([decoder_input, decoder_max[-1:, :, :]])
    
      decoder_output = self.rearrange(decoder_output)  # (seq_len, batch_size, vocab_size) = > (batch_size, batch_size, vocab_size)
      return decoder_output


In [6]:
def download(url, out_path=None):
  filename = os.path.basename(url)
  if out_path is None:
    path = '.data'
  else:
    path = out_path
  zpath = os.path.join(path, filename)
  if not os.path.isfile(zpath):
      if not os.path.exists(os.path.dirname(zpath)):
          os.makedirs(os.path.dirname(zpath))
      print('downloading {}'.format(filename))
      download_from_url(url, zpath)
  zroot, ext = os.path.splitext(zpath)
  _, ext_inner = os.path.splitext(zroot)
  if ext == '.zip':
      with zipfile.ZipFile(zpath, 'r') as zfile:
          print('extracting')
          zfile.extractall(path)
  # tarfile cannot handle bare .gz files
  elif ext == '.tgz' or ext == '.gz' and ext_inner == '.tar':
      with tarfile.open(zpath, 'r:gz') as tar:
          dirs = [member for member in tar.getmembers()]
          tar.extractall(path=path, members=dirs)
  elif ext == '.gz':
      with gzip.open(zpath, 'rb') as gz:
          with open(zroot, 'wb') as uncompressed:
              shutil.copyfileobj(gz, uncompressed)

  return os.path.dirname(zpath)


In [7]:
train_url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/training.tar.gz'
path = download(train_url)
src_corpus  = os.path.join(path, 'train.de')
trg_corpus  = os.path.join(path, 'train.en')

downloading training.tar.gz


training.tar.gz: 100%|██████████| 1.21M/1.21M [00:04<00:00, 249kB/s]


In [8]:
val_url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz'
path = download(val_url)
val_src_corpus  = os.path.join(path, 'val.de')
val_trg_corpus  = os.path.join(path, 'val.en')

downloading validation.tar.gz


validation.tar.gz: 100%|██████████| 46.3k/46.3k [00:00<00:00, 80.9kB/s]


In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [10]:
class CorpusDataset(torch.utils.data.Dataset):
    """Defines a dataset for machine translation."""
    def __init__(self, src_path, trg_path, seq_len):
      super().__init__()
      self.src_lines = self.get_lines(src_path)
      self.trg_lines = self.get_lines(trg_path)

      assert len(self.src_lines) == len(self.trg_lines)

      self.src_vocab = self.build_vocab(self.src_lines)
      self.trg_vocab = self.build_vocab(self.trg_lines)
      self.seq_len = seq_len
      self.sos = self.src_vocab.stoi['<sos>']
      self.eos = self.src_vocab.stoi['<eos>']

    def __len__(self):
      return len(self.src_lines)

    def __getitem__(self, item):
      src_sent = self.src_lines[item]
      trg_sent = self.trg_lines[item]

      words = src_sent.split()
      tokens = self.word2tokens(self.src_vocab, words) 
      trg_words  = trg_sent.split()
      trg_tokens = self.word2tokens(self.trg_vocab, trg_words, self.eos)
      trg_tokens = [self.sos] + trg_tokens
      return (torch.tensor(tokens).to(device), torch.tensor(trg_tokens).to(device))

    def build_vocab(self, lines):
      counter = Counter()
      for line in lines:
        tokens = line.split()
        counter.update(tokens)
      vocab = Vocab(counter, min_freq=1, specials=['<pad>', '<sos>', '<eos>', '<unk>'])

      return vocab

    def get_lines(self, file_path):
      lines = open(file_path, mode='r', encoding='utf-8').readlines()
      lines = [line.strip() for line in lines ]

      return lines

    def word2tokens(self, vocab, words, eos=None):
      tokens  = []
      for word in words:
        token = vocab.stoi[word]
        tokens.append(token)

      if eos is not None:
        tokens = tokens + [eos]
      padding = [vocab.stoi['<pad>'] for _ in range(self.seq_len - len(tokens))]
      tokens = tokens + padding
      return tokens

    def tokens2word(self, vocab, tokens):
      words = []
      for token in tokens:
        try:
          word = vocab.itos[token]
        except Exception:
          word = '<unk>'
        words.append(word)

      return words


In [11]:
def tokens2words(tokens, dataset, vocab):
  words = dataset.tokens2word(vocab, tokens)
  for i, word in enumerate(words):
    if word == '<eos>':
      words = words[:i+1]
      break
  words = ' '.join(words) 
  words = words.replace('<pad>', '')
  return words

In [12]:
dataset = CorpusDataset(src_corpus, trg_corpus, seq_len=SEQ_LEN)
dataloader = DataLoader(dataset,batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
for (src, tgt) in dataloader:
  print(src[0])
  print('src=', tokens2words(src[0], dataset, dataset.src_vocab))
  print(tgt[0])
  print('tgt=', tokens2words(tgt[0], dataset, dataset.trg_vocab))
  break

tensor([    4,    10,    28,    44,     5, 15975,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
       device='cuda:0')
src= Ein Mann steht neben einem Obstkarren.                                            
tensor([  1,   5,  11,   9,  32,  44,   4, 650, 714,   2,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0], device='cuda:0')
tgt= <sos> A man is standing by a fruit cart. <eos>


In [13]:
val_dataset = CorpusDataset(val_src_corpus, val_trg_corpus, seq_len=SEQ_LEN)
val_dataloader = DataLoader(val_dataset,batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
for src, tgt in val_dataloader:
  print(src[0])
  print(tgt[0])
  break

tensor([  5,  34,   8,   4, 275, 380,  73, 116, 487,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0], device='cuda:0')
tensor([   1,    5,   36,    6,    4,  288, 1454,   52, 2322,    2,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0], device='cuda:0')


In [14]:
VOCAB_SIZE_SRC = len(dataset.src_vocab)
VOCAB_SIZE_TRG = len(dataset.trg_vocab)
VOCAB_SIZE_SRC, VOCAB_SIZE_TRG

(24893, 15460)

In [15]:
model = Transformer(vocab_size_src=VOCAB_SIZE_SRC,
                    vocab_size_trg=VOCAB_SIZE_TRG,
                    d_model = MODEL_DIM,
                    nhead = NUM_HEADS,
                    num_layers = NUM_LAYERS,
                    dim_feedforward = 2048,
                    dropout = 0.1,
                    max_seq_length=SEQ_LEN)

In [16]:
PAD_IDX = dataset.src_vocab.stoi['<pad>']
criterion = nn.NLLLoss()

In [17]:
for src, tgt in dataloader:
  src_key_padding_mask = (src == PAD_IDX).to(src.device)
  print(src[0])
  print(src_key_padding_mask[0])
  
  break

tensor([  19, 1541,    8, 2197,  107,    9,   23,  305,   52,   27,  412,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0], device='cuda:0')
tensor([False, False, False, False, False, False, False, False, False, False,
        False,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
       device='cuda:0')


In [18]:
model.eval()
if torch.cuda.is_available():
  model.cuda()
with torch.no_grad():
  for src, tgt in dataloader:
    src_key_padding_mask = (src == PAD_IDX).to(src.device)
    outputs = model.predict(src, src_key_padding_mask=src_key_padding_mask)
    print(outputs[0])
    break

tensor([[ -9.5506,  -9.6070,  -9.4486,  ...,  -9.2964, -10.1216,  -9.5001],
        [ -9.6494,  -9.3505,  -9.3675,  ...,  -9.2720, -10.0462,  -9.3168],
        [ -9.6610,  -9.2874,  -9.3991,  ...,  -9.0884, -10.2051,  -9.1475],
        ...,
        [ -9.7248,  -9.0629,  -9.2523,  ...,  -9.1900, -10.0133,  -9.0814],
        [ -9.7250,  -9.0620,  -9.2519,  ...,  -9.1894, -10.0123,  -9.0812],
        [ -9.7253,  -9.0611,  -9.2515,  ...,  -9.1888, -10.0114,  -9.0810]],
       device='cuda:0')


In [19]:
def train():
  torch.manual_seed(2147483647)
  if torch.cuda.is_available():
    model.cuda()
  optimizer = optim.Adam(model.parameters(), lr=0.0001)
  model.train()
  for epoch in range(NUM_EPOCHS):
    for step , (src, tgt) in enumerate(dataloader):
      optimizer.zero_grad()
      # src_key_padding_mask = (src == PAD_IDX).to(src.device)
      # tgt_key_padding_mask = (tgt[:, :-1] == PAD_IDX).to(tgt.device)
      src_key_padding_mask=None
      tgt_key_padding_mask=None
      outputs = model(src, tgt[:, :-1], src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask)
      outputs = torch.transpose(outputs, 1, 2)  # (batch, seq_len, vocab_size)  => (batch, vocab_size, seq_len)
      loss = criterion(outputs, tgt[:, 1:])
      print('epoch: {}, step: {}, loss: {:.3f}'.format(epoch + 1, step + 1, loss))
      loss.backward()
      nn.utils.clip_grad_norm_(model.parameters(), 5)
      optimizer.step()

In [20]:
train()

epoch: 1, step: 1, loss: 9.879
epoch: 1, step: 2, loss: 6.006
epoch: 1, step: 3, loss: 3.969
epoch: 1, step: 4, loss: 3.130
epoch: 1, step: 5, loss: 2.961
epoch: 1, step: 6, loss: 3.091
epoch: 1, step: 7, loss: 2.987
epoch: 1, step: 8, loss: 2.846
epoch: 1, step: 9, loss: 2.771
epoch: 1, step: 10, loss: 2.817
epoch: 1, step: 11, loss: 2.874
epoch: 1, step: 12, loss: 2.700
epoch: 1, step: 13, loss: 2.785
epoch: 1, step: 14, loss: 2.759
epoch: 1, step: 15, loss: 2.596
epoch: 1, step: 16, loss: 2.677
epoch: 1, step: 17, loss: 2.622
epoch: 1, step: 18, loss: 2.585
epoch: 1, step: 19, loss: 2.482
epoch: 1, step: 20, loss: 2.426
epoch: 1, step: 21, loss: 2.263
epoch: 1, step: 22, loss: 2.268
epoch: 1, step: 23, loss: 2.076
epoch: 1, step: 24, loss: 2.122
epoch: 1, step: 25, loss: 2.086
epoch: 1, step: 26, loss: 2.097
epoch: 1, step: 27, loss: 2.081
epoch: 1, step: 28, loss: 1.981
epoch: 1, step: 29, loss: 2.038
epoch: 1, step: 30, loss: 2.069
epoch: 1, step: 31, loss: 2.011
epoch: 1, step: 3

In [21]:
def eval():
  model.eval()
  with torch.no_grad():
    for step , (src, tgt) in enumerate(val_dataloader):
      outputs = model.predict(src)
      outputs = torch.transpose(outputs, 1, 2)  # (batch, seq_len, vocab_size)  => (batch, vocab_size, seq_len)
      loss = criterion(outputs, tgt[:, 1:])
      print(step, loss)


In [22]:
eval()

0 tensor(2.8666, device='cuda:0')
1 tensor(2.7830, device='cuda:0')
2 tensor(3.1528, device='cuda:0')
3 tensor(2.8343, device='cuda:0')
4 tensor(2.9242, device='cuda:0')
5 tensor(2.8963, device='cuda:0')
6 tensor(3.1037, device='cuda:0')
7 tensor(2.8479, device='cuda:0')
8 tensor(3.0067, device='cuda:0')
9 tensor(2.8770, device='cuda:0')
10 tensor(3.0304, device='cuda:0')
11 tensor(2.8544, device='cuda:0')
12 tensor(2.8019, device='cuda:0')
13 tensor(2.9889, device='cuda:0')
14 tensor(2.8037, device='cuda:0')


In [23]:
test_dataset = CorpusDataset(val_src_corpus, val_trg_corpus, seq_len=SEQ_LEN)
test_dataloader = DataLoader(test_dataset,batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

In [24]:
def predict():
  model.eval()
  with torch.no_grad():
    for step , (src, tgt) in enumerate(test_dataloader):
      if step > 10:
        break
      #outputs = model(src, tgt[:, :-1])
      outputs = model.predict(src)
      tokens = torch.argmax(outputs, dim=-1)
      sents = []
      for t in tokens:
        sent = tokens2words(t, test_dataset, test_dataset.trg_vocab)
        sents.append(sent)
      for sent, tgt in zip(sents, tgt.cpu().numpy()):
        print(step)
        print('predition: ', sent)
        target = tokens2words(tgt[1:], test_dataset, test_dataset.trg_vocab)
        print('target   : ', target)
        print()
        break


In [25]:
predict()


0
predition:  A man in a An and an white The at the middle yellow guitar. camera. all <eos>
target   :  A young man in a green sweatshirt reads a newspaper on the beach. <eos>

1
predition:  A suit and a shorts, woman men up a side <eos>
target   :  A bride and a groom at their wedding kissing <eos>

2
predition:  A blond during is ceiling. in the making of a brick <eos>
target   :  A brown, black, and white dog barking up a tree. <eos>

3
predition:  A from Twp is floor Cami looking a shirt and shirt and an single in the wall <eos>
target   :  There are multiple people going over a loop on an inverted roller coaster. <eos>

4
predition:  A black sitting men in front of a performing ready in front of a fallen stage <eos>
target   :  Two men converse near a wall with graffiti on it. <eos>

5
predition:  A blond during is field. a barefoot while guitar. is The street of a circular <eos>
target   :  A yellow dog carries a ball in its mouth on the beach. <eos>

6
predition:  A baseball dur