In [17]:
import torch 
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models

In [18]:
class LSTMCell(nn.Module):
    def __init__(self,embed_dim,hidden_dim):
        super().__init__()
        # weights
        # forget gate
        self.w_f = nn.Linear(hidden_dim+embed_dim,hidden_dim)
        # input gate
        self.w_i = nn.Linear(hidden_dim+embed_dim,hidden_dim)
        #  update state
        self.w_c = nn.Linear(hidden_dim+embed_dim,hidden_dim)
        # output gate
        self.w_o = nn.Linear(hidden_dim+embed_dim,hidden_dim)
        # weight initialization
        self.init_weights()
        
    def init_weights(self):
        # Initialization of weights using Xavier
        for layer in [self.w_f,self.w_i,self.w_c,self.w_o]:
            nn.init.xavier_uniform_(layer.weight)
            nn.init.zeros_(layer.bias)

    def forward(self,x,h,c):
        combined = torch.cat([h,x],dim=-1)
        # forget gate
        f = torch.sigmoid(self.w_f(combined))
        # input gate
        i = torch.sigmoid(self.w_i(combined))
        # update state
        c_cap = torch.tanh(self.w_c(combined))
        # forget and input update
        c = f*c+i*c_cap
        # output gate
        o = torch.sigmoid(self.w_o(combined))
        h = o*torch.tanh(c)
        return h,c

In [24]:
class LSTMDecoder(nn.Module):
    def __init__(self,embed_dim,hidden_dim,vocab_size):
        super().__init__()
        # weights
        self.embedding = nn.Embedding(vocab_size,embed_dim)
        self.lstm = LSTMCell(embed_dim,hidden_dim)
        self.ff = nn.Linear(hidden_dim,vocab_size)
        
    def forward(self,X,h,c):
        logit_list = []
        for x in X:
            embed_x = self.embedding(x)
            h,c = self.lstm(embed_x,h,c)
            logits = self.ff(h)
            logit_list.append(logits)
            
        return torch.stack(logit_list)

In [None]:
class ResnetEncoder(nn.Module):
    def __init__(self,hidden_dim):
        super().__init__()
        self.base_model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        for params in self.base_model.parameters():
            params.requires_grad=False

        self.base_model = nn.Sequential(*list(self.base_model.children())[:-1])
        self.fc = nn.Linear(2048,hidden_dim)

    def forward(self,x):
        x = self.base_model(x)
        x = x.view(x.size(0),-1) # (batch,2048)
        x = self.fc(x)
        return x.squeeze(0)

In [25]:
class ImageCaptionModel(nn.Module):
    def __init__(self,embed_dim,hidden_dim,vocab_size):
        super().__init__()
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.encoder = ResnetEncoder(hidden_dim)
        self.decoder = LSTMDecoder(embed_dim,hidden_dim,vocab_size)

    def forward(self,img,text_encode): # img -> (1,other dimentions)
        encoder_out = self.encoder(img)
        c = torch.zeros_like(encoder_out)
        decoder_out = self.decoder(text_encode,encoder_out,c)
        return decoder_out
