In [13]:
from pycocotools.coco import COCO

import torch
import torch.nn as nn

import torchvision.models as models

import numpy as np

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [17]:
class Encoder(nn.Module):
    
    def __init__(self, vocab_size, hidden_size):
        super(Encoder, self).__init__()
        
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        
        resnet = models.resnet101(pretrained=True)
        
        for params in resnet.parameters(): # will not be fine-tuning resnet
            params.requires_grad = False
        
        self.resnet = nn.Sequential(*list(resnet.children()[:-1])) # will not be using last layer of resnet since that layer outputs a 1000-D vector for imagenet classification
        self.embedding = nn.Linear(2048, hidden_size) # add a linear layer to get a feature vector to pass into the decoder
        
    def forward(self, inputs):
        output = self.resnet(inputs)
        output = self.embedding(outputs)
        return output

In [18]:
class Decoder(nn.Module):
    
    def __init__(self, vocab_size, hidden_size):
        super(Decoder, self).__init__()
        
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=hidden_size)
        self.rnn = nn.LSTM(input_size=hidden_size, hidden_size=hidden_size, batch_first=True) # word embeddings will be size hidden_size, hidden states will be size hidden_size, and batch_first=True since we want input dimensions (batch, seq, hidden)
        self.out = nn.Linear(in_features=hidden_size, out_features=vocab_size)
    
    def forward(self, inputs, encoder_out):
        """
        inputs: (batch, seq_len) set of captions, each entry is an integer index into the vocabulary
        encoder_out: (hidden_size,) vector of CNN encoder output
        """
        output = self.embedding(inputs)
        output = self.rnn(output, (encoder_out, encoder_out))
        output = self.out(output)
        return output