#INTRODUCTION
The goal of this project is to create a model that can receive a text description and then select from a library of images the image that most closely aligns with that description.  The model architecture is a dual encoder that is trained to projects images and their descriptions onto the same space and at the same location.  To serve as the base models of the dual encoder, we will use Inception V3 to encode the image data and BERT to encode the text data.  On top of these base models will sit two small feedforward networks that will project the encoded images and text onto the common space where they will have the same dimensions.  Let us call these two feedforward networks the "projection heads".

Because training for this dual encoder is so time consuming, we will speed up the process by breaking the training process into two phases.
##TRAINING PHASES
###Phase 1:
During phase 1, we will circumvent the base models completely, dramatically reducing the iteration time as we tune the model.  We will use the base models just once to encode the entire dataset, and then we will train the projection heads using just the pre-encoded data.

###Phase 2:
During phase 2, we will use the original data to fine tune the entire dual encoder, including the base models.

##Note
This project draws inspiration and some design features from the keras project "Natural language image search with a Dual Encoder", located at https://keras.io/examples/vision/nl_image_search/.

#Setup
We will use PyTorch to build the dual encoder, and we will use Google Colab to give us access to a GPU.

In [1]:
import os
import zipfile
import requests
from tqdm import tqdm
import json
import shutil
import importlib.util
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torchvision.io import read_image
from torchvision import transforms as torchvision_transforms
import json
from torch import nn
from torch.utils.data import DataLoader, random_split
if not importlib.util.find_spec('transformers'):
  !pip install transformers
from transformers import BertTokenizer, BertModel
from torchvision.models import inception_v3, Inception_V3_Weights
import h5py
from itertools import chain
import pickle

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


#Download the Data
We begin by downloading the MS COCO dataset, which contains both the image and caption data.  Note that the captions describe the images, and that they are stored in the "annotations" file.  The data is stored in zip files that we download for free.

In [2]:
# Download the full MS-COCO dataset (unless using dev data)

def download_and_extract(url):
        # Function that downloads and extracts files given a url

        os.makedirs('data', exist_ok=True)

        # Extract the filename from the URL
        filename = os.path.join('data', url.split("/")[-1])

        # Download the file if it doesn't exist
        if not os.path.exists(filename):
            response = requests.get(url, stream=True)
            total_size = int(response.headers.get('content-length', 0))

            with open(filename, 'wb') as file, tqdm(
                desc=filename,
                total=total_size,
                unit='B',
                unit_scale=True,
                unit_divisor=1024,
            ) as bar:
                for data in response.iter_content(chunk_size=1024):
                    size = file.write(data)
                    bar.update(size)

        # Extract the zip file
        with zipfile.ZipFile(filename, 'r') as zip_ref:
            zip_ref.extractall('data')
            print(f"Extracted all contents to data")

def download_ms_coco():
    # Function to download the MS-COCO dataset (2017 version)

    URLS = {
        "train_images": "http://images.cocodataset.org/zips/train2014.zip",
        "val_images": "http://images.cocodataset.org/zips/val2014.zip",
        "annotations": "http://images.cocodataset.org/annotations/annotations_trainval2014.zip"
    }

    # Download, extract images and annotations
    for url in URLS.values():
        download_and_extract(url)

    print("MS-COCO dataset downloaded and extracted successfully!")

download_ms_coco()

data/train2014.zip: 100%|██████████| 12.6G/12.6G [06:54<00:00, 32.6MB/s]


Extracted all contents to data


data/val2014.zip: 100%|██████████| 6.19G/6.19G [03:14<00:00, 34.2MB/s]


Extracted all contents to data


data/annotations_trainval2014.zip: 100%|██████████| 241M/241M [00:07<00:00, 34.9MB/s]


Extracted all contents to data
MS-COCO dataset downloaded and extracted successfully!


#Prepare the Data for Model Ingestion
We will generate PyTorch Datasets and Dataloaders using the original image and caption data.  Note that these datasets will only be used during phase 2 training.  However, we still need to create them now so that we can generate the pre-encoded datasets that will be used during phase 1.

In [3]:
# Prepare the image transforms

