In [12]:
import numpy as np

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as f

In [55]:
class Encoder(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, hidden_dim, num_layers, padding_dim=0):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings, embedding_dim, padding_dim)
        self.encoder = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True)
        
    def forward(self, input):
        '''
        input dim => 2d no_sentences * no_words
        make sure that the input is in reverse
        
        Outputs:
        output, h, c
        
        output => dim: (num_sentences * num_words * hidden_dim)
        h => (num_layers * num_sentences * hidden_dim)
        c => (num_layers * num_sentences * hidden_dim)
        '''
        embeddings = self.embedding(input)
        output, (h, c) = self.encoder(embeddings)
        
        return output, h, c

In [54]:
enc = Encoder(10, 10, 10, 3)
input = torch.LongTensor(np.random.randint(0, 10, (2, 5)))
a, b, c = enc(input)
a.shape, b.shape, c.shape

(torch.Size([2, 5, 10]), torch.Size([3, 2, 10]), torch.Size([3, 2, 10]))

In [161]:
class Decoder(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, hidden_dim, num_layers, padding_dim=0):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings, embedding_dim, padding_dim)
        self.decoder = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True)
        self.word_predictor = nn.Linear(embedding_dim, num_embeddings)
        
    def forward(self, output, h, c, input=None):
        '''
        Output:
        pred: (no_sentences * no_words * no_embeddings)
        **Note: if input is not None, log_softmax is returned
        else, input is returned without softmax**
        '''
        if input is not None: # training with teacher forcing
            embeddings = self.embedding(input)
            yts, _ = self.decoder(embeddings, (h, c))
            orig_shape = yts.shape
            pred = self.word_predictor(yts.reshape(orig_shape[0]*orig_shape[1], -1)).reshape(*orig_shape[:-1], -1)
            return torch.softmax(pred, dim=-1)
        else:
            raise NotImplementedError