# Visualizing Better Multi-Modal Fused Embeddings of Image and Question Embeddings

In [1]:
import json
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

from transformers import AutoImageProcessor, AutoModel
from transformers import BertModel, BertTokenizer

In [2]:
class VQADataset(Dataset):
    def __init__(self, dataset, image_encoder, text_encoder):
        if dataset == 'v2':
            self.image_root = 'data/vqa-v2/val2014/val2014/COCO_val2014_000000'
            data_path = 'data/vqa-v2/v2_OpenEnded_mscoco_val2014_questions.json'            
        elif dataset == 'abs':
            self.image_root = 'data/vqa-abstract/img_train/abstract_v002_train2015_0000000'
            data_path = 'data/vqa-abstract/questions_train/OpenEnded_abstract_v002_train2015_questions.json'

        with open(data_path, 'r') as file:
            data = json.load(file)

        self.dataset = dataset
        self.data = data['questions']

        self.i_processor = AutoImageProcessor.from_pretrained(image_encoder)
        self.q_tokenizer = BertTokenizer.from_pretrained(text_encoder)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data_item = self.data[idx]
        image, question = self.get_data(data_item)
        
        i_tokens = self.i_processor(images=image, return_tensors='pt')
        q_tokens = self.q_tokenizer(question, return_tensors='pt')

        # dirty way to fix dimention issue:
        i_tokens['pixel_values'] = i_tokens['pixel_values'].squeeze(0)

        for key, value in q_tokens.items():
            q_tokens[key] = value.squeeze(0)

        return i_tokens, q_tokens 
    
    def get_data(self, data):
        image_id = data['image_id']

        if self.dataset == 'v2':
            image_id = str(image_id).zfill(6)
        elif self.dataset == 'abs':
            image_id = str(image_id).zfill(5)

        image_path = f'{self.image_root}{image_id}.jpg'
        image = Image.open(image_path).convert('RGB')

        return image, data['question']


In [3]:
class VLM(nn.Module):
    def __init__(self, image_encoder, text_encoder, fusion_mode, embed_dim=768):
        super(VLM, self).__init__()

        self.fusion_mode = fusion_mode
        self.embed_dim = embed_dim

        self.i_encoder = AutoModel.from_pretrained(image_encoder)
        self.q_encoder = BertModel.from_pretrained(text_encoder)

        attn_heads = 4
        attn_embed_dim = 768
        self.attn_q2i = nn.MultiheadAttention(attn_embed_dim, attn_heads, batch_first=True)
        self.attn_i2q = nn.MultiheadAttention(attn_embed_dim, attn_heads, batch_first=True)
        
        self.attn_i = nn.MultiheadAttention(attn_embed_dim, attn_heads, batch_first=True)
        self.attn_q = nn.MultiheadAttention(attn_embed_dim, attn_heads, batch_first=True)

        if fusion_mode == 'cat': attn_embed_dim *= 2
        self.attn_e = nn.MultiheadAttention(attn_embed_dim, attn_heads, batch_first=True)


    def forward(self, i_tokens, q_tokens):
        with torch.no_grad():
            i_embeddings = self.i_encoder(**i_tokens).last_hidden_state
            q_embeddings = self.q_encoder(**q_tokens).last_hidden_state

        i_attended, _ = self.attn_q2i(i_embeddings, q_embeddings, q_embeddings)
        q_attended, _ = self.attn_i2q(q_embeddings, i_embeddings, i_embeddings)

        i_embeddings = torch.zeros((i_attended.shape[0], 1, self.embed_dim))
        i_embeddings, _ = self.attn_i(i_embeddings, i_attended, i_attended)

        q_embeddings = torch.zeros((q_attended.shape[0], 1, self.embed_dim))
        q_embeddings, _ = self.attn_q(q_embeddings, q_attended, q_attended)

        if self.fusion_mode == 'cat':
            embedding = torch.cat((i_embeddings, q_embeddings), dim=-1)  # along channels
        elif self.fusion_mode == 'cat_v2':
            i_embeddings = i_embeddings.reshape(i_embeddings.shape[0], -1, self.embed_dim // 2)
            q_embeddings = q_embeddings.reshape(q_embeddings.shape[0], -1, self.embed_dim // 2)
            embedding = torch.cat((i_embeddings, q_embeddings), dim=-1)
        elif self.fusion_mode == 'mult':
            embedding = i_embeddings * q_embeddings
        elif self.fusion_mode == 'add':
            embedding = i_embeddings + q_embeddings

        zeros = torch.zeros((embedding.shape[0], 1, self.embed_dim))
        embedding, _ = self.attn_e(zeros, embedding, embedding)

        return embedding 

In [4]:
image_encoder = 'facebook/dinov2-base'
text_encoder = 'bert-base-uncased'

v2_dataset = VQADataset('v2', image_encoder, text_encoder)
v2_dataloader = DataLoader(v2_dataset, batch_size=1, shuffle=False)

model = VLM(image_encoder, text_encoder, fusion_mode='cat_v2')

for i_tokens, q_tokens in v2_dataloader:
    embedding = model(i_tokens, q_tokens)
    break

torch.Size([1, 2, 768])
torch.Size([1, 1, 768])
