In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.nn.utils.rnn import pack_padded_sequence

In [2]:
class EncoderCNN(nn.Module):
    def __init__(self,embed_size):
        super(EncoderCNN,self).__init__()
        resnet = models.resnet152(pretrained=True)
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)
        self.linear = nn.Linear(resnet.fc.in_features , embed_size)
        self.batch_norm = nn.BatchNorm1d(embed_size, momentum=0.01)
        
    def forward(self,images):
        with torch.no_grad():
            features = self.resnet(images)
        features = features.reshape(features.shape[0],-1)
        features = self.batch_norm(self.linear(features))
        return features

In [4]:
class DecoderRNN(nn.Module):
    def __init__(self,vocab_size,embed_size,hidden_size,num_layers,max_seq_length = 20):
        super(DecoderRNN,self).__init__()
        self.embed = nn.Embedding(vocab_size , embed_size)
        self.lstm = nn.LSTM(embed_size , hidden_size , num_layers ,batch_first = True )
        self.linear = nn.Linear(hidden_size , vocab_size)
        self.max_seq_length = max_seq_length
        
    def forward(self,features,captions,lengths):
        embeddings = self.embed(captions)
        embeddings = torch.cat((features.unsqueeze(1) , embeddings),1)
        packed = pack_padded_sequence(embeddings , lengths , batch_first = True)
        hiddens,_ = self.lstm(packed)
        outputs = self.linear(hiddens[0])
        return outputs
    
    def sample(self,features,states=None):
        sample_ids = []
        inputs = features.unsqueeze(1)
        for i in range(self.max_seq_length):
            hiddens ,states = self.lstm(inputs,states)
            outputs = self.linear(hiddens.squeeze(1))
            _, predicted = outputs.max(1)
            sample_ids.append(predicted)
            inputs = self.embed(predicted)
            inputs = inputs.unsqueeze(1)
        sample_ids = torch.stack(sample_ids,1)
        return sample_ids
        