# Training the Image Captioning Model

This notebook illustrates how to train the image captioning model for free on Colab. I synced my local files to a folder in Google Drive, allowing changes to the code to be auto-reloaded in the notebook. However, you could also clone the repo from Github.




## Set up notebook and imports

In [1]:
!pip install "pytorch_lightning==1.1.5" tokenizers wandb --quiet

In [2]:
%load_ext autoreload
%autoreload 2

# Connecting to Google Drive will allow us to load the Kaggle dataset and
# checkpoint models during training

from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
import os
import sys

import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import wandb

In [4]:
# Log metrics to https://wandb.ai/
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mjustinreppert[0m (use `wandb login --relogin` to force relogin)


True

In [5]:
# For Reproducibility
pl.seed_everything(42)

Global seed set to 42


42

In [6]:
# Setup Google Drive paths
GOOGLE_DRIVE_PATH_FOR_PROJECT = 'captioning'  # Location of project files in Google Drive
GOOGLE_DRIVE_PATH = os.path.join('drive', 'My Drive', GOOGLE_DRIVE_PATH_FOR_PROJECT)
GOOGLE_DRIVE_FLICKR_PATH = os.path.join('drive', 'MyDrive', GOOGLE_DRIVE_PATH_FOR_PROJECT, 'flickr30k.zip')
print(os.listdir(GOOGLE_DRIVE_PATH))
sys.path.append(GOOGLE_DRIVE_PATH)

['.ipynb_checkpoints', '__pycache__', 'flickr30k.zip', '.gitignore', 'README.md', 'setup.cfg', 'requirements.txt', 'LICENSE', 'tests', 'project', 'setup.py', 'checkpoints']


In [7]:
# Import project modules from Google Drive
from project.captioners import CaptioningRNN
from project.datasets import load_flickr_csv, load_coco_captions_json, CombinedDataModule, sample_minibatch
from project.utils import sample_predictions

In [8]:
if torch.cuda.is_available:
  print('Good to go!')
else:
  print('Select Runtime -> Change Runtime Type -> Hardware Accelerator -> GPU')

Good to go!


## Download Training Data

In [9]:
# Copy Flickr dataset from Google Drive and unzip; this will take a few minutes
if not os.path.exists('flickr30k_images'):
    !cp $GOOGLE_DRIVE_FLICKR_PATH .
    !unzip 'flickr30k.zip' > /dev/null
!ls flickr30k_images

flickr30k_images  results.csv


In [10]:
# Download the COCO captions dataset and unzip
DOWNLOAD_COCO = True
if DOWNLOAD_COCO:
    if not os.path.exists('coco2014_train'):
        !curl -L -C - http://images.cocodataset.org/zips/train2014.zip -o train2014.zip
        !unzip 'train2014.zip' -d coco2014_train > /dev/null
        !curl -L -C - http://images.cocodataset.org/annotations/annotations_trainval2014.zip -o annotations_trainval2014.zip
        !unzip 'annotations_trainval2014.zip' -d coco_labels > /dev/null
    !ls

annotations_trainval2014.zip		flickr30k.zip
coco2014_train				lightning_logs
coco_labels				sample_data
drive					train2014.zip
flickr30k_images			vocab.txt
flickr30k_tokenizer			wandb
flickr30k_tokenizertokenizer-vocab.txt


In [11]:
# Preview COCO dataset
if DOWNLOAD_COCO:
    COCO_JSON_PATH = 'coco_labels/annotations/captions_train2014.json'
    COCO_IMG_DIR = 'coco2014_train/train2014'

    captions_df = load_coco_captions_json(COCO_JSON_PATH, COCO_IMG_DIR)
captions_df.head()

Unnamed: 0,path,0,1,2,3,4
0,coco2014_train/train2014/COCO_train2014_000000...,a restaurant has modern wooden tables and chairs,a long restaurant table with rattan rounded ba...,a long table with a plant on top of it surroun...,a long table with a flower arrangement in the ...,a table is adorned with wooden chairs with blu...
1,coco2014_train/train2014/COCO_train2014_000000...,a man preparing desserts in a kitchen covered ...,a chef is preparing and decorating many small ...,a baker prepares various types of baked goods,a close up of a person grabbing a pastry in a ...,close up of a hand touching various pastries
2,coco2014_train/train2014/COCO_train2014_000000...,a big red telephone booth that a man is standi...,a person standing inside of a phone booth,this is an image of a man in a phone booth,a man is standing in a red phone booth,a man using a phone in a phone booth
3,coco2014_train/train2014/COCO_train2014_000000...,the kitchen is full of spices on the rack,a kitchen with counter oven and other accesso...,a small kitchen that utilizes all of its space,this small kitchen has pots pans and spices o...,a very small kitchen with a stove and a shelf ...
4,coco2014_train/train2014/COCO_train2014_000000...,a child and woman are cooking in the kitchen,a woman glances at a young girls cooking on th...,a young girl and a woman preparing food in a k...,a young person and an older person in a kitchen,two women cooking on stove in a kitchen together


In [12]:
# Preview Flickr dataset
FLICKR30K_IMG_DIR = 'flickr30k_images/flickr30k_images'
FLICKR30K_CSV_PATH = 'flickr30k_images/results.csv'

captions_df = load_flickr_csv(FLICKR30K_CSV_PATH, FLICKR30K_IMG_DIR)
captions_df.head()

Unnamed: 0,path,0,1,2,3,4
0,flickr30k_images/flickr30k_images/1000092795.jpg,two young guys with shaggy hair look at their ...,two young white males are outside near many ...,two men in green shirts are standing in a yard,a man in a blue shirt standing in a garden,two friends enjoy time spent together
1,flickr30k_images/flickr30k_images/10002456.jpg,several men in hard hats are operating a giant...,workers look down from up above on a piece of ...,two men working on a machine wearing hard hats,four men on top of a tall structure,three men on a large rig
2,flickr30k_images/flickr30k_images/1000268201.jpg,a child in a pink dress is climbing up a set o...,a little girl in a pink dress going into a woo...,a little girl climbing the stairs to her playh...,a little girl climbing into a wooden playhouse,a girl going into a wooden building
3,flickr30k_images/flickr30k_images/1000344755.jpg,someone in a blue shirt and hat is standing on...,a man in a blue shirt is standing on a ladder ...,a man on a ladder cleans the window of a tall ...,man in blue shirt and jeans on ladder cleaning...,a man on a ladder cleans a window
4,flickr30k_images/flickr30k_images/1000366164.jpg,two men one in a gray shirt one in a black...,two guy cooking and joking around with the cam...,two men in a kitchen cooking food on a stove,two men are at the stove preparing food,two men are cooking a meal


## Overfit tiny dataset as a sanity check

In [13]:
# # Sanity overfitting check: First, make a tiny dataset
# tiny_dataset = CombinedDataModule(
#     flickr_csv=FLICKR30K_CSV_PATH,
#     flickr_dir=FLICKR30K_IMG_DIR,
#     coco_json=COCO_JSON_PATH,
#     coco_dir=COCO_IMG_DIR,
#     num_workers=2,
#     batch_size=4,
#     dev_set=12,
#     val_size=4,
#     test_size=4,
#     transform='normalize',
#     target_transform='tokenize',
# )
# tiny_dataset.setup()

# # Check out the images and captions
# minibatch = next(iter(tiny_dataset.train_dataloader()))
# sample_minibatch(minibatch, tiny_dataset.tokenizer, remove_special_tokens=False)

In [14]:
# overfitting_config = CaptioningRNN.default_config()
# overfitting_config["rnn_dropout"] = False

# overfitting_model = CaptioningRNN(tiny_dataset, overfitting_config)
# overfitting_model.hparams

In [15]:
# overfitting_trainer = pl.Trainer(
#     gpus=1,
#     max_epochs=250,
#     num_sanity_val_steps=1,
#     progress_bar_refresh_rate=50,
#     check_val_every_n_epoch=250,
#     )

# overfitting_trainer.fit(overfitting_model)

In [16]:
# overfitting_model.inference_beam_width = 1
# overfitting_model.inference_beam_alpha = 0.
# sample_predictions(minibatch, overfitting_model)

## Set up training dataset

In [17]:
# Download vocab file; using the same vocab will allow us to interpret model
# results in the future:
!curl -L -C - "https://raw.githubusercontent.com/reppertj/image-captioning/master/tokenizer/vocab.txt" -o "vocab.txt"

** Resuming transfer from byte position 32221
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0


In [18]:
# Create dataset, passing in the vocab file
fdm = CombinedDataModule(
    flickr_csv=FLICKR30K_CSV_PATH,
    flickr_dir=FLICKR30K_IMG_DIR,
    coco_json=COCO_JSON_PATH,
    coco_dir=COCO_IMG_DIR,
    pretrained_vocab='vocab.txt',
    num_workers=2,
    batch_size=128,
)

In [19]:
# Setup dataset
fdm.setup()

## Set up model

In [20]:
# Begin with the default model config
config = CaptioningRNN.default_config()

In [21]:
# Tweak some of the default parameters
config['optimizer'] = 'adamw'
config['batch_size'] = 256
config['fc_init'] = 'kaiming'
config['encoder_init'] = 'kaiming'
config['label_smoothing_epsilon'] = 0.1
config['inference_beam_alpha'] = 0.9
config['num_rnn_layers'] = 4
config['hidden_size'] = 1024
config['wordvec_dim'] = 512

In [22]:
# Create the model
model = CaptioningRNN(fdm, config)

In [23]:
# Double check the model's hyperparameters
model.hparams

"batch_size":              256
"encoder_init":            kaiming
"fc_init":                 kaiming
"hidden_size":             1024
"image_encoder":           resnext50
"inference_beam_alpha":    0.9
"inference_beam_width":    10
"label_smoothing_epsilon": 0.1
"learning_rate":           0.0003
"max_length":              25
"momentum":                0.9
"num_rnn_layers":          4
"num_rnns":                1
"optimizer":               adamw
"rnn_bidirectional":       False
"rnn_dropout":             0.1
"rnn_init":                None
"rnn_nonlinearity":        None
"rnn_type":                attention
"scheduler":               plateau
"wd_embedder_init":        xavier
"wordvec_dim":             512

## Set up trainer and train

In [24]:
# Create a callback to save the best models during training
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    filename='attn-lstm-5000-{epoch}-{val_loss:.2f}',
    dirpath=os.path.join(GOOGLE_DRIVE_PATH, 'checkpoints'),
    monitor='val_loss',
    mode='min',
    save_top_k=10
    )

wandb_logger = WandbLogger(entity='collaborativeml', project='attention_lstm', log_model=True)



In [None]:
# Setup trainer and train!
trainer = pl.Trainer(
    gpus=1,
    max_epochs=1000,
    num_sanity_val_steps=1,
    progress_bar_refresh_rate=50,
    auto_scale_batch_size='binsearch',
    auto_lr_find=True,
    callbacks=[checkpoint_callback],
    logger=wandb_logger,
    benchmark=True,
    gradient_clip_val=0.5,
    )

trainer.fit(model)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



  | Name            | Type                  | Params
----------------------------------------------------------
0 | val_bleu        | CorpusBleu            | 0     
1 | test_bleu       | CorpusBleu            | 0     
2 | image_extractor | ImageFeatureExtractor | 25.1 M
3 | word_embedder   | WordEmbedder          | 2.6 M 
4 | decoder         | ParallelAttentionLSTM | 11.5 M
5 | fc_scorer       | ParallelFCScorer      | 5.1 M 
----------------------------------------------------------
21.3 M    Trainable params
23.0 M    Non-trainable params
44.3 M    Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

If Colab interrupts training, we'll be able to restore from a checkpoint and continue training in another session.