In [None]:
import os
import cv2
import gc
import numpy as np
import pandas as pd
import itertools
from tqdm import tqdm
#from tqdm.autonotebook import tqdm # if first is not working
import albumentations as A
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import torch
from torch import nn
import torch.nn.functional as F
import timm
import pprint

# for hubert-base-cc
from transformers import AutoModel, AutoTokenizer

# for sentiment-hts2-hubert-hungarian
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# for xlm-roberta-base, distilbert-base-multilingual-cased
from transformers import AutoTokenizer, AutoModelForMaskedLM

import pprint

In [None]:
pp = pprint.PrettyPrinter(indent=4, width=200, depth=None, stream=None, compact=False, sort_dicts=False)

In [None]:
image_path = "images_30k"
captions_path = "captions"
file="captions_30k_hu"
csv_name = f"{file}.csv"
df = pd.read_csv(f"{captions_path}/{file}.txt", delimiter="|")
df.columns = ['image', 'caption_number', 'caption']
df['caption'] = df['caption']
df['caption_number'] = df['caption_number']
ids = [id_ for id_ in range(len(df) // 5) for i in range(5)]
df['id'] = ids
df

In [None]:
class CFG:
    debug = False
    image_path = image_path
    captions_path = captions_path
    batch_size = 32
    num_workers = 0 # it was 4 on colab, should be 0 on our pc-s
    head_lr = 1e-3
    image_encoder_lr = 1e-4
    text_encoder_lr = 1e-5
    weight_decay = 1e-3
    patience = 1 # for server running it could be 10, 30 or even 50
    patience_val = 10
    factor = 0.8
    epochs = 1000 # could be 1000
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    model_name = 'tf_efficientnetv2_b1'
    image_embedding = 1280  # for efficientnetv2b1

    text_model_name = ["SZTAKI-HLT/hubert-base-cc", "xlm-roberta-base", "NYTK/sentiment-hts2-hubert-hungarian", "distilbert-base-multilingual-cased"]
    text_encoder_model = text_model_name[1]
    text_embedding = 768
    text_tokenizer = text_encoder_model
    max_length = 200

    pretrained = True   # for both image encoder and text encoder
    trainable = False   # for both image encoder and text encoder
    temperature = 1.0

    # image size
    size = 240  # for efficientnetv2b1

    # for projection head; used for both image and text encoders
    num_projection_layers = 1
    projection_dim = 256 
    dropout = 0.1

In [None]:
class CLIPDataset(torch.utils.data.Dataset):
    def __init__(self, image_filenames, captions, tokenizer, transforms):
        """
        image_filenames and cpations must have the same length; so, if there are
        multiple captions for each image, the image_filenames must have repetitive
        file names 
        """

        self.image_filenames = image_filenames
        self.captions = list(captions)
        self.encoded_captions = tokenizer(
            list(captions), padding=True, truncation=True, max_length=CFG.max_length
        )
        self.transforms = transforms

    def __getitem__(self, idx):
        item = {
            key: torch.tensor(values[idx])
            for key, values in self.encoded_captions.items()
        }

        image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx]}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.transforms(image=image)['image']
        item['image'] = torch.tensor(image).permute(2, 0, 1).float()
        #item['caption'] = self.captions[idx]   # it is not needed, just to help visualize

        return item


    def __len__(self):
        return len(self.captions)



def get_transforms(mode="train"):
    if mode == "train":
        return A.Compose(
            [
                A.Resize(CFG.size, CFG.size, always_apply=True),
                A.Normalize(max_pixel_value=255.0, always_apply=True),
            ]
        )
    else:
        return A.Compose(
            [
                A.Resize(CFG.size, CFG.size, always_apply=True),
                A.Normalize(max_pixel_value=255.0, always_apply=True),
            ]
        )

In [None]:
class ImageEncoder(nn.Module):
    def __init__(
        self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable
    ):
        super().__init__()
        self.model = timm.create_model(
            model_name,
            pretrained, 
            num_classes=0,
            global_pool="avg"
        )
        for p in self.model.parameters():
            p.requires_grad = trainable

    def forward(self, x):
        return self.model(x)

