# Natural language image search with a Dual Encoder
[source](https://keras.io/examples/nlp/nl_image_search/#implement-the-dual-encoder)

## Background

**Contrastive Learning**

* 데이터 instance 간의 유사성을 학습하여 데이터에 대한 좋은 특징들을 추출하는 방법
    * Contrastive learning은 representation space에서 어떤 것들이 비슷하니 가까이 있어야하고, 어떤 것들이 다르니 멀리 있어야 한다는 것을 학습하는 과정을 통해 rich representation을 얻는 방법
    * 예를 들어 classification task라고 생각했을때, 같은 강아지 이미지들은 representation이 가까이에 위치하고, 강아지와 고양이 이미지 사이에는 거리가 있어야 boundary가 잘 구분 됨
    *  positve & negative example의 관계를 학습
* 지도/비지도 학습 모두 적용 가능 
    * 레이블 정보를 사용하지 않는 비지도 학습으로 데이터의 특징(representation)을 잘 학습하면, 지도 학습 모델에 준하는 성능을 냄

**CLIP(Contrastive Language-Image Pre-training)**
* https://openai.com/blog/clip/
* CLIP models will need to learn to recognize a wide variety of visual concepts in images and associate them with their names

**Keras Example Description**
* The idea is to **train a vision encoder and a text encoder jointly** to project the **representation of images and their captions** into the same embedding space, such that the caption embeddings are **located near** the embeddings of the images they describe.
* **positive example** : correct image-text pair
* **negative example** : incorrect image-text pair
* **logits** : caption과 image embedding의 dot_similarity
* **targets** : binary label [참고](https://kmhana.tistory.com/17)

## Setup

In [1]:
import os
import collections
import json
import numpy as np
import torch
import torch.utils.data as data
import torch.nn as nn
import os
import pickle
from torchvision import transforms
from PIL import Image
from pycocotools.coco import COCO
from tqdm import tqdm

import transformers
from transformers import BertTokenizer

## Prepare the data

We will use the MS-COCO dataset to train our dual encoder model. MS-COCO contains over 82,000 images, each of which has at least 5 different caption annotations. The dataset is usually used for image captioning tasks, but we can repurpose the image-caption pairs to train our dual encoder model for image search.

Download and extract the data

First, let's download the dataset, which consists of two compressed folders: one with images, and the other—with associated image captions. Note that the compressed images folder is 13GB in size.

In [2]:
root_dir = "datasets"
annotations_dir = os.path.join(root_dir, "annotations")
images_dir = os.path.join(root_dir, "train2014")
annotation_file = os.path.join(annotations_dir, "captions_train2014.json")

with open(annotation_file, "r") as f:
    annotations = json.load(f)["annotations"]

image_path_to_caption = collections.defaultdict(list)
for element in annotations:
    caption = f"{element['caption'].lower().rstrip('.')}"
    image_path = images_dir + "/COCO_train2014_" + "%012d.jpg" % (element["image_id"])
    image_path_to_caption[image_path].append(caption)

image_paths = list(image_path_to_caption.keys())
print(f"Number of images: {len(image_paths)}")

Number of images: 82783


In [3]:
image_path_to_caption

defaultdict(list,
            {'datasets/train2014/COCO_train2014_000000318556.jpg': ['a very clean and well decorated empty bathroom',
              'a blue and white bathroom with butterfly themed wall tiles',
              'a bathroom with a border of butterflies and blue paint on the walls above it',
              'an angled view of a beautifully decorated bathroom',
              'a clock that blends in with the wall hangs in a bathroom. '],
             'datasets/train2014/COCO_train2014_000000116100.jpg': ['a panoramic view of a kitchen and all of its appliances',
              'a panoramic photo of a kitchen and dining room',
              'a wide angle view of the kitchen work area',
              'multiple photos of a brown and white kitchen. ',
              'a kitchen that has a checkered patterned floor and white cabinets'],
             'datasets/train2014/COCO_train2014_000000379340.jpg': ['a graffiti-ed stop sign across the street from a red car ',
              'a vand

## Create Dataset and DataLoader

In [4]:
# Original Configuration

train_size = 30000
valid_size = 5000
batch_size = 256
# captions_per_image = 2

# For experimental implementation
# train_size = 15000
# valid_size = 300
# batch_size = 128

train_image_paths = image_paths[:train_size]
valid_image_paths = image_paths[-valid_size:]

#### TODO : Dataset, DataLoader 수정

* COCO dataset에는 image 하나 당, catption 5개 존재
* Keras 예제에서는 image 한개당 caption 2개를 각각 짝지어서 30000개의 이미지에 대해, 총 60000개의 image-caption pairs를 학습데이터로 사용하고 있음
* 현재는 간단하게 1000장의 이미지에 대해 1개의 caption을 짝지어 1000개의 example 사용

[코드 부연설명]
* `Image.ANTIALIAS` : 이미지를 resizing 할 때, 조금 더 부드럽게 처리 [사진링크](https://www.computerhope.com/jargon/a/antialias.htm)
> In digital signal processing, spatial anti-aliasing is a technique for minimizing the distortion artifacts known as aliasing when representing a high-resolution image at a lower resolution. Anti-aliasing is used in digital photography, computer graphics, digital audio, and many other applications.
* `pycocotools`의 `COCO` Class
    * json 파일에서 우리가 주로 이용하게 될 데이터는 Images와 Annotations
    * image 와 annotation에 대한 작업을 수월하게 해주는 API

In [5]:
# 참조 코드 : https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/03-advanced/image_captioning/data_loader.py

class CocoDataset(data.Dataset):
    """COCO Custom Dataset compatible with torch.utils.data.DataLoader."""
    def __init__(self, filenames, json, transform=None):
        """Set the path for images, captions and vocabulary wrapper.
        
        Args:
            filenames: image filenames list.
            json: coco annotation file path.
            transform: image transformer.
        """
        self.filenames = filenames
        self.coco = COCO(json)
        self.ids = list(self.coco.anns.keys())
        self.transform = transform

    def __getitem__(self, index):
        """Returns one data pair (image and encoded caption)."""
        coco = self.coco
        ann_id = self.ids[index] # 주석 번호
        caption = coco.anns[ann_id]['caption'] # caption text
        img_id = coco.anns[ann_id]['image_id'] # image id (cf. image_id가 같아도 ann_id가 달라 caption 구분 가능)
        path = self.filenames[index]

        image = Image.open(os.path.join(path)).convert('RGB')
        # resize image
        image = image.resize([299, 299], Image.ANTIALIAS) 
        if self.transform is not None:
            image = self.transform(image)
        target = caption.lower().rstrip('.') # 마침표 제거
        return image, target

    def __len__(self):
        return len(self.filenames)


def get_loader(filenames, json, transform, batch_size, shuffle):
    """Returns torch.utils.data.DataLoader for custom coco dataset."""
    # COCO caption dataset
    coco = CocoDataset(filenames=filenames,
                       json=json,
                       transform=transform)
    
    # Data loader for COCO dataset
    # This will return (images, captions) for each iteration.
    # images: a tensor of shape (batch_size, 3, 299, 299).
    # captions: a tuple of shape (batch_size, 1).
    data_loader = torch.utils.data.DataLoader(dataset=coco, 
                                              batch_size=batch_size,
                                              shuffle=shuffle)
    return data_loader

In [6]:
# Image preprocessing, normalization for the pretrained resnet
transform = transforms.Compose([ 
    transforms.ToTensor(), 
    transforms.Normalize((0.485, 0.456, 0.406),    # mean
                         (0.229, 0.224, 0.225))])  # std

# Create train loader
coco_train_dataloader = get_loader(filenames=train_image_paths, json='datasets/annotations/captions_train2014.json',
                             transform=transform, batch_size=batch_size, shuffle=True)

# Create valid loader
coco_valid_dataloader = get_loader(filenames=valid_image_paths, json='datasets/annotations/captions_train2014.json',
                             transform=transform, batch_size=batch_size, shuffle=False)

loading annotations into memory...
Done (t=0.69s)
creating index...
index created!
loading annotations into memory...
Done (t=0.63s)
creating index...
index created!


In [7]:
# Example

coco = COCO('datasets/annotations/captions_train2014.json')
# pycocotools COCO 모듈을 이용해 annotation의 key값을 읽으면 id 번호가 출력
print(list(coco.anns.keys())[0:5])

# caption indexing
coco.anns[48]['caption']

loading annotations into memory...
Done (t=0.63s)
creating index...
index created!
[48, 67, 126, 148, 173]


'A very clean and well decorated empty bathroom'

In [8]:
# cf. values로는 다음과 같은 포멧으로 json 파일이 읽힘
print(list(coco.anns.values())[0:5])

[{'image_id': 318556, 'id': 48, 'caption': 'A very clean and well decorated empty bathroom'}, {'image_id': 116100, 'id': 67, 'caption': 'A panoramic view of a kitchen and all of its appliances.'}, {'image_id': 318556, 'id': 126, 'caption': 'A blue and white bathroom with butterfly themed wall tiles.'}, {'image_id': 116100, 'id': 148, 'caption': 'A panoramic photo of a kitchen and dining room'}, {'image_id': 379340, 'id': 173, 'caption': 'A graffiti-ed stop sign across the street from a red car '}]


In [9]:
# check data loader output format
iter(coco_train_dataloader).next()[0].shape

torch.Size([256, 3, 299, 299])

In [10]:
iter(coco_train_dataloader).next()[1]

('a cat is sitting outside on the sidewalk',
 'a photograph of a bathroom inside a home',
 'a white toilet sitting next to a bathroom sink',
 'a large kitchen with wooden cabinets, appliances, stove and refrigerator',
 'bathroom with a pedestal wash basin and rounded shower',
 'someone has a paperweight shaped like a fisherman riding a bike on the desk in front of their computer',
 'there are two toilets in this room which have no toilet lids or toilet seats. ',
 'a chef at a japanese restaurant lighting a tower of onion rings',
 'a large open kitchen with stainless steel refrigerator',
 'a wall rack holding pots, pans, and oven pot holders',
 'woman taking picture and surfer sitting on surf board on beach',
 'the interior of a kitchen with stainless steel appliances',
 'a bunch of men are sitting around a table',
 'a woman is walking near a red bus coming down the road',
 'a soup kitchen full of soup bowls and food ',
 'the kitchen has a big screen tv and a lot of counter space',
 'a 

___
[코드 부연설명]

* dataloader에서는 현재 위와 같이 caption text 그대로를 출력하도록 해서 TextEncoder단에서 BERT pre-trained 모델에 적용할 때, preprocessing과정인 tokenization도 하도록 함
* 이유
    1. huggingface transformers에서 pretrained 모델에 input으로 들어가는 포멧이 정해져 있음 [링크](https://huggingface.co/transformers/model_doc/bert.html#transformers.BertModel)
        
        * preprocess를 dataloader 단에서 하고, encoding된 text를 바로 pre-trained 모델에 넘겨주려 했으나 구현 상의 오류 발생(헤결 X)
 
        ```Python
        from transformers import BertTokenizer, BertModel
        import torch

        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertModel.from_pretrained('bert-base-uncased')

        inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
        outputs = model(**inputs)

        last_hidden_states = outputs.last_hidden_state
        ```
   
   
    2. Keras 예제에서도 각 Encoder에서 pretrained 모델에 구현된 preprocess 모듈을 사용하여 그때 그때 처리하는 것으로 구현 되어 있음
    
* `last_hidden_state` 의 dim=768 (bert-base 기준) [참고링크](https://pysnacks.com/bert-text-classification-with-fine-tuning/#how-to-fine-tune-bert-for-text-classification)
    * BERT output으로 CLS token embedding을 사용함

In [None]:
# !pip install ipywidgets
# !jupyter nbextension enable --py widgetsnbextension

In [11]:
# Example

from transformers import BertTokenizer, BertModel
import torch

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

# 'tf': Return TensorFlow tf.constant objects.

# 'pt': Return PyTorch torch.Tensor objects.

# 'np': Return Numpy np.ndarray objects.
    
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
outputs = model(**inputs)

last_hidden_states = outputs.last_hidden_state

print(inputs)
print(last_hidden_states.shape) # batch_size, seq_length, hidden_dim

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466062.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=570.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


{'input_ids': tensor([[  101,  7592,  1010,  2026,  3899,  2003, 10140,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}
torch.Size([1, 8, 768])


## Implement the projection head
* 각 모델의 output 을 같은 공간 차원의 representation으로 표현하기 위해 projection을 해주는 역할

In [12]:
# Example
X = torch.randn((8,1000))
linear_layer = nn.Linear(1000,256)
output = linear_layer(X)
print(output.shape)

torch.Size([8, 256])


In [13]:
class ProjectEmbeddingsBlcok(nn.Module):
    def __init__(self, in_planes, num_projection_layers, projection_dims, dropout_rate=0.0):
        """Project each encoder's output into the same embedding space.
        
        Args:
            in_planes: input size. cf) the size of each model's FCN layer output. 
            num_projection_layers: number of layers to stack.
            projection_dims: size of embedding space. 
            dropout_rate: drop out rate.
        """
        super().__init__()
        self.num_projection_layers = num_projection_layers
        self.embedding = nn.Linear(in_planes, projection_dims)
        self.gelu = nn.GELU()
        self.dense = nn.Linear(projection_dims, projection_dims)
        self.dropout = nn.Dropout(dropout_rate)
        self.layernorm = nn.LayerNorm(projection_dims)
        
    def forward(self, x):
        # x.shape -> [batch_size, 1000] (Xception-Imagenet), [batch_size, 768] (BERT)
        # projected_embeddings.shape -> [8, 256], same!!
        projected_embeddings = self.embedding(x) 
        for _ in range(self.num_projection_layers):
            out = self.gelu(projected_embeddings)
            out = self.dense(out)
            out = self.dropout(out)
            # residual connection
            out = projected_embeddings + out 
            projected_embeddings = self.layernorm(out)
        return projected_embeddings

## Implement the vision encoder

* pre-trained된 Xception 모델을 그대로 사용 (`trainable=False`)
    * 참고. 케라스 코드
    ```Python
        xception = keras.applications.Xception(
                include_top=False, weights="imagenet", pooling="avg"
            )
        ```
* Xception 모델 : Inception 모델을 발전시킴 [참고](https://sotudy.tistory.com/14)
    * Imagenet으로 학습된 weight를 가짐. 1000개의 class
* `timm` module : Pytorch Image Models ([timm](https://fastai.github.io/timmdocs/#How-to-use), [github](https://github.com/rwightman/pytorch-image-models#models)) 
    * `timm` is a deep-learning library created by Ross Wightman and is a collection of SOTA computer vision models, layers, utilities, optimizers, schedulers, data-loaders, augmentations and also training/validating scripts with ability to reproduce ImageNet training results.

In [16]:
# Example
import timm

avail_pretrained_models = timm.list_models(pretrained=True)
len(avail_pretrained_models), avail_pretrained_models

(452,
 ['adv_inception_v3',
  'cait_m36_384',
  'cait_m48_448',
  'cait_s24_224',
  'cait_s24_384',
  'cait_s36_384',
  'cait_xs24_384',
  'cait_xxs24_224',
  'cait_xxs24_384',
  'cait_xxs36_224',
  'cait_xxs36_384',
  'coat_lite_mini',
  'coat_lite_small',
  'coat_lite_tiny',
  'coat_mini',
  'coat_tiny',
  'convit_base',
  'convit_small',
  'convit_tiny',
  'cspdarknet53',
  'cspresnet50',
  'cspresnext50',
  'deit_base_distilled_patch16_224',
  'deit_base_distilled_patch16_384',
  'deit_base_patch16_224',
  'deit_base_patch16_384',
  'deit_small_distilled_patch16_224',
  'deit_small_patch16_224',
  'deit_tiny_distilled_patch16_224',
  'deit_tiny_patch16_224',
  'densenet121',
  'densenet161',
  'densenet169',
  'densenet201',
  'densenetblur121d',
  'dla34',
  'dla46_c',
  'dla46x_c',
  'dla60',
  'dla60_res2net',
  'dla60_res2next',
  'dla60x',
  'dla60x_c',
  'dla102',
  'dla102x',
  'dla102x2',
  'dla169',
  'dm_nfnet_f0',
  'dm_nfnet_f1',
  'dm_nfnet_f2',
  'dm_nfnet_f3',
  'dm_

In [17]:
all_densenet_models = timm.list_models('*densenet*')
all_densenet_models

['densenet121',
 'densenet121d',
 'densenet161',
 'densenet169',
 'densenet201',
 'densenet264',
 'densenet264d_iabn',
 'densenetblur121d',
 'tv_densenet121']

In [18]:
class VisionEncoder(nn.Module):
    def __init__(self, in_planes, num_projection_layers, projection_dims, dropout_rate, trainable=False):
        """Image Encoder Blcok Using pre-trained Xception model.
        
        Args:
            in_planes: input size. cf) the size of each model's FCN layer output. 
            num_projection_layers: number of layers to stack.
            projection_dims: size of embedding space. 
            dropout_rate: drop out rate.
            trainable: True forfreezing pre-trained model. 
        """
        super().__init__()
        self.proj_embeddings = ProjectEmbeddingsBlcok(in_planes, num_projection_layers, projection_dims, dropout_rate)
        # Load the pre-trained Xception model to be used as the base encoder.
        self.xception = timm.create_model('xception', pretrained=True)
        # Set the trainability of the base encoder.
        if trainable:
            for param in self.xception.parameters():
                param.requires_grad = True
        else:
            for param in self.xception.parameters():
                param.requires_grad = False
        
    def forward(self, x):
        # Generate the embeddings for the images using the xception model.
        out = self.xception(x) # (8, 1000)
        out = self.proj_embeddings(out)
        return out

## Implement the text encoder
* transformers의 pre-trained bert-base 모델 사용 (`trainable=False`)
* pooled_output -> CLS token embedding

In [19]:
from transformers import BertModel

class TextEncoder(nn.Module):
    def __init__(self, in_planes, num_projection_layers, projection_dims, dropout_rate, trainable=False):
        super().__init__()
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') # Pretrained model on English
        self.proj_embeddings = ProjectEmbeddingsBlcok(in_planes, num_projection_layers, projection_dims, dropout_rate)
        # Load the BERT preprocessing module.
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        # Set the trainability of the base encoder.
        if trainable:
            for param in self.bert.parameters():
                param.requires_grad = True
        else:
            for param in self.bert.parameters():
                param.requires_grad = False

    def forward(self, x):
        # Generate embeddings for the preprocessed text using the BERT model.
        out = self.tokenizer(x, padding = "max_length", truncation=True, max_length=256, return_tensors='pt')
        out = self.bert(**out).last_hidden_state # extract last hidden state

        # pooled output -> CLS representation 부분만 추출 (8, 256, 768) -> (8, 768)
        # https://www.kaggle.com/kernelk/imdb-sa-bert-fine-tuning-gpu-92-acc
        out = out[:,:1,:].reshape(-1, out.shape[-1])
        out = self.proj_embeddings(out)
        return out     

## Train the dual encoder model

#### Notice
* Keras Example :  Note that training the model with 60,000 image-caption pairs, with a batch size of 256, takes around 12 minutes per epoch using a V100 GPU accelerator. If 2 GPUs are available, the epoch takes around 8 minutes.
* Current Implementation : 1,000 image-caption pairs, batch size of 8, CPU.
* Hyper-parameter : learning rate 빼고 Keras 설정과 동일하게 학습 진행

In [20]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('current device :', device)
# Build the models
vision_encoder = VisionEncoder(in_planes=1000, num_projection_layers=1, projection_dims=256, dropout_rate=0.1)
text_encoder = TextEncoder(in_planes=768, num_projection_layers=1, projection_dims=256, dropout_rate=0.1)

# Loss and optimizer
params = list(vision_encoder.parameters()) + list(text_encoder.parameters())
optimizer = torch.optim.AdamW(params, lr=0.0005, weight_decay=0.001)

# Train the models
num_epochs = 10  # In practice, train for at least 30 epochs
log_step = 10
save_step = 1000
model_path = 'models/'

total_step = len(coco_train_dataloader)

current device : cuda


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/xception-43020ad28.pth" to /home/subinkim/.cache/torch/hub/checkpoints/xception-43020ad28.pth
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertF

#### loss calculation
* To calculate the loss, we compute the pairwise dot-product similarity between each `caption_i` and `images_j` in the batch as the predictions(**logits**). 
* The `target similarity`(**targets**) between `caption_i` and `image_j` is computed as the **average** of the (dot-product similarity between `caption_i` and `caption_j`) and (the dot-product similarity between `image_i` and `image_j`). 
    *  caption 간의 correlation, image 간의 correlation의 평균값
        * Q. 이것의 의미? 
        * A. 당연히 자기 자신과의 correlation이 1 로 가장 높음. --> 평균값을 취해 softmax를 하면, 당연히 본 순서의 index 대로 값이 가장 클 수밖에 없음 (positive example pair가 가장 유사도 높음!)
* Then, we use crossentropy to compute the loss between the **targets** and the **logits**.
* Q. similarity 구할 때, temperature(0.05)의 의미?
    * softmax값 극단적으로 1-0으로 만들기 위해? --> binary에 근사하여 targets 로 사용
    * exponential을 쓰기 때문에, 아주 큰 숫자 값에 대해서는, 값이 조금만 차이나도 y=exp(x)의 결과가 매우 크게 나타남
    * softmax에 들어가는 ` (captions_similarity + images_similarity) / (2 * temperature)` 출력 결과
    ```Python
    tensor([[4393.1533, 4074.8098, 1520.8954, 1167.3274, 4091.5378, 1153.2083,
         2225.4690, 4181.8931],
        [4074.8098, 4360.6641, 1295.5466,  985.3911, 4093.9922,  917.9226,
         1908.7693, 3965.3970],
        [1520.8954, 1295.5466, 4363.8037, 4145.8589, 1130.6233, 4079.5464,
         3581.9004, 1500.7577],
        [1167.3274,  985.3911, 4145.8589, 4377.1626,  836.4001, 4152.3042,
         3463.3833, 1195.7922],
        [4091.5378, 4093.9922, 1130.6233,  836.4001, 4348.4033,  841.7334,
         1886.1943, 4057.7800],
        [1153.2083,  917.9226, 4079.5464, 4152.3042,  841.7334, 4378.2651,
         3496.7959, 1198.8381],
        [2225.4690, 1908.7693, 3581.9004, 3463.3833, 1886.1943, 3496.7959,
         4361.3433, 2310.0374],
        [4181.8931, 3965.3970, 1500.7576, 1195.7922, 4057.7803, 1198.8381,
         2310.0374, 4394.0605]], grad_fn=<DivBackward0>)
         ```

In [21]:
# Example
m = nn.Softmax(dim=1)
input = torch.tensor([[9875.1265, 5795.5768, 6854.2154]])
output = m(input)
print(output)

tensor([[1., 0., 0.]])


___

In [22]:
def compute_loss(image_embeddings, caption_embeddings, temperature):

    ##########
    # logits
    ##########
    # logits[i][j] is the dot_similarity(caption_i, image_j).
    logits = (
        torch.matmul(caption_embeddings, image_embeddings.T)
        / temperature
    )

    ##########
    # targets
    ##########
    # images_similarity[i][j] is the dot_similarity(image_i, image_j).
    images_similarity = (
        torch.torch.matmul(image_embeddings, image_embeddings.T)
        / temperature 
    )
    # captions_similarity[i][j] is the dot_similarity(caption_i, caption_j).
    captions_similarity = (
        torch.torch.matmul(caption_embeddings, caption_embeddings.T)
        / temperature
    )
    # targets[i][j] = avarage dot_similarity(caption_i, caption_j) and dot_similarity(image_i, image_j).
    softmax = nn.Softmax(dim=1) # 행의 합 1
    targets = softmax(
        (captions_similarity + images_similarity) / (2 * temperature)
    )

    ##########
    # loss 
    ##########    
    # Compute the loss for the captions using crossentropy
    criterion = nn.CrossEntropyLoss()
    targets = torch.where(targets == 1)[0]
    targets = targets.reshape(-1)
    
    # 개념 적인 이해
    # caption x image 여서 각 행이 한 caption에 대한 image와의 유사도
    captions_loss = criterion(logits, targets)

    # Compute the loss for the images using crossentropy
    # image x caption 여서 각 행이 한 image에 대한 caption과의 유사도
    images_loss = criterion(logits.T, targets.T)
    
    # Return the mean of the loss over the batch.
    return (captions_loss + images_loss) / 2

In [23]:
for epoch in range(num_epochs):
    for i, (images, captions) in enumerate(coco_train_dataloader): # collate_fn

        # Set mini-batch dataset
        images = images
        captions = captions

        # Forward, backward and optimize
        vision_features = vision_encoder(images) # torch.Size([8, 256]) 
        text_features = text_encoder(captions) # torch.Size([8, 256])
                
        loss = compute_loss(image_embeddings=vision_features, caption_embeddings=text_features, temperature=0.05)
        vision_encoder.zero_grad()
        text_encoder.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Print log info
        if i % log_step == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                  .format(epoch, num_epochs, i, total_step, loss.item())) 

        # Save the model checkpoints
#         if (i+1) % save_step == 0:
    torch.save(vision_encoder.state_dict(), os.path.join(
        model_path, 'vision_encoder-{}-{}.ckpt'.format(epoch+1, i+1)))
    torch.save(text_encoder.state_dict(), os.path.join(
        model_path, 'text_encoder-{}-{}.ckpt'.format(epoch+1, i+1)))

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Epoch [0/10], Step [0/118], Loss: 598.4899
Epoch [0/10], Step [10/118], Loss: 599.0793
Epoch [0/10], Step [20/118], Loss: 514.1037
Epoch [0/10], Step [30/118], Loss: 590.0313
Epoch [0/10], Step [40/118], Loss: 442.0206
Epoch [0/10], Step [50/118], Loss: 452.0010
Epoch [0/10], Step [60/118], Loss: 303.3466
Epoch [0/10], Step [70/118], Loss: 303.4099
Epoch [0/10], Step [80/118], Loss: 274.5798
Epoch [0/10], Step [90/118], Loss: 215.6099
Epoch [0/10], Step [100/118], Loss: 213.8963
Epoch [0/10], Step [110/118], Loss: 217.2624
Epoch [1/10], Step [0/118], Loss: 150.4297
Epoch [1/10], Step [10/118], Loss: 169.7791
Epoch [1/10], Step [20/118], Loss: 146.2765
Epoch [1/10], Step [30/118], Loss: 159.9283
Epoch [1/10], Step [40/118], Loss: 143.0205
Epoch [1/10], Step [50/118], Loss: 137.9877
Epoch [1/10], Step [60/118], Loss: 151.1434
Epoch [1/10], Step [70/118], Loss: 109.3960
Epoch [1/10], Step [80/118], Loss: 130.6513
Epoch [1/10], Step [90/118], Loss: 157.2462
Epoch [1/10], Step [100/118], Lo

In [24]:
for epoch in range(num_epochs):
    for i, (images, captions) in enumerate(coco_train_dataloader): # collate_fn

        # Set mini-batch dataset
        images = images
        captions = captions

        # Forward, backward and optimize
        vision_features = vision_encoder(images) # torch.Size([8, 256]) 
        text_features = text_encoder(captions) # torch.Size([8, 256])
                
        loss = compute_loss(image_embeddings=vision_features, caption_embeddings=text_features, temperature=0.05)
        vision_encoder.zero_grad()
        text_encoder.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Print log info
        if i % log_step == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                  .format(epoch, num_epochs, i, total_step, loss.item())) 

        # Save the model checkpoints
#         if (i+1) % save_step == 0:
    torch.save(vision_encoder.state_dict(), os.path.join(
        model_path, 'vision_encoder-{}-{}.ckpt'.format(epoch+1+10, i+1)))
    torch.save(text_encoder.state_dict(), os.path.join(
        model_path, 'text_encoder-{}-{}.ckpt'.format(epoch+1+10, i+1)))

Epoch [0/10], Step [0/118], Loss: 19.6893
Epoch [0/10], Step [10/118], Loss: 22.4893
Epoch [0/10], Step [20/118], Loss: 18.0451
Epoch [0/10], Step [30/118], Loss: 16.3588
Epoch [0/10], Step [40/118], Loss: 17.3910
Epoch [0/10], Step [50/118], Loss: 19.6444
Epoch [0/10], Step [60/118], Loss: 18.8600
Epoch [0/10], Step [70/118], Loss: 14.3356
Epoch [0/10], Step [80/118], Loss: 17.9814
Epoch [0/10], Step [90/118], Loss: 18.7041
Epoch [0/10], Step [100/118], Loss: 20.8581
Epoch [0/10], Step [110/118], Loss: 16.6678
Epoch [1/10], Step [0/118], Loss: 14.7673
Epoch [1/10], Step [10/118], Loss: 16.2103
Epoch [1/10], Step [20/118], Loss: 15.5074
Epoch [1/10], Step [30/118], Loss: 19.3236
Epoch [1/10], Step [40/118], Loss: 13.9958
Epoch [1/10], Step [50/118], Loss: 16.2889
Epoch [1/10], Step [60/118], Loss: 15.6037
Epoch [1/10], Step [70/118], Loss: 19.6922
Epoch [1/10], Step [80/118], Loss: 18.2435
Epoch [1/10], Step [90/118], Loss: 20.1453
Epoch [1/10], Step [100/118], Loss: 15.3459
Epoch [1/1

___

# Inference

## Search for images using natural language queries

* inference 과정
1) 이미지를 모두 embedding 시킴
2) query text를 embedding 시킴
3) query와 모든 이미지들 간의 similarity 계산 후 top k 개 의 유사한 이미지 retrieval

In [None]:
print("Loading vision and text encoders...")
vision_encoder = keras.models.load_model("vision_encoder")
text_encoder = keras.models.load_model("text_encoder")
print("Models are loaded.")


def read_image(image_path):
    image_array = tf.image.decode_jpeg(tf.io.read_file(image_path), channels=3)
    return tf.image.resize(image_array, (299, 299))


print(f"Generating embeddings for {len(image_paths)} images...")
image_embeddings = vision_encoder.predict(
    tf.data.Dataset.from_tensor_slices(image_paths).map(read_image).batch(batch_size),
    verbose=1,
)
print(f"Image embeddings shape: {image_embeddings.shape}.")

#### Retrieve relevent images

In [None]:
def find_matches(image_embeddings, queries, k=9, normalize=True):
    # Get the embedding for the query.
    query_embedding = text_encoder(tf.convert_to_tensor(queries))
    # Normalize the query and the image embeddings.
    if normalize:
        image_embeddings = tf.math.l2_normalize(image_embeddings, axis=1)
        query_embedding = tf.math.l2_normalize(query_embedding, axis=1)
    # Compute the dot product between the query and the image embeddings.
    dot_similarity = tf.matmul(query_embedding, image_embeddings, transpose_b=True)
    # Retrieve top k indices.
    results = tf.math.top_k(dot_similarity, k).indices.numpy()
    # Return matching image paths.
    return [[image_paths[idx] for idx in indices] for indices in results]

In [None]:
query = "a family standing next to the ocean on a sandy beach with a surf board"
matches = find_matches(image_embeddings, [query], normalize=True)[0]

plt.figure(figsize=(20, 20))
for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(mpimg.imread(matches[i]))
    plt.axis("off")