<a href="https://colab.research.google.com/github/tchtinku/TorchScript/blob/main/NeurIPS_2018_PyTorch1_0_NMT_tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
#This will install a preview version of PyTorch 1.0
#This version is necessary for some features such as torch.jit.save to work
#This may take a few minutes
!pip install https://download.pytorch.org/whl/nightly/cu90/torch_nightly-1.0.0.dev20181128-cp36-cp36m-linux_x86_64.whl

[31mERROR: torch_nightly-1.0.0.dev20181128-cp36-cp36m-linux_x86_64.whl is not a supported wheel on this platform.[0m[31m
[0m

In [2]:
import torch
print(torch.__version__)

2.3.0+cu121


In [4]:
#Fetch IWSLT 2014 German-English data
import urllib.request
url = "https://download.pytorch.org/models/translate/iwslt14/data.tar.gz"
local_archive_name = "data.tar.gz"
urllib.request.urlretrieve(url, local_archive_name)

#Extract Files
!tar xvzf data.tar.gz

data/
data/valid.tok.bpe.en
data/valid.tok.de
data/train.tok.en
data/train.tok.de
data/valid.tok.bpe.de
data/valid.tok.en
data/test.tok.de
data/train.tok.bpe.en
data/test.tok.bpe.en
data/train.tok.bpe.de
data/test.tok.bpe.de
data/test.tok.en


In [None]:
#simple class to induce a vocabulary from a text file
class Dictionary:

  def __init__(self):
    self.pad_index = 0
    self.eos_index = 1
    self.unk_index = 2
    self.token_indices = {
        "<pad>": self.pad_index,
        "<eos>": self.eos_index,
        "<unk>": self.unk_index,
    }
    self.tokens = ["<pad>", "<eos>", "<unk>"]

    @staticmethod
    def induce_from_file(filename, max_size=50000):
      from collections import Counter
      text = open(filename).read()
      token_counts = Counter(text.split())

      d = Dictionary()
      for token, _ in token_counts.most_common(max_size):
        d.token_indices[token] = len(d.token_indices)
        d.tokens.append(token)

      return d

      def get_index(self, token):
        return self.token_indices.get(token, self.unk_index)

      def size(self):
        return len(self.token_indices)

      def get_token(self, index):
        if index > len(self.tokens):
          return "<unk>"
        return self.tokens[index]

In [None]:
src_dict = Dictionary.induce_from_file("data/train.tok.de")
print("Loaded source vocabulary of size: ", src_dict.size())
trg_dict = Dictionary.induce_from_file("data/train.tok.en")
print("Loaded target vocabulary of size: ", trg_dict.size())

In [None]:
from torch.nn.utils.rnn import (
    pack_padded_sequence,
    pad_packed_sequence,
)

class LstmEncoder(torch.nn.module):
  def __init__(self, embed_dim, hidden_dim, vocab_size):
    super().__init__()

    self.embed_dim = embed_dim
    self.hidden_dim = hidden_dim
    self.vocab_size = vocab_size

    self.embed_tokens = torch.nn.Embedding(vocab_size, embed_dim)
    torch.nn.init.uniform_(self.embed_tokens.weight, -0.1, 0.1)

    #hidden_dim is combined output dim from both directions
    self.lstm = torch.nn.LSTM(
        input_size = embed_dim,
        hidden_size = hidden_dim // 2,
        bidirectional = True
    )

  def forward(self, src_tokens, src_lengths):
    embeddings = self.embed_tokens(src_tokens)

    #Generate packed seq to deal with varying source seq length
    #packed_input is of type PackedSequence, which consists of:
    #element [0]: a tensor, the packed data, and
    #element [1]: a list of integers, the batch size for each step
    packed_input = pack_padded_sequence(embeddings, src_lengths)

    packed_output, (_, _) = self.lstm(packed_input)

    # [max_seqlen, batch_size, hidden_dim]
    unpacked_output, _ = pad_packed_sequence(packed_output)

    return unpacked_output

In [None]:
def attention(decoder_state, encoder_outputs):
  """
  decoder_state: trg_len x bsz x dim
  encoder_outputs: src_len x bsz x dim
  """
  #bsz x trg_len x dim
  decoder_state_t = decoder_state.transpose(0, 1)
  #bsz x dim x src_len
  encoder_outputs_t = encoder_outputs.permute(1, 2, 0)
  #bsz x trg_len x src_len
  dot_product = torch.bmm(decoder_state_t, encoder_outputs_t)
  #Note: including invalid (padded) positions for code simplicity
  norm_dot_product = torch.softmax(dot_product, dim=2)
  # bsz x src_len x dim
  encoder_outputs_tt = encoder_outputs.transpose(0, 1)
  # bsz x trg_len x dim
  context = torch.bmm(norm_dot_product, encoder_outputs_tt)
  # trg_len x bsz x dim
  return context.transpose(0, 1)

In [None]:
class LstmDecoder(torch.nn.Module):
  def __init__(self, embed_dim, hidden_dim, vocab_size):
    super().__init__()

    self.embed_dim = embed_dim
    self.hidden_dim = hidden_dim
    self.vocab_size = vocab_size

    self.embed_tokens = torch.nn.Embedding(vocab_size, embed_dim)
    torch.nn.init.uniform_(self.embed_tokens.weight, -0.1, 0.1)

    self.lstm = torch.nn.LSTM(
        input_size=embed_dim,
        hidden_size=hidden_dim,
    )

    self.output_projection = torch.nn.Linear(2 * hidden_dim, vocab_size)

  def forward(self, input_tokens, encoder_out, prev_state=None):
    seqlen, bsz = input_tokens.size()
    x = self.embed_tokens(input_tokens)

    if prev_state is None:
      h_prev = torch.zeros([1, bsz, self.hidden_dim]).type_as(x)
      c_prev = torch.zeros([1, bsz, elf.hidden_dim]).type_as(x)
    else:
      h_prev, c_prev = prev_state

    x, (h_next, c_next) = self.lstm(x, (h_prev, c_prev))

    encoder_text = attention(x, encoder_out)
    x = torch.cat([x, encoder_context], dim=2)

    logits = self.output_projection(x)

    return logits, (h_next, c_next)



In [None]:
class LstmSeq2Seq(torch.nn.Module):
  def __init__(
      self,
      encoder_embed_dim,
      decoder_embed_dim,
      hidden_dim,
      src_dict,
      trg_dict,
  ):
  super().__init__()
  self.src_dict = src_dict
  self.trg_dct = trg_dict
  self.encoder = LstmEncoder(
      embed_dim = encoder_embed_dim,
      hidden_dim = hidden_dim,
      vocab_size = src_dict.size(),
  )
  self.decoder = LstmDecoder(
      embed_dim=decoder_embed_dim,
      hidden_dim=hidden_dim,
      vocab_size=trg_dict.size(),
  )

  def forward(self, src_tokens, src_lengths, prev_output_tokens):
    encoder_out = self.encoder(src_tokens, src_lengths)
    decoder_out = self.decoder(prev_output_tokens, encoder_out)
    return decoder_out

In [None]:
import numpy as np

class Corpus():
  def __init__(self, src_path, trg_path, src_dict, trg_dict):
    self.src_dict = src_dict
    self.trg_dict = trg_dict


    self.src_inds = []
    for line in open(src_path):
      inds = []
      for token in line.split():
        inds.append(src_dict.get_index(token))
      self.src_inds.append(inds)


      self.trg_inds = []
      for line in open(trg_path):
        inds = []
        for token in line.split():
          inds.append(trg_dict.eos_index)
          self.trg_inds.append(inds)

      self.batches = None

  def pad_batch(self, pairs):
    """
    Input pairs is list of 2-tuples (src, trg) where each element is a list of indices.
    Output paddedis a list of 3-tuples (src, trg, src_length), which also includes the original length of the source sentence