In [None]:
class Decoder(nn.Module):
    """
    Inputs:
    (here M is whatever the batch size is passed)

    context_size : size of the context vector [shape: (1,M,context_size)]
    n_layers: number of layers [for our purposes, defaults to 1]
    hidden_size : size of the hidden state vectors [shape: (n_layers,M,hidden_size)]
    embed_size : size of the embedding vectors [shape: (1,M,embed_size)]
    vocab_size : size of the vocabulary
    max_length : maximum length of the formula
    """
    def __init__(self, context_size, vocab, n_layers = 1, hidden_size = 512, embed_size = 512,  max_length = 100):
        super().__init__()
        self.context_size = context_size
        self.vocab = vocab
        self.vocab_size = vocab.N
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embed_size = embed_size
        self.max_length = max_length


        self.input_size = context_size + embed_size
        self.embed = nn.Embedding(self.vocab_size, embed_size)
        
        self.lstm = nn.LSTMCell(self.input_size, hidden_size, n_layers)
        self.linear = nn.Linear(hidden_size, self.vocab_size)
        self.softmax = nn.Softmax(dim = 2)

        pass
    
    def forward(self, context, target_tensor = None):
        """
        context is the context vector from the encoder [shape: (1,M,context_size)]
        target_tensor is the formula in tensor form [shape: (1,M,max_length)] (in the second dimension, it is sequence of indices of formula tokens)
            if target_tensor is not None, then we are in Teacher Forcing mode
            else normal jo bhi (last prediction is concatenated)
        """
        batch_size = context.shape[1]

        #initialize hidden state and cell state
            #@TODO: Some caveat in the size of the cell vector. Should it be same as hidden_size? (check nn.LSTM documentation)
        hidden = torch.zeros((self.n_layers, batch_size, self.hidden_size))
        cell = torch.zeros((self.n_layers, batch_size, self.hidden_size))

        #initialize the input with embedding of the start token
        init_embed = self.embed(torch.tensor([self.vocab.wd_to_id[START_TOKEN]])).reshape((1, batch_size, self.embed_size))
        input = torch.cat([context, init_embed], dim = 2)

        #initialize the output
        output = torch.zeros((1, batch_size, self.vocab_size))

        for i in range(self.max_length):
            output, (hidden, cell) = self.lstm(input, (hidden, cell))
            output = self.linear(output)
            output = self.softmax(output)

            
            if target_tensor is not None:
                input = torch.cat([context, self.embed(target_tensor[0, :, i]).reshape((1,batch_size, self.embed_size))], dim = 2)
            else:
                #add the embedding of the last prediction
                input = torch.cat([context, self.embed(torch.argmax(output, dim = 2))], dim = 2)

In [None]:
class DecoderRNN(nn.Module):
    """
    INPUTS
    context_size : size of the context vector
    hidden_size : size of the hidden latent vectors
    embed_size : literal
    vocab_size : literal
    output_size : one_hot?
    """
    def __init__(self, vocab, context_size, hidden_size, embed_size, output_size, max_length):
        super().__init__()

        #class variables
        self.embed_size = embed_size
        self.context_size = context_size
        self.max_length = max_length
        self.vocab = vocab
        vocab_size = vocab.N

        #compute input size, concatenating context and prev. output embedding
        input_size = context_size + embed_size

        self.embedding = nn.Embedding(vocab_size, embed_size)

        self.lstm = nn.LSTM(input_size, hidden_size, num_layers = 1)

        self.out = nn.Linear(hidden_size, output_size) #output_size = vocab_size
    
    def forward(self, context, target_tensor = None):
        """
        target_tensor is of size MAX_LENGTH
        """
        #START Token handling
        batch_size = context.size(0)
        start_id = self.vocab.get_id(START_TOKEN)
        start_tensor = torch.empty(batch_size, 1, dtype = torch.int64).fill_(start_id)

        decoder_input = torch.concatenate((context, self.embedding(start_tensor)), dim = 0)

        print(f'Context shape: {context.shape}, decoder_input shape: {decoder_input.shape}')
        print(f'embedding shape: {self.embedding(start_tensor).shape}')
        print('====================================')

        decoder_hidden = context  #dimensions are same
        decoder_outputs = []

        for i in range(self.max_length):
            decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
            decoder_outputs.append(decoder_output)

            if target_tensor is not None:
                input_tensor = self.vocab.get_id(target_tensor[i])  #assuming target_tensor[i] is just a number
                ground_truth_embed = self.embedding(input_tensor)
                decoder_input = torch.concatenate((context, ground_truth_embed), dim = 0)
            else:
                #embed the last output, which was an index of vocab
                last_out_embed = self.embedding(decoder_outputs[-1])
                decoder_input = torch.concatenate((context, last_out_embed), dim = 0)

        return decoder_outputs, decoder_hidden, None
        
    def forward_step(self, input, hidden):
        print('+++++++++++++++++++++++++=')
        print(f'Input shape: {input.shape}, hidden shape: {hidden.shape}')
        output, hidden = self.lstm(input, hidden)
        print(f'New hidden shape: {hidden.shape}')
        output = self.out(hidden)

        #get the output as just an index tensor
        output = torch.argmax(output, dim = -1)

        return output, hidden