In [None]:
class TextEncoder(nn.Module):
    def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable):
        super().__init__()

        print(model_name)
        if(CFG.text_model_name.index(model_name)==0):
            self.model = AutoModel.from_pretrained(model_name) # hubert-base-cc
        elif(CFG.text_model_name.index(model_name)==1 or CFG.text_model_name.index(model_name)==3):
            self.model = AutoModelForMaskedLM.from_pretrained(model_name)  # xlm-roberta-base or distilbert-base-multilingual-cased
        elif(CFG.text_model_name.index(model_name)==2):
            self.model = AutoModelForSequenceClassification.from_pretrained(model_name) # sentiment-hts2-hubert-hungarian
            
        for p in self.model.parameters():
            p.requires_grad = trainable

        # we are using the CLS token hidden representation as the sentence's embedding
        self.target_token_idx = 0

    def forward(self, input_ids, attention_mask):

        if(CFG.text_model_name.index(CFG.text_encoder_model)==0):
            output = self.model(input_ids=input_ids, attention_mask=attention_mask)    # hubert-base-cc
            last_hidden_state = output.last_hidden_state   # hubert-base-cc
        else:
            output = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)    # other
            last_hidden_state = output.hidden_states[-1]   # other
        return last_hidden_state[:, self.target_token_idx, :]

In [None]:
class ProjectionHead(nn.Module):
    def __init__(
        self,
        embedding_dim,
        projection_dim=CFG.projection_dim,
        dropout=CFG.dropout
    ):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)
    
    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x

In [None]:
class CLIPModel(nn.Module):
    def __init__(
        self,
        temperature=CFG.temperature,
        image_embedding=CFG.image_embedding,
        text_embedding=CFG.text_embedding,
    ):
        super().__init__()
        self.image_encoder = ImageEncoder()
        self.text_encoder = TextEncoder()
        self.image_projection = ProjectionHead(embedding_dim=image_embedding)
        self.text_projection = ProjectionHead(embedding_dim=text_embedding)
        self.temperature = temperature

    def forward(self, batch):
        # Getting Image and Text Features
        image_features = self.image_encoder(batch["image"])
        text_features = self.text_encoder(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
        )
        # Getting Image and Text Embeddings (with same dimension)
        image_embeddings = self.image_projection(image_features)
        text_embeddings = self.text_projection(text_features)

        # Calculating the Loss
        logits = (text_embeddings @ image_embeddings.T) / self.temperature
        images_similarity = image_embeddings @ image_embeddings.T
        texts_similarity = text_embeddings @ text_embeddings.T
        targets = F.softmax(
            (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
        )
        texts_loss = cross_entropy(logits, targets, reduction='none')
        images_loss = cross_entropy(logits.T, targets.T, reduction='none')
        loss =  (images_loss + texts_loss) / 2.0 # shape: (batch_size)
        return loss.mean()


def cross_entropy(preds, targets, reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1)
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()

In [None]:
def make_train_valid_test_dfs():
    dataframe = pd.read_csv(f"{CFG.captions_path}/{csv_name}")
    max_id = dataframe["id"].max() + 1 if not CFG.debug else 100
    image_ids = np.arange(0, max_id)
    np.random.seed(42)
    np.random.shuffle(image_ids)

    test_ids = image_ids[int(len(image_ids)*0.8):]
    valid_ids = image_ids[int(len(image_ids)*0.6):int(len(image_ids)*0.8)]
    train_ids = image_ids[:int(len(image_ids)*0.6)]
    train_dataframe = dataframe[dataframe["id"].isin(train_ids)].reset_index(drop=True)
    valid_dataframe = dataframe[dataframe["id"].isin(valid_ids)].reset_index(drop=True)
    test_dataframe = dataframe[dataframe["id"].isin(test_ids)].reset_index(drop=True)
    
    return train_dataframe, valid_dataframe, test_dataframe

In [None]:
def build_loaders(dataframe, tokenizer, mode):
    transforms = get_transforms(mode=mode)
    dataset = CLIPDataset(
        dataframe["image"].values,
        dataframe["caption"].values,
        tokenizer=tokenizer,
        transforms=transforms,
    )
    dataloader = torch.utils.data.DataLoader(dataset,batch_size=CFG.batch_size,num_workers=CFG.num_workers,shuffle=True if mode == "train" else False,)
    return dataloader

In [None]:
def get_model_embeddings(test_df, model_path):

    tokenizer = AutoTokenizer.from_pretrained(CFG.text_tokenizer)

    test_loader = build_loaders(test_df, tokenizer, mode="test")
    
    model = CLIPModel().to(CFG.device)
    model.load_state_dict(torch.load(model_path, map_location=CFG.device))
    model.eval()
    
    test_image_embeddings = []
    test_text_embeddings = []

    with torch.no_grad():
        for batch in tqdm(test_loader):
            # img embeddings
            image_features = model.image_encoder(batch["image"].to(CFG.device))
            image_embeddings = model.image_projection(image_features)
            test_image_embeddings.append(image_embeddings)
            # text embeddings
            text_features = model.text_encoder(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])
            text_embeddings = model.text_projection(text_features)
            test_text_embeddings.append(text_embeddings)
    return model, torch.cat(test_image_embeddings), torch.cat(test_text_embeddings) # torch.cat makes a torch from the list

