In [1]:
!pip install accelerate
! pip install ftfy regex tqdm
! pip install git+https://github.com/openai/CLIP.git

Collecting accelerate
  Downloading accelerate-0.30.1-py3-none-any.whl (302 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.6/302.6 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.10.0->accelerate)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.10.0->accelerate)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=1.10.0->accelerate)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=1.10.0->accelerate)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch>=1.10.0->accelerate)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.w

In [2]:
import numpy as np
import torch
from pkg_resources import packaging

print("Torch version:", torch.__version__)


Torch version: 2.2.1+cu121


In [3]:
import numpy as np
import torch
from pkg_resources import packaging
import clip
import cv2
from PIL import Image
import os
import json
import pickle

In [4]:
clip.available_models()
print("Torch version:", torch.__version__)
clip.available_models()
clip.tokenize('hello')

Torch version: 2.2.1+cu121


tensor([[49406,  3306, 49407,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0]], dtype=torch.int32)

In [5]:
class RelevanceEvaluator:
    def __init__(self, checkpoint="ViT-B/32"):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model, self.preprocess = clip.load(checkpoint, device=self.device)


    def extract_image_features(self, image_path):
        if isinstance(image_path, np.ndarray):
            image = image_path
        else:
            image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(image)
        image = self.preprocess(image).unsqueeze(0).to(self.device)
        with torch.no_grad():
            image_features = self.model.encode_image(image).float()
        return image_features

    def extract_text_features(self, texts):
        if isinstance(texts, str):
            texts = [texts]
        elif isinstance(texts, np.ndarray):
            texts = texts.tolist()
        elif isinstance(texts, list):
            texts = [str(text) for text in texts]

        text_tokens = clip.tokenize(texts).to(self.device)
        with torch.no_grad():
            text_features = self.model.encode_text(text_tokens).float()
        return text_features

    def measure_similarity(self, image_path, text):
        text_features = self.extract_text_features(text)
        image_features = self.extract_image_features(image_path)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        similarity = (text_features @ image_features.T).cpu().numpy()
        return similarity

# Main ReCap

