In [12]:
import torch
import torch.nn as nn

class EncoderCNN(nn.Module):
    def __init__(self, layers, hparams):
        '''
        Args:
            layers: Description of all layers in the Encoder: [(layer_type, {layer_params})]
                - layer types - ['conv1d', 'conv2d', 'maxpool1d', 'maxpool2d', 'avgpool2d', 'avgpool2d', 'linear', 'dropout']
                - layer_params - dict of parameters for the layer

            hparams: Hyperparameters for the model
        '''
        super(EncoderCNN, self).__init__()
        self.save_hyperparameters(hparams)
        self.layers = nn.ModuleList()

        for layer_type, layer_params in layers:
            if layer_type == 'conv1d':
                self.layers.append(nn.Conv1d(**layer_params))
            elif layer_type == 'conv2d':
                self.layers.append(nn.Conv2d(**layer_params))
            elif layer_type == 'maxpool1d':
                self.layers.append(nn.MaxPool1d(**layer_params))
            elif layer_type == 'maxpool2d':
                self.layers.append(nn.MaxPool2d(**layer_params))
            elif layer_type == 'avgpool1d':
                self.layers.append(nn.AvgPool1d(**layer_params))
            elif layer_type == 'avgpool2d':
                self.layers.append(nn.AvgPool2d(**layer_params))
            elif layer_type == 'linear':
                self.layers.append(nn.Linear(**layer_params))
            elif layer_type == 'dropout':
                self.layers.append(nn.Dropout(**layer_params))
            else:
                raise ValueError(f'Invalid layer type: {layer_type}')

    def forward(self, input):
        for layer in self.layers:
            input = layer(input)
        return input
    
class DecoderRNN(nn.Module):
    def __init__(self, vocabulary_size, embedding_size, input_size):
        super(DecoderRNN, self).__init__()
        '''
        Args:
            vocabulary_size: Size of the vocabulary
            embedding_size: Size of the embedding vector
        '''
        self.vocabulary_size = vocabulary_size
        self.embedding = nn.Embedding(vocabulary_size, embedding_size)
        self.embedding_size = embedding_size
        self.lstm = nn.LSTM(input_size+embedding_size, embedding_size)
        self.output = nn.Linear(embedding_size, vocabulary_size)

    def forward(self, input, prev_token_embedding, hidden = None):
        '''
        Args:
            input: Input context vector
            prev_token_embedding: Embedding of the previous token
            hidden: Hidden state of the LSTM
        '''
        if hidden == None: 
           return self.lstm(torch.cat((input, prev_token_embedding), dim=1))
        return self.lstm(torch.cat((input, prev_token_embedding), dim=1), hidden)
    
    def output_to_embedding(self, output):
        '''
        Args:
            output: Output of the decoder
        '''
        return self.embedding(torch.argmax(output))
    
    def token_idx_to_embedding(self, token_idx):
        '''
        Args:
            token_idx: Index of the token
        '''
        return self.embedding(token_idx)
    
    def training_step(self, batch, batch_idx):
        pass

In [13]:
hparams = {
    "lr" : 0.001,
    "batch_size" : 32,
    "epochs" : 10
}

channel_seq = [3, 32, 64, 128, 256, 512]
num_conv_pool = 5

enc_layers = []

for i in range(num_conv_pool):
    enc_layers.append(('conv2d', {'in_channels': channel_seq[i], 'out_channels': channel_seq[i+1], 'kernel_size': 5}))
    enc_layers.append(('maxpool2d', {'kernel_size': 2}))

enc_layers.append(('avgpool2d', {'kernel_size': (3,3)}))

enc = EncoderCNN(enc_layers, hparams)

vocab = getvocab()
dec = DecoderRNN(vocabulary=vocab, embedding_size=512, input_size=512)