<a href="https://colab.research.google.com/github/yscope75/CS2225.CH2001020/blob/master/Image_captioning_Master_courses.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
# Update torchtext version
!pip3 install torchvision==0.8.0
!pip3 install torchtext==0.8.0



In [5]:
import torchvision.datasets as dset
import torchvision.datasets.utils as dset_utils
import torchvision.transforms as transforms
import os
import time
import math

In [6]:
import torch
from torch import nn
import torchvision
from torchsummary import summary
import json 
import matplotlib.pyplot as plt
import collections
from PIL import Image
import torchtext
from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import Vocab
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
import torch.optim as optim
import random

In [7]:
data_folder = os.path.join(os.path.abspath('.') + '/coco/')
# Download and unzip annotations
if not os.path.exists(data_folder):
  dset_utils.download_and_extract_archive(url='http://images.cocodataset.org/annotations/annotations_trainval2014.zip',
                                          download_root=data_folder,
                                          extract_root=data_folder,
                                          filename='captions.zip')
  # Download and unzion images
  dset_utils.download_and_extract_archive(url='http://images.cocodataset.org/zips/train2014.zip',
                                          download_root=data_folder,
                                          extract_root=data_folder,
                                          filename='train2014.zip')

In [8]:
# Get path to file and delete zip file
captions_train = os.path.join(data_folder, 'annotations/captions_train2014.json')
images_train = os.path.join(data_folder, 'train2014/')
captions_zip = os.path.join(data_folder, 'captions.zip')
images_zip = os.path.join(data_folder, 'train2014.zip')
if os.path.exists(captions_zip):
  os.remove(captions_zip)
if os.path.exists(images_zip):
  os.remove(images_zip)

In [9]:
# Setup device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [10]:
image_preprocessor = transforms.Compose([
    transforms.Resize(299),
    transforms.CenterCrop(299),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
 
coco_cap = dset.CocoCaptions(images_train, captions_train, transform=image_preprocessor)

loading annotations into memory...
Done (t=0.70s)
creating index...
index created!


In [11]:
train_slice = int(len(coco_cap)*0.6)
val_slice = int(len(coco_cap)*0.2)
train_data, val_data, test_data = torch.utils.data.random_split(coco_cap,
                                                                [train_slice, 
                                                                 val_slice,
                                                                 len(coco_cap)-train_slice-val_slice])


In [12]:
# create dataloader for images in coco
# image_loader = DataLoader(coco_cap, batch_size=10)
# Loop throught images and save encoded features 
# encoder = Encoder()
# encoder.cuda()
# encoded_features = []
# for id_batch, (img, caps) in enumerate(image_loader):
#   encoded_batch = encoder(img.to(device))
#   encoded_batch.to('cpu')
#   encoded_features.append(encoded_batch)

In [13]:
train_captions = []

for img, caps in coco_cap:
  train_captions.extend(caps)

In [14]:
# Building vocabulary for annotations
tokenizer = get_tokenizer("basic_english")
def build_vocab(sentences, tokenizer):
  counter = Counter()
  for sen in sentences:
    counter.update(tokenizer(sen))
  return Vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])

en_vocab = build_vocab(train_captions, tokenizer)

In [15]:
# Building data loader 
# Define params
BATCH_SIZE = 24
PAD_IDX = en_vocab['<pad>']
BOS_IDX = en_vocab['<bos>']
EOS_IDX = en_vocab['<eos>']

# process batch_data
def batch_process(batch_data):
  img_batch, cap_batch = [], []
  for img, caps in batch_data:
    img_batch.extend(img.repeat(len(caps),1,1,1))
    for cap in caps:
      sen_ids = torch.tensor([en_vocab[token] for token in tokenizer(cap)])
      cap_batch.append(torch.cat([torch.tensor([BOS_IDX]), 
                                 sen_ids, 
                                 torch.tensor([EOS_IDX])], dim=0))
  cap_batch = pad_sequence(cap_batch, batch_first=True, padding_value=PAD_IDX)

  return torch.stack(img_batch), cap_batch

train_iter = DataLoader(train_data, batch_size=BATCH_SIZE, collate_fn=batch_process)
valid_iter = DataLoader(val_data, batch_size=BATCH_SIZE, collate_fn=batch_process)
test_iter = DataLoader(test_data, batch_size=BATCH_SIZE, collate_fn=batch_process)

