# Build Model

In [1]:
import torch
from torch import nn
from torchvision import models

In [None]:
# Torch Module using GloVe Embeddings

class CaptionEncoder(nn.module):
    def __init__(self, embed_size):
        super(CaptionEncoder, self).__init__()

        self.embed_size = embed_size
        
        self.embed = nn.Embedding.from_pretrained(torch.load('glove.twitter.27B.pt'))

        self.lstm = nn.LSTM(input_size=300, hidden_size=embed_size, num_layers=1, batch_first=True)
        
    def forward(self, x):
        x = self.embed(x)
        x, _ = self.lstm(x)
        return x

In [None]:
# Torch module for image encoding using Inceptionv4

class ImageEncoder(nn.Module):
    def __init__(self, embed_size):
        super(ImageEncoder, self).__init__()

        self.embed_size = embed_size

        self.inception = models.inception_v3(pretrained=True)
        self.inception.fc = nn.Linear(in_features=2048, out_features=embed_size)

    def forward(self, x):
        x = self.inception(x)
        return x

In [None]:
class ImageLabelEncoder(nn.Module):
    """ImageLabel encoder.

    Encodes images and text labels into a single embedding of size `emb_dim`.
    """

    def __init__(self, num_tokens, emb_dim=256, dropout=0.2):
        """Initializes LabelEncoder.

        Args:
            num_tokens: number of tokens in the vocabulary
            emb_dim (int): dimensions of the output embedding
            dropout (float): dropout for the encoded features
        """
        super().__init__()
        self.image_encoder = ImageEncoder(emb_dim, dropout)
        self.label_encoder = LabelEncoder(num_tokens, emb_dim, dropout)
        self.linear = nn.Linear(2 * emb_dim, emb_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, images, labels):
        """
        Args:
            images (torch.Tensor): input images of shape `[bs, width, height]`
            labels (torch.Tensor): input text labels of shape `[bs, seq_len]`

        Returns:
            torch.Tensor: combined image-label embedding of shape `[bs, emb_dim]`
        """
        image_emb = self.image_encoder(images)
        label_emb = self.label_encoder(labels)

        emb = torch.cat([image_emb, label_emb], dim=1)
        emb = self.dropout(self.linear(emb))

        return emb