class resize_and_pad_image:
    # Class to ensure that all images are the same size
    def __init__(self, size):
        self.size = size

    def __call__(self, image):
        h, w = image.size()[-2:]
        size_param = int(min(w/h, h/w) * self.size)
        image = torchvision_transforms.functional.resize(img=image, size=size_param)
        h_new, w_new = image.size()[-2:]
        dw = (self.size-w_new)//2
        dh = (self.size-h_new)//2
        rw = (self.size-w_new)%2
        rh = (self.size-h_new)%2
        return torchvision_transforms.functional.pad(image, padding=(dw, dh, dw+rw, dh+rh))

image_transform = torchvision_transforms.Compose([
    resize_and_pad_image(size=299),
    torchvision_transforms.ConvertImageDtype(torch.float32),  # Convert to float tensor
    torchvision_transforms.Lambda(lambda x: x / 255),        # Scale pixel values
    torchvision_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


# Define the images & captions dataset in PyTorch

class image_and_captions_dateset(torch.utils.data.Dataset):
    """
    Class to provide a custom dataset for the images and captions

    Attributes
    ----------
    annotations_file : str
        The path to the annotations file
    img_dir : str
        The path to the directory containing the images
    image_transform : torchvision.transforms.Compose
        The transform to apply to the images

    Methods
    -------
    __len__()
        Returns the number of samples in the dataset
    __getitem__(idx)
        Returns the transformed image, the caption, and the sample Id at index idx
    """
    def __init__(self, annotations_file, img_dir, image_transform):

        with open(annotations_file, 'r') as f:
            self.annotations = json.load(f)['annotations']
        self.img_dir = img_dir
        self.image_transform = image_transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        sample_id = self.annotations[idx]['image_id']
        img_path = os.path.join(self.img_dir, 'COCO_train2014_%012d.jpg'%sample_id)
        image = read_image(img_path)
        if image.shape[0] == 1: image = image.tile((3,1,1))
        caption = self.annotations[idx]['caption']
        image = self.image_transform(image)
        return image, caption, sample_id

##Create the PyTorch Datasets
Out dataset will contain transformed images and untransformed text captions.  The text will be tokenized during the dataloading process.

Note that I limit that size of the training dataset because the full dataset is too large for my Google Cloud compute resources.

In [4]:
annotations_file = "./data/annotations/captions_train2014.json"
img_dir = './data/train2014'

# Generate the PyTorch dataset
coco_dataset = image_and_captions_dateset(annotations_file, img_dir, image_transform)


# Split the dataset.  We won't use the whole dataset because it is too large
splits = (0.4, 0.1, 0.1)

train_size = int(splits[0] * len(coco_dataset))
val_size = int(splits[1] * len(coco_dataset))
test_size = int(splits[2] * len(coco_dataset))
extra = len(coco_dataset) - train_size - val_size - test_size
train_dataset, val_dataset, test_dataset, _ = random_split(coco_dataset, [train_size, val_size, test_size, extra])

##Create the PyTorch Dataloaders

We tokenize during the dataloading process so that the length of the samples within each batch can be padded to the max sample length within the batch, rather than forcing all samples in all batches to be padded to the max sample length within the entire dataset.

Since we are using BERT as our text encoder model, we must tokenize our captions using the BERT Tokenizer

In [5]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

class collate_fn():
  """
  Class to provide a custom collate function that modifies the batching process.

  Attributes
  ----------
  tokenizer : BertTokenizer
      The tokenizer to use for tokenizing the captions.

  Methods
  -------
  __call__(batch)
      Takes in a batch of captions and returns a batch of tokenized captions
      padded to the max caption length within the batch.
  """
  def __init__(self, tokenizer):
    self.tokenizer = tokenizer
  def __call__(self, batch):
    [images, captions, sample_ids] = list(zip(*batch))
    embedded_captions = self.tokenizer(list(captions), padding='longest', truncation=True, return_tensors='pt')
    images = torch.stack(list(images))
    return (images, embedded_captions, captions, sample_ids)

# Create DataLoader for each dataset
train_dataloader = DataLoader(train_dataset, batch_size = 16, collate_fn=collate_fn(tokenizer), shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size = 16, collate_fn=collate_fn(tokenizer), shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size = 16, collate_fn=collate_fn(tokenizer), shuffle=False)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]



config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

#Create the Elements of the Dual Encoder
The elements are:


1.   Base Model 1 - Inception V3 for image data
2.   Base Model 2 - BERT for text data
3.   Projection Head 1 for image data
4.   Projection Head 2 for text data



In [6]:
# Define the Projection Head class that will be used to transform the outputs of both the text and image models into the same embedding space.

class projection_head(nn.Module):
    """
    Module to provide the projection head that will be used to project the outputs of both the text and image models into the same embedding space.

    Attributes
    ----------
    in_features : int
        The number of features outputted by the base model.
    projection_dims : int
        The number of features in the text/image common embedding space.
    dropout_rate : float
        The dropout rate to use
    """
    def __init__(self, in_features, projection_dims, dropout_rate):
        super().__init__()
        self.lin1 = nn.Linear(in_features, projection_dims)
        self.gelu = nn.GELU()
        self.layernorm1 = nn.LayerNorm(projection_dims)
        self.lin2 = nn.Linear(projection_dims, projection_dims)
        self.dropout = nn.Dropout(p=dropout_rate)
        self.layernorm2 = nn.LayerNorm(projection_dims)
    def forward(self, x):
        x1 = self.lin1(x)
        x = self.gelu(x1)
        x = self.layernorm1(x)
        x = self.lin2(x)
        x = self.dropout(x)
        x = self.layernorm2(x)
        return x


# Load in the BERT model and the Inception model, and create projection heads for each

image_embedder = inception_v3(weights=Inception_V3_Weights.DEFAULT)
image_projection_head = projection_head(in_features=image_embedder.fc.in_features, projection_dims=256, dropout_rate=0.1)
image_AuxLogits_projection_head = projection_head(in_features=image_embedder.AuxLogits.fc.in_features, projection_dims=256, dropout_rate=0.1)
image_embedder.fc = torch.nn.Identity()
image_embedder.AuxLogits.fc = torch.nn.Identity()
caption_embedder = BertModel.from_pretrained('bert-base-uncased')
caption_projection_head = projection_head(in_features=caption_embedder.pooler.dense.out_features, projection_dims=256, dropout_rate=0.1)

# Move all elements of the dual encoder to the GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
image_embedder.to(device)
caption_embedder.to(device)
image_projection_head.to(device)
caption_projection_head.to(device)

Downloading: "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth" to /root/.cache/torch/hub/checkpoints/inception_v3_google-0cc3c7bd.pth
100%|██████████| 104M/104M [00:02<00:00, 46.5MB/s]


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

projection_head(
  (lin1): Linear(in_features=768, out_features=256, bias=True)
  (gelu): GELU(approximate='none')
  (layernorm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (lin2): Linear(in_features=256, out_features=256, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (layernorm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
)

#Generate the Base Embeddings for Phase 1
Now, we use our base models to pre-embed both our images and captions.  Both models will be run in inference mode.  Later, we will use these base-embeddings to create the dataset that we will use to train our projection heads during Phase 1 of training.  We will generate both a file containing the embeddings and a file containing the original captions and sample IDs.  This second file will allow us to map the base-embeddings to the original dataset, which we need to do in order to evaluate our model performance after Phase 1 training.

In [7]:
def generate_base_embeddings(image_embedder, text_embedder, dataloader, base_embeddings_file, base_model_embeddings_metadata_file):
    """
    Function to generate and save the base embeddings for the entire dataset.
    The embeddings are collected in two lists and then saved into an HDF5 file and a pickle file.

    Inputs
    -------
    image_embedder : torch.nn.Module
        The image model to use for embedding the images.
    text_embedder : torch.nn.Module
        The text model to use for embedding the captions.
    dataloader : torch.utils.data.DataLoader
        The dataloader to use for generating the original dataset's data.
    base_embeddings_file : str
        The path to the file to store the base embeddings.
    base_model_embeddings_metadata_file : str
        The path to the file to store the original captions and sample ids.

    Outputs
    -------
    h5py file is stored at base_embeddings_file
    """

    image_embedder.eval(); text_embedder.eval()

    with torch.no_grad():

        # Create a temporary file to store the embeddings
        with h5py.File('temporary_file.h5', 'w') as h5f:

            # Create a dataset within the HDF5 file to store embeddings
            image_embeddings_h5 = h5f.create_dataset("image_embeddings", (len(dataloader.dataset), 2048), dtype='float32')
            caption_embeddings_h5 = h5f.create_dataset("caption_embeddings", (len(dataloader.dataset), 768), dtype='float32')

            # Lists to store the data as it is embedded in batches
            caption_list = []
            sample_id_list = []

            for batch, (images, tokenized_captions, captions, sample_ids) in enumerate(tqdm(dataloader, desc='Embedding Data')):
                images = images.to(device)
                for item in tokenized_captions:
                  tokenized_captions[item] = tokenized_captions[item].to(device)
                images_out = image_embedder(images)
                captions_out = caption_embedder(**tokenized_captions).pooler_output
                image_embeddings_h5[batch*16:(batch*16+len(images_out))] = images_out.cpu().numpy()
                caption_embeddings_h5[batch*16:(batch*16+len(images_out))] = captions_out.cpu().numpy()
                caption_list = caption_list + list(captions)
                sample_id_list = sample_id_list + list(sample_ids)

    !mv temporary_file.h5 {base_embeddings_file}

    with open('temporary_file.pkl', "wb") as file:
        pickle.dump((caption_list, sample_id_list), file)
    !mv temporary_file.pkl {base_model_embeddings_metadata_file}


# Establish the target location at which to save the base-embeddings

os.makedirs('/content/drive/My Drive/Projects/text-to-image-search/data', exist_ok=True)
base_embeddings_file = '/content/drive/My Drive/Projects/text-to-image-search/data/base_embeddings.h5'
base_model_embeddings_metadata_file = "/content/drive/My Drive/Projects/text-to-image-search/data/base_embeddings_metadata.pkl"

# Generate the base-embeddings
if not os.path.isfile(base_embeddings_file):
      generate_base_embeddings(image_embedder, caption_embedder, train_dataloader, base_embeddings_file, base_model_embeddings_metadata_file)

# Load the base-embeddings, as well as the captions and sample IDs
data = h5py.File(base_embeddings_file, 'r')
image_embeddings = torch.tensor(data['image_embeddings'], dtype=torch.float)
caption_embeddings = torch.tensor(data['caption_embeddings'], dtype=torch.float)

with open(base_model_embeddings_metadata_file, "rb") as file:
    (caption_list, sample_id_list) = pickle.load(file)

  image_embeddings = torch.tensor(data['image_embeddings'], dtype=torch.float)


#Generate the Phase 1 Training Dataset
Now that we have the base-embeddings, we can use them to create a dataset that we will use for Phase 1 training of just the projection heads.  We will also split off some of the base-embeddings to use in a validation dataset.

In [8]:
# Generate the dataset for the pre-embedded data

class Embeddings_Dataset(torch.utils.data.Dataset):
    def __init__(self, image_embeddings, caption_embeddings, caption, sample_id):

        self.image_embeddings = image_embeddings
        self.caption_embeddings = caption_embeddings
        self.caption = caption
        self.sample_id = sample_id

    def __len__(self):
        return len(self.caption_embeddings)

    def __getitem__(self, idx):
        return self.image_embeddings[idx], self.caption_embeddings[idx], self.caption[idx], self.sample_id[idx]


validation_cutoff = round(len(image_embeddings)*0.05)
embeddings_train_dataset = Embeddings_Dataset(image_embeddings[:-validation_cutoff], caption_embeddings[:-validation_cutoff], caption_list[:-validation_cutoff], sample_id_list[:-validation_cutoff])
embeddings_train_dataloader = DataLoader(embeddings_train_dataset, batch_size=16, shuffle=True)
embeddings_val_dataset = Embeddings_Dataset(image_embeddings[-validation_cutoff:], caption_embeddings[-validation_cutoff:], caption_list[-validation_cutoff:], sample_id_list[-validation_cutoff:])
embeddings_val_dataloader = DataLoader(embeddings_val_dataset, batch_size=16, shuffle=True)

# Functions and Classes for Phase 1 Training
Before we begin Phase 1 training of the projection heads, we need to create the training functions and define the loss function.  We use a custom loss function that allows us to compare all potential pairings and images and captions within an individual batch. Because our projection heads output a layer-normalized vector for each sample, we can simply use a matrix multiplication to check for projection similarity.  Matrix multiplication is advantageous because it is a quick and easy operation.

In [9]:
# Create the custom loss class to utilize during training

from torch import matmul
import torch.nn.functional as F

class custom_loss(nn.Module):
    def forward(self, image_out, caption_out):
        predicted_sim = matmul(image_out, torch.transpose(caption_out, 0, 1))
        image_sim = matmul(image_out, torch.transpose(image_out, 0, 1))
        caption_sim = matmul(caption_out, torch.transpose(caption_out, 0, 1))
        targets = F.softmax((image_sim+caption_sim)/2, dim=0)
        return (F.cross_entropy(predicted_sim, targets) + F.cross_entropy(torch.transpose(predicted_sim, 0, 1), targets)) / 2


# Fuctions for training

def train_projection_head_one_epoch(image_embedder, text_embedder, dataloader, loss_fn, optimizer):
    image_embedder.train(); text_embedder.train()
    for batch, (images, captions, _, _) in enumerate(dataloader):
        images = images.to(device)
        captions = captions.to(device)
        images_out = image_embedder(images)
        captions_out = text_embedder(captions)
        loss = loss_fn(images_out, captions_out)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

def validate_projection_head(image_embedder, text_embedder, dataloader, loss_fn):
    image_embedder.eval(); text_embedder.eval()
    with torch.no_grad():
        val_loss = 0
        for batch, (images, captions, _, _) in enumerate(dataloader):
            images = images.to(device)
            captions = captions.to(device)
            images_out = image_embedder(images)
            captions_out = text_embedder(captions)
            val_loss += loss_fn(images_out, captions_out)
        return val_loss / len(dataloader)

def train_projection_head(image_embedder, text_embedder, train_dataloader, val_dataloader, epochs, loss_fn, optimizer):
    for epoch in range(epochs):
        # print('Epoch', epoch)
        val_loss = validate_projection_head(image_embedder, text_embedder, val_dataloader, loss_fn)
        # print('Val_loss:', val_loss)
        train_projection_head_one_epoch(image_embedder, text_embedder, train_dataloader, loss_fn, optimizer)
        scheduler.step(val_loss)

    return image_embedder, text_embedder

# Phase 1 Training: Train the Projection Heads

In [10]:
# Train the projection heads

image_projection_head.to(device)
caption_projection_head.to(device)

loss_fn = custom_loss()
optimizer = torch.optim.AdamW(chain(image_projection_head.parameters(), caption_projection_head.parameters()), lr=0.2, weight_decay=0.001)
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=1e-5, max_lr=2e-2, step_size_up=200, mode='triangular')

image_projection_head, caption_projection_head = train_projection_head(image_projection_head,
                                                                       caption_projection_head,
                                                                       embeddings_train_dataloader,
                                                                       embeddings_val_dataloader,
                                                                       100,
                                                                       loss_fn,
                                                                       optimizer)

os.makedirs('/content/drive/My Drive/Projects/text-to-image-search/models', exist_ok=True)
torch.save(caption_projection_head, '/content/drive/My Drive/Projects/text-to-image-search/models/caption_projection_head.pt')
torch.save(image_projection_head, '/content/drive/My Drive/Projects/text-to-image-search/models/image_projection_head.pt')
print('Saved!')

Saved!


## Load in the Projection Heads

In [11]:
image_projection_head = torch.load('/content/drive/My Drive/Projects/text-to-image-search/models/image_projection_head.pt')
caption_projection_head = torch.load('/content/drive/My Drive/Projects/text-to-image-search/models/caption_projection_head.pt')

# Evaluate the Performance
In order to evaluate the performance of Phase 1 training, we will generate projections for the entire validation dataset, and then check each projected caption to see which image projections are closest to it.  If the correct image is within the 5% of closest images, we will call it a success and count it towards our accuracy score.  Finally, we will print out some examples of image / caption pairs whose projections were closest to each other.

In [19]:
def generate_projections(image_model, caption_model, dataloader):
    """
    Function to generate (but not save) the projects for an entire dataset.
    The projections are collected in two large torch.Tensors, and projection metadata is collected into two lists.

    Inputs
    -------
    image_projection_head : torch.nn.Module
        The projection head to use for projecting the images.
    caption_projection_head : torch.nn.Module
        The projection to use for projecting the captions.
    embeddings_val_dataloader : torch.utils.data.DataLoader
        The dataloader to use for generating the base-embeddings data.

    Outputs
    -------
    image_proj_embeddings : torch.Tensor
        The Torch Tensor containing the image projections.
    caption_proj_embeddings : torch.Tensor
        The Torch Tensor containing the caption projections.
    captions_list : list
        The list of captions.
    sample_ids_list : list
        The list of sample ids.
    """

    image_model.to(device); caption_model.to(device)
    image_model.eval(); caption_model.eval()
    with torch.no_grad():
        image_proj_embeddings = torch.empty((len(dataloader.dataset), 256))
        caption_proj_embeddings = torch.empty((len(dataloader.dataset), 256))
        captions_list = []
        sample_ids_list = []
        for batch, (images_batch, captions_batch, captions, sample_ids) in enumerate(tqdm(dataloader, desc='Projecting Data')):
            images_batch = images_batch.to(device)
            if isinstance(captions_batch, dict):
                for item in captions_batch:
                  captions_batch[item] = captions_batch[item].to(device)
            else:
                captions_batch = captions_batch.to(device)
            image_proj_embeddings[batch*16:batch*16+len(images_batch), :] = image_model(images_batch)
            caption_proj_embeddings[batch*16:batch*16+len(images_batch), :] = caption_model(captions_batch)
            captions_list = captions_list + list(captions)
            sample_ids_list = sample_ids_list + list(sample_ids)
    return image_proj_embeddings, caption_proj_embeddings, captions_list, sample_ids_list


image_proj_embeddings, caption_proj_embeddings, captions_list, sample_ids_list = \
    generate_projections(image_projection_head, caption_projection_head, embeddings_val_dataloader)


def performance_eval(image_proj_embeddings, caption_proj_embeddings, num_to_print, img_dir):

    # Array to store the image that most closely aligns with each caption's projection
    selected_images = np.empty(len(caption_proj_embeddings)).astype('int')

    k = round(len(image_proj_embeddings)*.05)
    counter = 0
    for i, caption in enumerate(caption_proj_embeddings):
      predicted_sim = matmul(caption, torch.transpose(image_proj_embeddings, 0, 1))
      selected_images[i] = torch.argmax(predicted_sim)
      if i in torch.topk(predicted_sim, k=k)[1]:
        counter += 1

    print('Accuracy of', counter/len(caption_proj_embeddings))

    # Print some captions and corresponding chosen images
    for i in range(num_to_print):
        print('Caption:', captions_list[i])
        print('Correct Image:', sample_ids_list[i])
        img_path = os.path.join(img_dir, 'COCO_train2014_%012d.jpg'%sample_ids_list[i])
        image = read_image(img_path)
        plt.imshow(image.permute(1, 2, 0)); plt.show()
        print('Selected Image:', sample_ids_list[selected_images[i]])
        selected_image_id = sample_ids_list[selected_images[i]]
        img_path = os.path.join(img_dir, 'COCO_train2014_%012d.jpg'%selected_image_id)
        image = read_image(img_path)
        plt.imshow(image.permute(1, 2, 0)); plt.show()

performance_eval(image_proj_embeddings, caption_proj_embeddings, 5, img_dir)

Output hidden; open in https://colab.research.google.com to view.

#Phase 1 Results: 50% Success Rate
While 50% might seem low, we can see from the selected images that the model has definitely learned a lot, and does connect some conceptual elements from the text and images.  There is still room for improvement, so let's see if we can fine tune the model and do even better.

# Phase 2: Fine Tune Base Models Using Original Dataset
At this point, we have projection heads that work reasonably well.  Now we can fine tune the base models and attempt to get even better performance.  We will build out the full dual encoder and continue training using the original datasets.  To create the image side the dual encoder, we will simply replace the top layer of the Inception V3 image_embedder with the image projection head.  To create the text side of the dual encoder, we will need to create a custom Module to extract the pooler outputs from the BERT model and feed them to the caption projection head.

In [20]:
# Create the image side of the dual encoder.  It will still be called the image_embedder.

image_embedder.fc = image_projection_head
image_embedder.AuxLogits.fc = image_AuxLogits_projection_head


# Create the text side of the dual encoder.  We will call this the text_embedder

class CaptionEmbedderWithProjectionHead(nn.Module):
  def __init__(self, caption_embedder, caption_projection_head):
    super().__init__()
    self.base_model = caption_embedder
    self.projection_head = caption_projection_head

  def forward(self, x):
    base_model_output = self.base_model(**x)
    return self.projection_head(base_model_output.pooler_output)

text_embedder = CaptionEmbedderWithProjectionHead(caption_embedder=caption_embedder, caption_projection_head=caption_projection_head)



## Functions for Phase 2 Training

In [24]:
def train_full_model_one_epoch(image_embedder, text_embedder, dataloader, loss_fn, optimizer):
    image_embedder.train(); text_embedder.train()
    for batch, (images, captions, _, _) in enumerate(tqdm(dataloader, desc='Epoch Progress')):
        images = images.to(device)
        for item in captions:
          captions[item] = captions[item].to(device)
        images_out = image_embedder(images)
        captions_out = text_embedder(captions)
        loss1 = loss_fn(images_out.logits, captions_out)
        loss2 = loss_fn(images_out.aux_logits, captions_out)
        loss = loss1 + 0.4 * loss2
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

def validate_full_model(image_embedder, text_embedder, dataloader, loss_fn):
    image_embedder.eval(); text_embedder.eval()
    with torch.no_grad():
        val_loss = 0
        for batch, (images, captions, _, _) in enumerate(tqdm(dataloader, desc='Calculating Val Loss')):
            images = images.to(device)
            for item in captions:
              captions[item] = captions[item].to(device)
            images_out = image_embedder(images)
            captions_out = text_embedder(captions)
            val_loss += loss_fn(images_out, captions_out)
        return val_loss / len(dataloader)

def train_full_model(image_embedder, text_embedder, train_dataloader, val_dataloader, epochs, loss_fn, optimizer):
    for epoch in range(epochs):
        print('Epoch', epoch)
        val_loss = validate_full_model(image_embedder, text_embedder, val_dataloader, loss_fn)
        print('Val_loss:', val_loss)
        train_full_model_one_epoch(image_embedder, text_embedder, train_dataloader, loss_fn, optimizer)
        scheduler.step(val_loss)
    val_loss = validate_full_model(image_embedder, text_embedder, val_dataloader, loss_fn)
    print('Val_loss:', val_loss)

    return image_embedder, text_embedder

#Phase 2 Training: Fine Tune the Full Dual Encoder
Because the dual encoder is so large, training an epoch is very slow.  Therefore, we will just train for 1 epoch.

In [25]:
# Train the full models

image_embedder.to(device)
text_embedder.to(device)

loss_fn = custom_loss()
optimizer = torch.optim.AdamW(chain(image_embedder.parameters(), text_embedder.parameters()), lr=0.0001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)

image_embedder, text_embedder = train_full_model(image_embedder, text_embedder, train_dataloader, val_dataloader, 1, loss_fn, optimizer)

Epoch 0


Calculating Val Loss: 100%|██████████| 2589/2589 [10:21<00:00,  4.17it/s]


Val_loss: tensor(1.5703, device='cuda:0')


  return F.conv2d(input, weight, bias, self.stride,
Epoch Progress: 100%|██████████| 10353/10353 [50:28<00:00,  3.42it/s]
Calculating Val Loss: 100%|██████████| 2589/2589 [10:08<00:00,  4.26it/s]


Val_loss: tensor(0.9090, device='cuda:0')


In [27]:
torch.save(image_embedder, '/content/drive/My Drive/Projects/text-to-image-search/models/image_embedder.pt')
torch.save(text_embedder, '/content/drive/My Drive/Projects/text-to-image-search/models/text_embedder.pt')
print('Saved!')

Saved!


# Evaluate Performance
We will be using the entire test dataset, which is 42K samples long, to evaluate performance.

In [34]:
image_proj_embeddings, caption_proj_embeddings, captions_list, sample_ids_list = \
    generate_projections(image_embedder, text_embedder, test_dataloader)

performance_eval(image_proj_embeddings, caption_proj_embeddings, 5, img_dir)





# selected_images = np.empty(len(caption_embeddings)).astype('int')

# counter = 0
# for i, caption in enumerate(caption_embeddings):
#   # prediction_dist = torch.cdist(torch.unsqueeze(caption, 0), image_embeddings)[0]
#   # selected_images[i] = torch.argmin(prediction_dist)
#   predicted_sim = matmul(caption, torch.transpose(image_embeddings, 0, 1))
#   selected_images[i] = torch.argmax(predicted_sim)
#   if i in torch.topk(predicted_sim, k=10)[1]:
#     counter += 1

# print('Accuracy of', counter/len(caption_embeddings))

# for i in range(5):
#     print(captions[i])
#     selected_image_id = sample_ids[selected_images[i]]
#     img_path = os.path.join(img_dir, 'COCO_train2014_%012d.jpg'%selected_image_id)
#     image = read_image(img_path)
#     plt.imshow(image.permute(1, 2, 0)); plt.show()

Output hidden; open in https://colab.research.google.com to view.

#Final Results: 78% Success Rate
Fine tuning pushed the success rate up from 50% to 78% in only 1 epoch of training.  Additional epochs may have driven that rate up to be higher.  Already, the dual encoder is doing a very good job of returning images that contain the correct objects.