In [16]:
# Definition of main model
# Begin with Encoder
class Encoder(nn.Module):
  """
    Encode image input using pre-trained Resnet152 model on imagenet
  """
  def __init__(self):
    super(Encoder, self).__init__()
    resnet152 = torchvision.models.resnet152(pretrained=True)
    # remove the last two layers and keep the last CNN output 
    modules = list(resnet152.children())[:-2] 
    self.res_encoder = nn.Sequential(*modules) # last output (batch_size, 2048, 10, 10)
    # Flatten feature vector to (batch_size, 2048, 100)
    self.flat_embed = nn.Flatten(start_dim=2)
    

  def forward(self, X_in):
    """
      The forward pass of encoder 
      args:
      - X_in: input data batch of size (batch, 3, Height, weight)
      return: encoded images of size (batch, embed_size, 100)
    """
    e_out = self.res_encoder(X_in) 
    # Flatten output vector to (batch_size, embedding_size, 100)
    e_out = self.flat_embed(e_out)
    # Change shape of output encoded to (batch_size, 100, embedding_dim)
    e_out = e_out.permute(0, 2, 1)
    return e_out
    
  

In [17]:
class BahdanauAttention(nn.Module):
  """
    Define attention mechanism module on encoded image for genrating text
  """
  def __init__(self, encoder_dim, hidden_size, attention_size):
    """
      args: 
      - encoder_dim: size of encoded image (batch_size, 100, embedding_size(2048))
      - hidden_size: size of hidden unit in decoder RNN 
      - attention_size: 
    """
    super(BahdanauAttention, self).__init__()
    self.W1 = nn.Linear(encoder_dim, attention_size)   # size
    self.W2 = nn.Linear(hidden_size, attention_size)
    self.V = nn.Linear(attention_size, 1)
    
  def forward(self, encoded_feature, hidden):
    """
      args:
      - encoded_feature: 
    """
    # expand time dimension for hidden layer in decoder (batch_size, 1, hidden_size)
    # Compute attenntion for hidden (batch_size, 100, attention_size)
    attention_on_hidden = torch.tanh(self.W1(encoded_feature) + self.W2(hidden))
    # attention score on attention (batch, 100, 1)
    score = self.V(attention_on_hidden)
    # compute attention weights 
    attention_weights = torch.softmax(score, dim=1)
    context_vector = attention_weights*encoded_feature
    context_vector = torch.sum(context_vector, dim=1)
    
    return context_vector, attention_weights
    
    

In [18]:
torch.manual_seed(3)
class DecoderWithAttention(nn.Module):
  def __init__(self,
               embedding_dim,
               hidden_size,
               vocab_size,
               encoded_dim,
               attention: nn.Module,
               pretrained_embed=None):
    super(DecoderWithAttention, self).__init__()
    self.hidden_size = hidden_size
    self.vocab_size = vocab_size
    self.embedding_dim = embedding_dim
    self.encoded_dim = encoded_dim
    self.embedding = self.init_embedding(pretrained_embed)
    self.gru_in = self.embedding_dim + self.encoded_dim
    self.gru = nn.GRU(input_size=self.gru_in,
                      hidden_size=self.hidden_size,
                      batch_first=True)
    self.fc1 = nn.Linear(self.hidden_size, self.hidden_size)
    self.fc2 = nn.Linear(self.hidden_size, self.vocab_size)
    
    self.attention = attention
    
  def init_embedding(self, weight):
    """
      if pretrained embedding exists then load from pretrained
      else load from new one
    """
    embedding = nn.Embedding(self.vocab_size, self.embedding_dim)
    embedding.weight.data.uniform_(-0.1, 0.1)
    if weight is not None:
      embedding = nn.Embedding.from_pretrained(weight)
    
    return embedding
 
  def forward(self, encoded_features, x, hidden):
    
    # get necessary size
    batch_size = encoded_features.size(0)
    # compute context vector and attention weights 
    # context_vector: (batch_size, visual_embedding_size(2048))
    context_vector, atten_weights = self.attention(encoded_features, hidden.unsqueeze(1))
    # embed token x to vector
    # x: (batch, embedding_size)
    x = self.embedding(x).unsqueeze(1)
    # Concatinate context vector to input 
    x = torch.cat((x, context_vector.unsqueeze(1)), dim=-1)
    # output size: (batch_size, sequence_len, hidden_size)
    output, hn = self.gru(x, hidden.unsqueeze(0))
    output = output.squeeze(1)
    x = self.fc1(output)
    # Change size to (batch_size*sequence_len, hidden_size)
    x = x.view(-1, x.size()[-1])
    x = self.fc2(x)
 
    return x, hn.squeeze(), atten_weights

