# Text and Image Joined representations using COLA

I trained a custom model to produce image and text embeddings using COLA: https://arxiv.org/abs/2010.10915

Those embeddings can be used as inputs to the recommendation system.

This is the inference script where we probe the embeddings to better understand what they contain.

The first part is a visualization of the embeddings using TSNE and the second part is using those embeddings to find the most similar images to a query image.


This notebook was made by @CVxTz, if you like it, you might also like some similar projects here: https://medium.com/@CVxTz

In [None]:

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from pathlib import Path

import os
for dirname, _, filenames in os.walk('../input/image-text-embeddings'):
    for filename in filenames:
        print(os.path.join(dirname, filename))
        
BASE_PATH = Path("/kaggle/input/h-and-m-personalized-fashion-recommendations/")
MODEL_PATH = Path("/kaggle/input/image-text-embeddings/ssl_resnet18_1337.ckpt")
TOKENIZER_PATH = Path("/kaggle/input/image-text-embeddings/tokenizer.json")


#### Text Utils

In [None]:
from tokenizers import Tokenizer


TOKENIZER = Tokenizer.from_file(str(TOKENIZER_PATH))


CLS_IDX = TOKENIZER.token_to_id("[CLS]")
PAD_IDX = TOKENIZER.token_to_id("[PAD]")
SEP_IDX = TOKENIZER.token_to_id("[SEP]")


def tokenize(text: str):
    raw_tokens = TOKENIZER.encode(text)

    return raw_tokens.ids


def pad_list(
    list_integers, context_size: int = 90, pad_val: int = PAD_IDX, mode="right"
):
    """

    :param list_integers:
    :param context_size:
    :param pad_val:
    :param mode:
    :return:
    """

    list_integers = list_integers[:context_size]

    if len(list_integers) < context_size:
        if mode == "left":
            list_integers = [pad_val] * (
                context_size - len(list_integers)
            ) + list_integers
        else:
            list_integers = list_integers + [pad_val] * (
                context_size - len(list_integers)
            )

    return list_integers


#### Image utils

In [None]:
import random
from pathlib import Path

import cv2
import numpy as np

cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)

import albumentations as A

SIZE = 128
SCALE = 255.0


RESIZE = A.Compose(
    [
        A.LongestMaxSize(max_size=SIZE, p=1.0),
        A.PadIfNeeded(min_height=SIZE, min_width=SIZE, p=1.0),
    ]
)
NORMALIZE = A.Normalize(
    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=SCALE
)

def read_image(image_path: Path) -> np.ndarray:

    bgr_image = cv2.imread(str(image_path))

    rgb_image = bgr_image[:, :, ::-1]

    return rgb_image


def resize(image: np.ndarray) -> np.ndarray:
    reshaped = RESIZE(image=image)["image"]

    return reshaped


def normalize(image: np.ndarray) -> np.ndarray:
    normalized = NORMALIZE(image=image)["image"]

    return normalized


def preprocess(image: np.ndarray) -> np.ndarray:
    return normalize(resize(image))


#### Model

In [None]:
import math

import torch
import torch.nn.functional as F
from pytorch_lightning import LightningModule
from torchvision import models
from transformers import get_cosine_schedule_with_warmup


