In [None]:
!pip install -q faiss-gpu

In [None]:
!pip install -q translate
!pip install -q underthesea==1.3.5a3
!pip install -q underthesea[deep]
!pip install -q pyvi
!pip install -q langdetect
!pip install -q googletrans==3.1.0a0

In [4]:
lst_keyframes = []
for dirname, _, filenames in os.walk('/kaggle/input/'):
    for filename in filenames:
        if filename.endswith('.jpg'):
            lst_keyframes.append(os.path.join(dirname, filename))
lst_keyframes.sort()

id2img_fps = dict()
for i, img_path in enumerate(lst_keyframes):
    id2img_fps[i] = img_path

# Model

In [None]:
from transformers import CLIPModel, CLIPImageProcessor, CLIPTokenizer

In [None]:
model = [
    ("openai/clip-vit-base-patch32", 'clipB32'),
    ("facebook/metaclip-b16-fullcc2.5b", 'metaB16'),
    ('facebook/metaclip-l14-fullcc2.5b', 'metaL14'),
    ('facebook/metaclip-h14-fullcc2.5b', 'metaH14')
]

model_name, bin_name = model[0]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
model = CLIPModel.from_pretrained(model_name)
image_processor = CLIPImageProcessor.from_pretrained(model_name)
text_processor = CLIPTokenizer.from_pretrained(model_name)

In [None]:
model = model.to(device)

# Class

## Translation

In [None]:
class Translation():
    def __init__(self, from_lang='vi', to_lang='en', mode='google'):
        # The class Translation is a wrapper for the two translation libraries, googletrans and translate.
        self.__mode = mode
        self.__from_lang = from_lang
        self.__to_lang = to_lang

        if mode in 'googletrans':
            self.translator = googletrans.Translator()
        elif mode in 'translate':
            self.translator = translate.Translator(from_lang=from_lang,to_lang=to_lang)

    def preprocessing(self, text):

        return text.lower()

    def __call__(self, text):

        text = self.preprocessing(text)
        return self.translator.translate(text) if self.__mode in 'translate' \
                else self.translator.translate(text, dest=self.__to_lang).text

## Text preprocessing

In [None]:
class Text_Preprocessing():
    def __init__(self, stopwords_path='./dict/vietnamese-stopwords-dash.txt'):
        with open(stopwords_path, 'rb') as f:
            lines = f.readlines()
        self.stop_words = [line.decode('utf8').replace('\n','') for line in lines]

    def find_substring(self, string1, string2):

        match = SequenceMatcher(None, string1, string2, autojunk=False).find_longest_match(0, len(string1), 0, len(string2))
        return string1[match.a:match.a + match.size].strip()

    def remove_stopwords(self, text):

        text = ViTokenizer.tokenize(text)
        return " ".join([w for w in text.split() if w not in self.stop_words])

    def lowercasing(self, text):
        return text.lower()

    def uppercasing(self, text):
        return text.upper()

    def add_accents(self, text):

        return ViUtils.add_accents(u"{}".format(text))

    def remove_accents(self, text):

        return ViUtils.remove_accents(u"{}".format(text))

    def sentence_segment(self, text):

        return underthesea.sent_tokenize(text)

    def text_norm(self, text):

        return underthesea.text_normalize(text)

    def text_classify(self, text):

        return underthesea.classify(text)

    def sentiment_analysis(self, text):

        return underthesea.sentiment(text)

    def __call__(self, text):

        text = self.lowercasing(text)
        text = self.remove_stopwords(text)
        # text = self.remove_accents(text)
        # text = self.add_accents(text)
        text = self.text_norm(text)
        categories = self.text_classify(text)
        return categories

## Myfaiss

In [None]:
class Myfaiss:
    def __init__(self, bin_file : str,id2img_fps, device, model):
        self.index= self.load_bin_file(bin_file)
        self.id2img_fps= id2img_fps
        self.device= device
        self.model= model

    def load_bin_file(self, bin_file: str):
        return faiss.read_index(bin_file)


    def show_images(self, image_paths):
        fig = plt.figure(figsize=(15, 10))
        columns = int(math.sqrt(len(image_paths)))
        rows = int(np.ceil(len(image_paths)/columns))

        for i in range(1, columns*rows +1):
          img = plt.imread(image_paths[i - 1])
          ax = fig.add_subplot(rows, columns, i)
          ax.set_title(image_paths[i - 1].split('/')[-1].split('.')[0], fontsize=10)
#           ax.set_title('/'.join(image_paths[i - 1].split('/')[-3:]))

          plt.imshow(img)
          plt.axis("off")

        plt.show()

    def image_search(self, id_query, k, bin_file):

        query_feats = self.index.reconstruct(id_query).reshape(1,-1)

        scores, idx_image = self.index.search(query_feats, k=k)
        idx_image = idx_image.flatten()

        infos_query = list(map(self.id2img_fps.get, list(idx_image)))
        image_paths = [info for info in infos_query]


        return scores, idx_image, infos_query, image_paths

    def text_search(self, text, k):
        translater= Translation()
        if detect(text) == 'vi':
            text = translater(text)

        ###### TEXT FEATURES EXACTING ######
        inputs = text_processor([text], return_tensors="pt").to(device)
        text_features = model.get_text_features(**inputs).cpu().detach().numpy().astype(np.float32)        
        
        ###### SEARCHING #####
        scores, idx_image = self.index.search(text_features, k=k)
        idx_image = idx_image.flatten()

        ###### GET INFOS KEYFRAMES_ID ######
        infos_query = list(map(self.id2img_fps.get, list(idx_image)))
        image_paths = [info for info in infos_query]

        return scores, idx_image, infos_query, image_paths

# Inference

In [None]:
bin_file=os.path.join(root_features, f'{bin_name}.bin')
faiss_test= Myfaiss(bin_file, id2img_fps, device, model)

In [None]:
text = 'Một người đang trả lời phỏng vấn. Bức tường phía sau người này được treo rất nhiều hàm răng cá mập.'

scores, idx_image, infos_query, image_paths = faiss_test.text_search(text, k=1)
faiss_test.show_images(image_paths)
print(scores)