In [None]:
_, _, test_df = make_train_valid_test_dfs()

In [None]:
#test_df.to_csv(f"test_df_from_30k.csv", index=False)
test_df

In [None]:
captions_w_fnames = test_df.to_dict('records')

In [None]:
#model_hubert_base_cc, image_embeddings_hubert_base_cc, text_embeddings_hubert_base_cc = get_model_embeddings(test_df, "best_models/hubert-base-cc.pt")

In [None]:
#model_sentiment_hts2_hubert_hungarian, image_embeddings_sentiment_hts2_hubert_hungarian, text_embeddings_sentiment_hts2_hubert_hungarian = get_model_embeddings(test_df, "best_models/sentiment-hts2-hubert-hungarian.pt")

In [None]:
model_xlm_roberta_base, image_embeddings_xlm_roberta_base, text_embeddings_xlm_roberta_base = get_model_embeddings(test_df, "best_models/xlm-roberta-base.pt")

In [None]:
#model_distilbert_base_multilingual_cased, image_embeddings_distilbert_base_multilingual_cased, text_embeddings_distilbert_base_multilingual_cased = get_model_embeddings(test_df, "best_models/distilbert-base-multilingual-cased.pt")

In [None]:
def find_matches_from_text_to_img(model, image_embeddings, caption, image_filenames, n=5):
    
    tokenizer = AutoTokenizer.from_pretrained(CFG.text_tokenizer)

    encoded_query = tokenizer([caption])
    batch = {
        key: torch.tensor(values).to(CFG.device)
        for key, values in encoded_query.items()
    }
    with torch.no_grad():
        text_features = model.text_encoder(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
        )
        text_embedding = model.text_projection(text_features)
    
    #print(text_embedding) embedding works fine
    
    image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
    text_embedding_n = F.normalize(text_embedding, p=2, dim=-1)

    dot_similarity = text_embedding_n @ image_embeddings_n.T

    #print("image_embeddings_n.size()",image_embeddings_n.size())    
    #print("len(image_filenames)",len(image_filenames))  # they are the same size

    values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)

    match = [image_filenames[idx] for idx in indices[::5]]
    
    return match

In [None]:
#find_matches_from_text_to_img(model_distilbert_base_multilingual_cased, 
#                              image_embeddings_distilbert_base_multilingual_cased, 
#                              'Egy fénykép egy lóról.', 
#                              image_filenames=test_df['image'].values, n=9)