In [6]:
class ReCap:
    def __init__(self, image, caption, clip_checkpoint="ViT-B/32", generic=False):
        self.ClipModel = RelevanceEvaluator(clip_checkpoint)
        self.object_reference = ['human', 'animal', 'machine', 'insect', 'tree', 'building', 'plant', 'food', 'tool', 'house']
        self.feature_reference = ['red', 'orange', 'yellow', 'green', 'blue', 'indigo', 'violet', 'white', 'black', 'purple']
        self.object_reference_clip_score = dict()
        self.feature_reference_clip_score = dict()
        self.T_hat = None
        self.image = image
        self.caption = caption
        self.candidate_segments = None
        self.candidate_tokens = None
        self.candidate_clip_score = dict()
        self.target_tokens = dict()
        self.generic = generic

    def run(self):
        self.filteration()
        if not (self.candidate_segments and self.candidate_tokens):# and self.T_hat):
            print("Assign candidate lists!")
            return
        self.relevance_evaluation()
        self.substitution()

    def filteration(self):
        self.text_filteration()
        self.image_filteration()

    def text_filteration(self):
        # it returns a list of triplets
        # T_c = {'token_1':{t_1, l_1, g_1},.... }
        print("Assign candidate_tokens in form of T_c = {'token_0':{t_0, l_0, g_0},.... }")

    def image_filteration(self):
        # it returns a list of duals
        # I_c = {'segment_1':{i_1, l_1},.... }
        print("Assign candidate_segments in form of I_c = {'segment_0':{i_0, l_0},.... }")

    def relevance_evaluation(self):
        self.calculate_candidate_clip_score()
        self.calculate_reference_clip_score()
        self.generate_target_tokens()

    def calculate_reference_clip_score(self):
        for i_id in range(len(self.candidate_segments)):
            i_k = self.candidate_segments[f"segment_{i_id}"]
            segment, i_label = i_k
            for obj_id in range(len(self.object_reference)):
                obj_ref = self.object_reference[obj_id]
                clip_score = self.ClipModel.measure_similarity(segment, obj_ref)
                self.object_reference_clip_score[f"segment_{i_id}_obj_{obj_id}"] = clip_score
            for f_id in range(len(self.feature_reference)):
                f_ref = self.feature_reference[f_id]
                clip_score = self.ClipModel.measure_similarity(segment, f_ref)
                self.feature_reference_clip_score[f"segment_{i_id}_f_{f_id}"] = clip_score

    def calculate_candidate_clip_score(self):
        for i_id in range(len(self.candidate_segments)):
            i_k = self.candidate_segments[f"segment_{i_id}"]
            segment, i_label = i_k
            for t_id in range(len(self.candidate_tokens)):
                t_k = self.candidate_tokens[f"token_{t_id}"]
                token, t_label, g_label = t_k

                if i_label == t_label:
                    clip_score = self.ClipModel.measure_similarity(segment, token)
                    self.candidate_clip_score[f"segment_{i_id}_token_{t_id}"] = clip_score



    def generate_target_tokens(self):
        for c_id, c_score in self.candidate_clip_score.items():
            c_id_split = c_id.split("_")    # "segment_{i_k}_token_{t_k}" -> segment, i_k, token, t_k
            i_id, t_id = c_id_split[1],  c_id_split[3]
            g_k = self.candidate_tokens[f"token_{t_id}"][-1] # either f or obj
            if g_k == "f":
                reference_clip_score = self.feature_reference_clip_score.copy()
            elif g_k == "obj":
                reference_clip_score = self.object_reference_clip_score.copy()
            ref_thr, ref_id = 0, None
            for item_ref_id, ref_score in reference_clip_score.items():
                if f"segment_{i_id}" in item_ref_id:
                    thr = ref_score
                    if thr > ref_thr:
                        ref_thr = thr
                        ref_id = item_ref_id
            if ref_thr > c_score:
                self.target_tokens[f"token_{t_id}"] = [ref_id, ref_thr]

    def substitution(self):
        print('========================= SUBSTITUTION START ========================= ')
        self.T_hat = self.caption
        print(f'{self.T_hat=}')
        for target_token, ref in self.target_tokens.items():
            print(f'{target_token=} {ref=}')
            segment_info = ref[0]
            splitted = segment_info.split('_')
            segment_idx = splitted[1]

            token = self.candidate_tokens[target_token][0]
            g = self.candidate_tokens[target_token][2]
            if g == "f":
                reference_list = self.feature_reference.copy()
                feature_idx = int(splitted[3])
                substitute = reference_list[feature_idx]
            elif g == "obj":
                reference_list = self.object_reference.copy()
                object_idx = int(splitted[3])
                substitute = reference_list[object_idx]

            print(substitute)
            if self.generic:
                substitute = "unknown"
            self.T_hat = self.T_hat.replace(token, substitute)
            print(f'{self.T_hat=}')
        print('========================= SUBSTITUTION END =========================')

In [8]:
image = '/content/pink_tree_purple_deer.png'
text = 'a gray deer in front of a green tree'
recap_instance = ReCap(image=image, caption=text, clip_checkpoint="ViT-B/32", generic=False)


recap_instance.candidate_tokens = {
    'token_0': ('gray', 'fg', 'f'),
    'token_1': ('deer', 'fg', 'obj'),
    'token_2': ('green', 'bg', 'f'),
    'token_3': ('tree', 'bg', 'obj')
}
deer_segment = cv2.imread('/content/purple_deer.png')
tree_segment = cv2.imread('/content/pink_tree.png')

recap_instance.candidate_segments = {
    'segment_0': (deer_segment, 'fg'),
    'segment_1': (tree_segment, 'bg')
}

recap_instance.run()