In [19]:
class Captioning(nn.Module):
  def __init__(self,
               encoder: nn.Module,
               decoder: nn.Module,
               device: torch.device):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder
    self.device = device
    
  def forward(self, 
              img_src: torch.Tensor,
              trg: torch.Tensor,
              teacher_forc_rate: float=0.5):
    batch_size = img_src.shape[0]
    max_len = trg.shape[1]

    trg_vocab_size = self.decoder.vocab_size

    outputs = torch.zeros(batch_size, max_len, trg_vocab_size).to(self.device)
    encoder_out = self.encoder(img_src)
    hidden = self.init_hidden(batch_size).to(device)
    output = trg[:, 0]
    for t in range(1, max_len):
      output, hidden, _ = self.decoder(encoder_out, output, hidden)
      outputs[:,t,:] = output
      teacher_force = random.random() < teacher_forc_rate 
      top = output.max(1)[1]
      output = (trg[:,t] if teacher_force else top)
      
    return outputs

  def init_hidden(self, batch_size):
    w = torch.empty(batch_size, self.decoder.hidden_size)
    return nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))

In [20]:
# Define training parameters 
OUTPUT_DIM = len(en_vocab)
EMB_DIM = 256
HIDDEN_DIM = 512
ATTEN_DIM = 512
ENCODE_DIM = 2048

In [21]:
encoder = Encoder()
attention = BahdanauAttention(ENCODE_DIM, HIDDEN_DIM, ATTEN_DIM)
decoder = DecoderWithAttention(EMB_DIM, HIDDEN_DIM, OUTPUT_DIM, ENCODE_DIM, attention)
model = Captioning(encoder, decoder, device).to(device)

In [22]:
optimizer = optim.Adam(model.parameters())
for param in encoder.parameters():
  param.requires_grad = False

In [23]:
PAD_IDX = en_vocab.stoi['<pad>']
loss_func = nn.CrossEntropyLoss(ignore_index=PAD_IDX)


In [24]:
def train(model: nn.Module,
          iterator: torch.utils.data.DataLoader,
          optimizer: optim.Optimizer,
          loss_func: nn.Module):

  model.train()

  epoch_loss = 0

  for _, (img, trg) in enumerate(iterator):
      img = img.to(device)
      trg = trg.to(device)

      optimizer.zero_grad()

      output = model(img, trg)

      output = output[:,1:,:].contiguous().view(-1, output.shape[-1])
      trg = trg[:,1:].contiguous().view(-1)

      loss = loss_func(output, trg)

      loss.backward()

      optimizer.step()

      epoch_loss += loss.item()

  return epoch_loss / len(iterator)

def evaluate(model: nn.Module,
             iterator: torch.utils.data.DataLoader,
             loss_func: nn.Module):

  model.eval()

  epoch_loss = 0

  with torch.no_grad():

      for _, (src, trg) in enumerate(iterator):
          src, trg = src.to(device), trg.to(device)

          output = model(src, trg, 0) 

          output = output[1:].contiguous().view(-1, output.shape[-1])
          trg = trg[1:].contiguous().view(-1)

          loss = loss_func(output, trg)

          epoch_loss += loss.item()

  return epoch_loss / len(iterator)

def epoch_time(start_time: int,
               end_time: int):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs


In [25]:
# define checkpoint info
CKP_PATH = 'model.pt'


In [None]:
N_EPOCHS = 10
 
best_valid_loss = float('inf')
 
for epoch in range(N_EPOCHS):
 
    start_time = time.time()
 
    train_loss = train(model, train_iter, optimizer, loss_func)
    valid_loss = evaluate(model, valid_iter, loss_func)
    if epoch % 4 == 0:
      torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': train_loss,
        }, CKP_PATH)
    end_time = time.time()
 
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
 
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')
 
test_loss = evaluate(model, test_iter, loss_func)
 
print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')

Epoch: 01 | Time: 49m 45s
	Train Loss: 4.178 | Train PPL:  65.212
	 Val. Loss: 4.930 |  Val. PPL: 138.439


In [None]:
def predict()

In [52]:
torch.cuda.empty_cache()