<a href="https://colab.research.google.com/github/yscope75/CS2225.CH2001020/blob/master/Image_captioning_Master_courses.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torchvision.datasets as dset
import torchvision.datasets.utils as dset_utils
import torchvision.transforms as transforms
import os

In [2]:
import torch
from torch import nn
import torchvision
from torchsummary import summary

In [3]:
annotation_folder = os.path.join(os.path.abspath('.') + '/coco/annotations/')
image_folder = os.path.join(os.path.abspath('.') + '/coco/train2014/')
# Download and unzip annotations 
dset_utils.download_and_extract_archive(url='http://images.cocodataset.org/annotations/annotations_trainval2014.zip',
                                        download_root=annotation_folder,
                                        extract_root=annotation_folder,
                                        filename='captions.zip')
# Download and unzion images
dset_utils.download_and_extract_archive(url='http://images.cocodataset.org/zips/train2014.zip',
                                        download_root=image_folder,
                                        extract_root=image_folder,
                                        filename='train2014.zip')

Downloading http://images.cocodataset.org/annotations/annotations_trainval2014.zip to /content/coco/annotations/captions.zip


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /content/coco/annotations/captions.zip to /content/coco/annotations/
Downloading http://images.cocodataset.org/zips/train2014.zip to /content/coco/train2014/train2014.zip


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /content/coco/train2014/train2014.zip to /content/coco/train2014/


In [4]:
# Setup device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [6]:
# Definition of main model
# Begin with Encoder
class Encoder(nn.Module):
  """
    Encode image input using pre-trained Resnet152 model on imagenet
  """
  def __init__(self):
    super(Encoder, self).__init__()
    resnet152 = torchvision.models.resnet152(pretrained=True)
    # remove the last two layers and keep the last CNN output 
    modules = list(resnet152.children())[:-2] 
    self.res_encoder = nn.Sequential(*modules) # last output (batch_size, 2048, 8, 8)
    # Flatten feature vector to (batch_size, 2048, 64)
    self.flat_embed = nn.Flatten(start_dim=2)
    # normalize variable input size to encoded size using adaptive pooling
    

  def forward(self, X_in):
    """
      The forward pass of encoder 
      args:
      - X_in: input data batch of size (batch, 3, Height, weight)
      return: encoded images of size (batch, embed_size, 64)
    """
    e_out = self.res_encoder(X_in) 
    # Flatten output vector to (batch_size, embedding_size, 64)
    e_out = self.flat_embed(e_out)
    # Change shape of output encoded to (batch_size, 64, embedding_dim)
    e_out = e_out.permute(0, 2, 1)
    return out
    
  

In [7]:
class BahdanauAttention(nn.Module):
  """
    Define attention mechanism module on encoded image for genrating text
  """
  def __init__(self, encoder_dim, hidden_size, attention_size):
    """
      args: 
      - encoder_dim: size of encoded image (batch_size, 64, embedding_size(2048))
      - hidden_size: size of hidden unit in decoder RNN 
      - attention_size: 
    """
    super(BahdanauAttention, self).__init__()
    self.W1 = nn.Linear(encoder_dim, attention_size)   # size
    self.W2 = nn.Linear(hidden_size, attention_size)
    self.V = nn.Linear(attentions_size, 1)
    
    def forward(self, encoded_feature, hidden):
      """
        args:
        - encoded_feature: 
      """
      # expand time dimension for hidden layer in decoder (batch_size, 1, hidden_size)
      hidden_with_time = torch.unsqueeze(hidden, dim=1)
      # Compute attenntion for hidden (batch_size, 64, attention_size)
      attention_on_hidden = torch.tanh(self.w1(encoded_feature) + self.W2(hidden))
      # attention score on attention (batch, 64, 1)
      score = self.V(attention_on_hidden)
      # compute attention weights 
      attention_weights = torch.softmax(score, dim=1)
      context_vector = attention_weights*encoded_feature
      context_vector = torch.sum(context_vector, dim=1)
      
      return context_vector, attention_weights
    
    

In [None]:
torch.manual_seed(3)
class DecoderWithAttention(nn.Module):
  def __init__(self,
               embedding_dim,
               hidden_size,
               vocab_size,
               encoded_dim,
               pretrained_embed=None):
    super(DecoderWithAttention, self).__init__()
    self.hidden_size = hidden_size
    self.vocab_size = vocab_size
    self.embedding_dim = embedding_dim
    self.encoded_dim = encoded_dim
    self.embedding = self.init_embedding(pretrained_embed)
    self.gru_in = self.embedding_dim + self.encoded_dim
    self.gru = nn.GRU(input_size=self.gru_in,
                      hidden_size=self.hidden_size,
                      batch_first=True)
    self.fc1 = nn.Linear(self.hidden_size, self.hidden_size)
    self.fc2 = nn.Linear(self.hidden_size, self.vocab_size)
    
    self.attention = BahdanauAttention(encoded_image, self.hidden_size)
    
  def init_embedding(self, weight):
    """
      if pretrained embedding exists then load from pretrained
      else load from new one
    """
    embedding = nn.Embedding(self.vocab_size, self.embedding_dim)
    embedding.weight.data.uniform_(-0.1, 0.1)
    if weight is not None:
      embedding = nn.Embedding.from_pretrained(weight)
    
    return embedding

  def forward(self, encoded_features, x, hidden):
    
    # get necessary size
    batch_size = encoded_features.size(0)
    # compute context vector and attention weights 
    # context_vector: (batch_size, visual_embedding_size(2048))
    context_vector, atten_weights = self.attention(encoded_features, hidden)
    # embed token x to vector
    # x: (batch, embedding_size)
    x = self.embedding(x)
    # Concatinate context vector to input
    x = torch.cat((x, context_vector), dim=-1)
    # initialize h0 (batch_size, hidden_size)
    h0 = init_hidden(batch_size)
    # output size: (batch_size, sequence_len, hidden_size)
    output, hn = self.gru(x, h0)
    x = self.fc1(output)
    # Change size to (batch_size*sequence_len, hidden_size)
    x = x.view(-1, x.size()[-1])
    x = self.fc2(x)

    return x, hn, atten_weights
    
  def init_hidden(self, batch_size):
    return nn.init.xavier_uniform_((batch_size, self.hidden_size))