Assign candidate_tokens in form of T_c = {'token_0':{t_0, l_0, g_0},.... }
Assign candidate_segments in form of I_c = {'segment_0':{i_0, l_0},.... }
self.T_hat='a gray deer in front of a green tree'
target_token='token_0' ref=['segment_0_f_9', array([[0.29162014]], dtype=float32)]
purple
self.T_hat='a purple deer in front of a green tree'
target_token='token_2' ref=['segment_1_f_9', array([[0.23500487]], dtype=float32)]
purple
self.T_hat='a purple deer in front of a purple tree'


In [9]:
image = '/content/yellow_house_green_lion.png'
text = 'a blue dinosaur in front of a red tree'
recap_instance = ReCap(image=image, caption=text, clip_checkpoint="ViT-B/32", generic=False)


recap_instance.candidate_tokens = {
    'token_0': ('blue', 'fg', 'f'),
    'token_1': ('dinosaur', 'fg', 'obj'),
    'token_2': ('red', 'bg', 'f'),
    'token_3': ('tree', 'bg', 'obj')
}
lion_segment = cv2.imread('/content/green_lion.png')
house_segment = cv2.imread('/content/yellow_house.png')

recap_instance.candidate_segments = {
    'segment_0': (lion_segment, 'fg'),
    'segment_1': (house_segment, 'bg')
}

recap_instance.run()

Assign candidate_tokens in form of T_c = {'token_0':{t_0, l_0, g_0},.... }
Assign candidate_segments in form of I_c = {'segment_0':{i_0, l_0},.... }
self.T_hat='a blue dinosaur in front of a red tree'
target_token='token_0' ref=['segment_0_f_3', array([[0.27266407]], dtype=float32)]
green
self.T_hat='a green dinosaur in front of a red tree'
target_token='token_1' ref=['segment_0_obj_1', array([[0.24704775]], dtype=float32)]
animal
self.T_hat='a green animal in front of a red tree'
target_token='token_2' ref=['segment_1_f_2', array([[0.26942572]], dtype=float32)]
yellow
self.T_hat='a green animal in front of a yellow tree'
target_token='token_3' ref=['segment_1_obj_9', array([[0.26037657]], dtype=float32)]
house
self.T_hat='a green animal in front of a yellow house'


In [11]:
image = '/content/pink_tree_orange_elephant.png'
text = 'digital art selected for the #'
recap_instance = ReCap(image=image, caption=text, clip_checkpoint="ViT-B/32", generic=False)


recap_instance.candidate_tokens = {
    'token_0': ('digital', 'fg', 'f'),
    'token_1': ('art', 'fg', 'obj'),
    'token_2': ('selected', 'bg', 'f'),
    'token_3': ('#', 'bg', 'obj')
}
elephant_segment = cv2.imread('/content/orange_elephant.png')
tree_segment = cv2.imread('/content/pink_tree.png')

recap_instance.candidate_segments = {
    'segment_0': (elephant_segment, 'fg'),
    'segment_1': (tree_segment, 'bg')
}

recap_instance.run()

Assign candidate_tokens in form of T_c = {'token_0':{t_0, l_0, g_0},.... }
Assign candidate_segments in form of I_c = {'segment_0':{i_0, l_0},.... }
self.T_hat='digital art selected for the #'
target_token='token_0' ref=['segment_0_f_1', array([[0.28036323]], dtype=float32)]
orange
self.T_hat='orange art selected for the #'
target_token='token_1' ref=['segment_0_obj_1', array([[0.25569206]], dtype=float32)]
animal
self.T_hat='orange animal selected for the #'
target_token='token_2' ref=['segment_1_f_9', array([[0.23500487]], dtype=float32)]
purple
self.T_hat='orange animal purple for the #'
target_token='token_3' ref=['segment_1_obj_4', array([[0.2891208]], dtype=float32)]
tree
self.T_hat='orange animal purple for the tree'