class PositionalEncoding(torch.nn.Module):
    #  https://pytorch.org/tutorials/beginner/transformer_tutorial.html

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = torch.nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe = torch.zeros(1, max_len, d_model)
        pe[0:, :, 0::2] = torch.sin(position * div_term)
        pe[0:, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

        self.d_model = d_model

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """

        x = x + self.pe[:, : x.size(1)] / math.sqrt(self.d_model)

        return self.dropout(x)


class Cola(LightningModule):
    def __init__(
        self,
        lr=0.001,
        use_pretrained=False,
        dropout=0.2,
        d_model=128,
        n_vocab=30_000,
        smoothing=0.1,
    ):
        super().__init__()
        self.dropout = dropout

        self.lr = lr
        self.d_model = d_model
        self.n_vocab = n_vocab
        self.smoothing = smoothing

        # Vision
        self.model = models.resnet18(pretrained=use_pretrained)
        self.model.fc = torch.nn.Linear(self.model.fc.in_features, self.d_model)

        # Text
        self.item_embeddings = torch.nn.Embedding(self.n_vocab, self.d_model)
        self.pos_encoder = PositionalEncoding(
            d_model=self.d_model, dropout=self.dropout
        )
        encoder_layer = torch.nn.TransformerEncoderLayer(
            d_model=self.d_model, nhead=4, dropout=self.dropout, batch_first=True
        )
        self.encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers=4)

        # Normalizations
        self.layer_norm = torch.nn.LayerNorm(normalized_shape=self.d_model)
        self.linear = torch.nn.Linear(self.d_model, self.d_model, bias=False)
        self.do = torch.nn.Dropout(p=self.dropout)

        self.save_hyperparameters()

    def encode_image(self, x):
        x = x.permute(0, 3, 1, 2)
        x = self.do(self.model(x))
        x = torch.tanh(self.layer_norm(x))

        return x

    def encode_text(self, x):
        x = self.item_embeddings(x)
        x = self.pos_encoder(x)
        x = self.encoder(x)

        return x[:, 0, :]

    def forward(self, x):
        image, text = x

        encoded_image = self.encode_image(image)

        encoded_image_w = self.linear(encoded_image)

        encoded_text = self.encode_text(text)

        return encoded_image_w, encoded_text


#### Predict embeddings

In [None]:
import matplotlib.pyplot as plt

import torch
from tqdm import tqdm

from sklearn.manifold import TSNE

In [None]:
df = pd.read_csv(
    BASE_PATH / "articles.csv",
    nrows=None,
    dtype={
        "article_id": str,
    },
)

df["text"] = df.apply(
    lambda x: " ".join(
        [
            str(x["prod_name"]),
            str(x["product_type_name"]),
            str(x["product_group_name"]),
            str(x["graphical_appearance_name"]),
            str(x["colour_group_name"]),
            str(x["perceived_colour_value_name"]),
            str(x["index_name"]),
            str(x["section_name"]),
            str(x["detail_desc"])
        ]
    ),
    axis=1,
)

df["image_path"] = df.article_id.apply(lambda x: BASE_PATH / "images" / x[:3] / f"{x}.jpg")

df = df.sample(n=5000)

model = Cola(lr=1e-4)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.load_state_dict(torch.load(MODEL_PATH, map_location=device)["state_dict"])

model.to(device)

model.eval()

text_embeddings = []
image_embeddings = []

for image_path, text in tqdm(
    zip(df.image_path.values, df.text.values),
    total=len(df),
):

    if image_path.is_file():
        image = read_image(image_path)
    else:
        image = np.zeros((128, 128, 3))

    image = preprocess(image)

    image_t = torch.from_numpy(image).unsqueeze(0)
    image_t = image_t.to(device)

    text_t = tokenize(text)
    text_t = torch.tensor(pad_list(text_t), dtype=torch.long, device=device).unsqueeze(0)

    with torch.no_grad():
        text_embed = model.encode_text(text_t)
        image_embed = model.encode_image(image_t)

        text_embed = text_embed.squeeze().cpu().tolist()
        image_embed = image_embed.squeeze().cpu().tolist()

    text_embeddings.append(text_embed)
    image_embeddings.append(image_embed)

text_embeddings = np.array(text_embeddings)
image_embeddings = np.array(image_embeddings)



#### TSNE Representations

In [None]:

tsne = TSNE(
    n_components=2,
    init="random",
    random_state=0,
    learning_rate="auto",
    n_iter=300,
)

Y = tsne.fit_transform(image_embeddings)

fig = plt.figure(figsize=(12, 12))

for index_name in df.index_name.unique():
    plt.scatter(Y[df.index_name == index_name, 0], Y[df.index_name == index_name, 1], label=index_name, s=3)

plt.title("Cola Image embeddings by index_name")
plt.legend()
plt.show()

In [None]:

tsne = TSNE(
    n_components=2,
    init="random",
    random_state=0,
    learning_rate="auto",
    n_iter=300,
)

Y = tsne.fit_transform(text_embeddings)

fig = plt.figure(figsize=(12, 12))

for index_name in df.index_name.unique():
    plt.scatter(Y[df.index_name == index_name, 0], Y[df.index_name == index_name, 1], label=index_name, s=3)

plt.title("Cola Text embeddings by index_name")
plt.legend()
plt.show()

#### Search

In [None]:
index = 10

most_similar = np.argsort(-image_embeddings @ image_embeddings[index, :].T)[:9].tolist()

_, axs = plt.subplots(3, 3, figsize=(12, 12))
axs = axs.flatten()
for i, ax in zip(most_similar, axs):
    ax.imshow(read_image(df.image_path.values[i]))
    ax.axis('off')
    if i == index:
        ax.title.set_text("Query image")
    else:
        ax.title.set_text("Result Image")
plt.axis('off')
plt.show()

In [None]:
index = 100

most_similar = np.argsort(-image_embeddings @ image_embeddings[index, :].T)[:9].tolist()

_, axs = plt.subplots(3, 3, figsize=(12, 12))
axs = axs.flatten()
for i, ax in zip(most_similar, axs):
    ax.imshow(read_image(df.image_path.values[i]))
    ax.axis('off')
    if i == index:
        ax.title.set_text("Query image")
    else:
        ax.title.set_text("Result Image")
plt.axis('off')
plt.show()

That's it folks! We can see that those custom embeddings are well suited to the domain of the images and texts (Fashion). I'll try them on a recommender system and share the results with you soon.