In [None]:
# top 1 ?

print("len(test_df['caption'].values)",len(test_df['caption'].values))
cnt=0
res=0
test_imgs = list(test_df['image'].values)
for e in captions_w_fnames:
    cnt+=1
    print(cnt, end='\r')
    #print("expected: ",e['caption'], e['image'])
    #print("got: ",end='')
    pred_img = find_matches_from_text_to_img(model_xlm_roberta_base, image_embeddings_xlm_roberta_base, e['caption'], image_filenames=test_imgs, n=1)[0]
    #print('e['image']',e['image'])
    #print('pred_img',pred_img)
    
    if pred_img == e['image']: res += 1
    #break

print(res ,'/', cnt)

In [None]:
def append_value(dict_obj, key, value):
    # Check if key exist in dict or not
    if key in dict_obj:
        # Key exist in dict.
        # Check if type of value of key is list or not
        if not isinstance(dict_obj[key], list):
            # If type is not list then make it list
            dict_obj[key] = [dict_obj[key]]
        # Append the value in list
        dict_obj[key].append(value)
    else:
        # As key is not in dict,
        # so, add key-value pair
        dict_obj[key] = value

In [None]:
# we need a dict with an image name, and 5 captions
d_imgs_w_texts = {} 
for e in captions_w_fnames:
    append_value(d_imgs_w_texts, e['image'], e['caption'])

In [None]:
pp.pprint(d_imgs_w_texts)

In [None]:
def find_matches_from_img_to_texts(model, text_embeddings, image_fn, captions, n=5):

    transforms = get_transforms(mode='test')
    image = cv2.imread(f"{CFG.image_path}/{image_fn}")
    image = cv2.resize(image, (CFG.size, CFG.size))
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    #img = mpimg.imread(f"{CFG.image_path}/{image_fn}")
    #imgplot = plt.imshow(img)
    #plt.show() # too much img printing can destroy memory

    #print("len(captions)",len(captions))
    #print("text_embeddings.shape",text_embeddings.shape)   # same len
    image = transforms(image=image)['image']
    image = torch.tensor(image).permute(2, 0, 1).float()
    image_features = model.image_encoder(image.unsqueeze(0).to(CFG.device))
    image_embedding = model.image_projection(image_features)
    
    #print(image_embedding) # embedding is the same

    image_embedding_n = F.normalize(image_embedding, p=2, dim=-1)   # just one image
    text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)   # all caption 
    dot_similarity = image_embedding_n @ text_embeddings_n.T
    values, indices = torch.topk(dot_similarity.squeeze(0), n)
    #print("indices",indices)  # (best) indexes
    #print("values",values)   # (best) similarity values
    matches = [captions[idx] for idx in indices]
    
    #print(image_fn,matches)
    return matches

In [None]:
cnt=0
res_sum = 0
cnt_when_not_zero=0
all_cnt = len(d_imgs_w_texts)*5
for k,v  in d_imgs_w_texts.items():
    #print(cnt,end='\r')
    res=0
    # k = img_fname v = list_of_captions_for_the_img
    zs=find_matches_from_img_to_texts(model_xlm_roberta_base, text_embeddings_xlm_roberta_base, k, captions=test_df['caption'].values, n=5)
    #print("expected: ",k, v)
    #print("got: ", zs)
    for text in zs:
        if(text in v): res+=1
    #print(f'{res}/5')
    if res != 0 : cnt_when_not_zero+=1 
    res_sum += res
    cnt+=1
    #break

print(cnt_when_not_zero,'/',len(d_imgs_w_texts))
print(res_sum ,'/', all_cnt)

# 5 shot?

# model_distilbert_base_multilingual_cased
# 321 / 6357
# 366 / 31785

# model_sentiment_hts2_hubert_hungarian
# 677 / 6357
# 815 / 31785 

# model hubert_base_cc
# 706 / 6357
# 845 / 31785

# model_xlm_roberta_base
# 452 / 6357
# 520 / 31785