In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchtext.datasets import Multi30k
from torchtext.data import Field, BucketIterator

import spacy
import numpy as np

import random
import math
import time

In [2]:
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [3]:
#! python -m spacy download en

In [4]:
spacy_de = spacy.load('de')
spacy_en = spacy.load('en')

In [5]:

def tokenize_de(text):
    # Tokenizes German text from a string into a list of strings
    return [tok.text for tok in spacy_de.tokenizer(text)]

def tokenize_en(text):
    # Tokenizes English text from a string into a list of strings
    return [tok.text for tok in spacy_en.tokenizer(text)]

In [6]:

SRC = Field(tokenize = tokenize_de, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True)

TRG = Field(tokenize = tokenize_en, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True)



In [7]:
train_data, valid_data, test_data = Multi30k.splits(exts = ('.de', '.en'),fields = (SRC, TRG))



In [8]:
SRC.build_vocab(train_data, min_freq = 2)
TRG.build_vocab(train_data, min_freq = 2)

In [9]:
BATCH_SIZE = 96
train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE)



In [10]:
class Encoder(nn.Module):
    def __init__(self,input_dim,emb_dim,enc_hid_dim,dec_hid_dim,dropout):
        super().__init__()
        self.num_layers = 1
        self.bidirectional = True
        self.embedding = nn.Embedding(input_dim,emb_dim)
        self.rnn = nn.GRU(emb_dim,enc_hid_dim,num_layers=self.num_layers,dropout=dropout,bidirectional=self.bidirectional)
        self.fc = nn.Linear(enc_hid_dim*(2 if self.bidirectional else 1),dec_hid_dim)
        self.dropout = nn.Dropout(dropout)
    def forward(self,src):
        '''
        src= [src_len,batch_size]
        '''
        src = src.transpose(0,1) #[batch_size,src_len]
        embedded = self.dropout(self.embedding(src)).transpose(0,1)  #[src_len,batch_size,emb_dim]
        #enc_output [src_len,batch_size,num_directions * hidden_size]
        #enc_hidden [num_layers * num_directions, batch, hidden_size]
        enc_output,enc_hidden = self.rnn(embedded)  
        if self.bidirectional:
            hidden = torch.cat((enc_hidden[-1,:,:],enc_hidden[-2,:,:]),dim = 1)
        else:
            hidden = enc_hidden[-1,:,:]
        #s = [batch,dec_hid_dim]
        s = torch.tanh(self.fc(hidden))
        
        return enc_output,s
        

In [11]:
class Decoder(nn.Module):
    def __init__(self,output_dim,emb_dim,enc_hid_dim,dec_hid_dim,dropout):
        super().__init__()
        self.num_layers = 1
        self.output_dim = output_dim
        self.embedding = nn.Embedding(output_dim,emb_dim)
        self.rnn = nn.GRU(emb_dim,dec_hid_dim,num_layers=self.num_layers,dropout=dropout)
        self.fc = nn.Linear(dec_hid_dim,output_dim)
        self.drop = nn.Dropout(dropout)
    def forward(self,dec_input,s):
        '''
        dec_input = [batch_size]
        s = [num_layers ,batch_size, dec_hid_dim]
        '''
        dec_input = dec_input.unsqueeze(1) # dec_input = [batch_size, 1]
        
        # embedded = [1,batch_size,dec_emb_dim]
        embedded = self.drop(self.embedding(dec_input)).transpose(0,1)
        #dec_output [1,batch_size,dec_hid_dim]
        #dec_hidden = [num_layers,batch_size,dec_hid_dim]
        dec_output,dec_hidden = self.rnn(embedded,s)
        
        #dec_output = [batch_size,dec_hid_dim]
        dec_output = dec_output.squeeze(0)
        
        #dec_output [batch_size,out_dim]
        dec_output = self.fc(dec_output)
        
        return dec_output,dec_hidden

In [12]:
class Seq2seq(nn.Module):
    def  __init__(self,encoder,decoder):
        super(Seq2seq,self).__init__()
        self.encoder = encoder
        self.decoder = decoder
    def forward(self,src, target,teacher_forcing_ratio=0.5):
        # src = [src_len, batch_size]
        # trg = [trg_len, batch_size]
        # teacher_forcing_ratio is probability to use teacher forcing
        batch_size = src.shape[1]
        trg_len = target.shape[0]
        trg_vocab_size = self.decoder.output_dim
        
        #enc_out = [src_len,batch_size,num_directions * hidden_size]
        #s = [batch_size,dec_hid_dim]
        enc_out,s = self.encoder(src)
        
        s = s.repeat(self.decoder.num_layers,1,1)
        #outputs [batch_size,trg_len,trg_vocab_size]
        outputs = torch.zeros(trg_len,batch_size,trg_vocab_size).cuda()
        input = target[0,:]
        for i in range(trg_len):
            dec_out,s = self.decoder(input,s)
            outputs[i,:,:] = dec_out
            
            # get the highest predicted token from our predictions
            top1 = dec_out.argmax(1) 
            teacher_force = random.random() < teacher_forcing_ratio
            input = target[i,:] if teacher_force else top1
        return outputs

In [13]:
encode = Encoder(len(SRC.vocab),10,128,128,0).cuda()
decode = Decoder(len(TRG.vocab),10,128,128,0).cuda()
model = Seq2seq(encode,decode).cuda()

In [14]:

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),1e-3)

In [17]:
epoch_loss = 0
for batch in train_iterator:
    src = batch.src.cuda()
    trg = batch.trg.cuda() # trg = [trg_len, batch_size]
    pred = model(src,trg)
    pred_dim = pred.shape[-1]
    # trg = [(trg len - 1) * batch size]
    # pred = [(trg len - 1) * batch size, pred_dim]
    trg = trg[1:].view(-1)
    pred = pred[1:].view(-1, pred_dim)
        
    loss = criterion(pred, trg)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    epoch_loss += loss.item()
    break

In [18]:
a = torch.randn((3,3))

In [19]:
print(a.repeat(2,1))

tensor([[-1.1824, -0.0137,  0.9211],
        [-0.8134,  0.6702, -0.4294],
        [-0.2126, -0.8452, -1.4732],
        [-1.1824, -0.0137,  0.9211],
        [-0.8134,  0.6702, -0.4294],
        [-0.2126, -0.8452, -1.4732]])
