# Set up

Use APTM ipykernel

## Define Path

In [2]:
import os, sys
from pathlib import Path
ROOT_PATH = Path('../../paper_clones/APTM').resolve()
sys.path.append(str(ROOT_PATH))
IMAGE_PATH = Path('../../DATASET/CUHK-PEDES/imgs').resolve()
sys.path.append(str(IMAGE_PATH))
MODEL_PATH = ROOT_PATH/'MODEL'/'ft_cuhk'/'checkpoint_best.pth'
ANNO_PATH = ROOT_PATH/'data/finetune'

## Import libarires

In [3]:
import torch; print(torch.cuda.is_available())
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import yaml
import torch.nn.functional as F
from easydict import EasyDict as edict
from torchinfo import summary
import argparse
import os
import math
import numpy as np
import random
import time
import datetime
import json
from pathlib import Path
from prettytable import PrettyTable
import torch
from torchvision import transforms
from torchvision.transforms import InterpolationMode
import PIL
import torch.backends.cudnn as cudnn
import torch.distributed as dist
#### LOCAL MODULES
from models import aptm, bert, model_retrieval, swin_transformer, tokenization_bert
from models.model_retrieval import APTM_Retrieval
from models.tokenization_bert import BertTokenizer
import utils
from dataset import create_dataset, create_sampler, create_loader
from dataset.re_dataset import TextMaskingGenerator
from scheduler import create_scheduler
from optim import create_optimizer
from trains import train, train_attr
from train_pa100ks import train_pa100k, train_pa100k_only_img_classifier
from reTools import evaluation, mAP
from reTools import evaluation_attr, itm_eval_attr
from reTools import evaluation_attr_only_img_classifier, itm_eval_attr_only_img_classifier

  from .autonotebook import tqdm as notebook_tqdm


True
/home/jovyan/workspace/BA-PRE_THESIS/REPORT


2023-10-29 11:48:55.585492: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [4]:
#FOR VECTOR DATABASE
import pymilvus
from pymilvus import MilvusClient, Collection
from pymilvus import connections, utility
from pymilvus import Collection, DataType, FieldSchema, CollectionSchema
from collections import defaultdict
from tqdm import tqdm
import torch
import numpy as np
import pandas as pd
import json
from matplotlib import pyplot as plt
from PIL import Image
import gradio as gr

In [5]:
#ZILLIZ SERVICE
MILVUS_URI = "https://in01-e512c229cfb7739.aws-us-west-2.vectordb.zillizcloud.com:19531"
TOKEN = "db_admin:HelloKitty!"
USER = "db_admin"
PASSWORD = "HelloKitty!"
pymilvus.connections.connect(
    "default",
    uri=MILVUS_URI,
    user=USER,
    password=PASSWORD
)

## Load model config for finetune on CUHK-PEDES

In [6]:
config_reader = lambda file: yaml.load(open(file, 'r'), Loader=yaml.Loader)
config_path = ROOT_PATH / "configs"
retrieval_cuhk_config = config_reader(config_path / "Retrieval_cuhk.yaml")
retrieval_pa100k_config = config_reader(config_path / "Retrieval_pa100k.yaml")
config = retrieval_cuhk_config
config['vision_config'] = config_path / 'config_swinB_384.json'
config['text_encoder'] = 'bert-base-uncased'
config['text_config'] = config_path / 'config_bert.json'

## Load dataset config

In [7]:
config['image_root'] = IMAGE_PATH
config['test_file'] = ANNO_PATH/'cuhk_test.json'
config['val_file'] = ANNO_PATH/'cuhk_val.json'
config['train_file'] = [ANNO_PATH/'cuhk_train.json']

# Define class APTM_Inferencer

In [8]:
class APTMInference(APTM_Retrieval):
    def __init__(self, config):
        super().__init__(config)
        self.config_dict = config
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        self.cuhk_norm = transforms.Normalize((0.38901278, 0.3651612, 0.34836376), (0.24344306, 0.23738699, 0.23368555))
        self.test_transform = transforms.Compose([
            transforms.Resize((config['h'], config['w']), interpolation=InterpolationMode.BICUBIC),
            transforms.ToTensor(),
            self.cuhk_norm,
        ])

    def encode_image(self, image, device='cuda'):
        self.vision_encoder.to(device)
        self.vision_proj.to(device)
        # output embedding sequence, dim = 1024
        image_embs_sequence = self.vision_encoder(image.to(device))
        # global feature from CLS
        image_cls_feature = self.vision_proj(image_embs_sequence[:, 0, :])
        image_cls_feature = F.normalize(image_cls_feature, dim=-1)
        
        # self.vision_encoder.to('cpu')
        # self.vision_proj.to('cpu')
        
        return {
            "sequence": image_embs_sequence,
            "cls_feature": image_cls_feature
        }

    def tokenize_text(self, texts):
        text_model_input = self.tokenizer(
            texts,
            padding="max_length",
            truncation=True,
            max_length=self.config_dict["max_tokens"],
            return_tensors="pt",
        )
        return text_model_input

    def encode_text(self, text: str, device='cuda'):
        self.text_encoder.bert.to(device)
        self.text_proj.to(device)
        # tokenization
        text_model_input = self.tokenize_text(text).to(device)
        text_attention_mask = text_model_input.attention_mask
        # output embedding sequence, dim = 768
        text_embs_sequence = self.text_encoder.bert(
            text_model_input.input_ids,
            attention_mask=text_model_input.attention_mask,
            return_dict=True,
            mode="text",
        ).last_hidden_state
        # cls embedding
        text_cls_feature = self.text_proj(text_embs_sequence[:, 0, :])
        text_cls_feature = F.normalize(text_cls_feature, dim=-1)
        ########################## Get out of GPU
        # self.text_encoder.bert.to('cpu')
        self.text_proj.to('cpu')
        return {
            "sequence": text_embs_sequence,
            "attention_mask": text_attention_mask,
            "cls_feature": text_cls_feature,
        }

    def encode_cross_modal(self, image_sequence, text_sequence, text_attention_mask, device='cuda'):
        # image attention mask
        image_attention_mask = torch.ones(
            image_sequence.size()[:-1], dtype=torch.long
        ).to(device)
        # output embedding
        cross_sequence = self.text_encoder.bert.to(device)(
            encoder_embeds=text_sequence,
            attention_mask=text_attention_mask,
            encoder_hidden_states=image_sequence,
            encoder_attention_mask=image_attention_mask,
            return_dict=True,
            mode="fusion",
        ).last_hidden_state
        # cls token
        cross_cls_feature = cross_sequence[:, 0, :]

        # self.text_encoder.bert.to('cpu')
        return {"sequence": cross_sequence, "cls_feature": cross_cls_feature}

    def check_matching(self, cross_cls_feature, device='cuda'):
        logits = self.itm_head.to(device)(cross_cls_feature.to(device))
        probs = torch.nn.functional.softmax(logits, dim=-1)
        class_res = probs.argmax(dim=-1)
        # self.itm_head.to('cpu')
        return {'logits': logits, 'probs': probs, 'class': class_res}
    
    def read_image(self, img_path):
        img = PIL.Image.open(img_path).convert("RGB")
        return self.test_transform(img)

#     def end2end_filter(self, image_sequence_dict, text_dict):
#         image_infer_result = self.encode_image(image)
#         text_infer_result = 

In [9]:
model = APTMInference(retrieval_cuhk_config)
model.load_pretrained(
    MODEL_PATH, config, is_eval=False
)

use_swin
### Loading pretrained vision encoder
### Loading pretrained text encoder
load checkpoint from /home/jovyan/workspace/BA-PRE_THESIS/paper_clones/APTM/MODEL/ft_cuhk/checkpoint_best.pth
missing_keys:  []
vision_encoder missing_keys:  []
unexpected_keys:  []


# Create dataset, dataloader

In [10]:
config['batch_size_test'] = 16
train_dataset, val_dataset, test_dataset = create_dataset('re_cuhk', config)

In [11]:
## samplers = None: not use distributed mode
train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset], 
                                                      samplers=[None, None, None], 
                                                      batch_size=[config['batch_size_train']] + [
                                                          config['batch_size_test']] * 2,
                                                      num_workers=[4, 4, 4], is_trains=[True, False, False],
                                                      collate_fns=[None, None, None])
len(test_dataset), len(test_loader)

(3074, 193)

# Class DatabaseBuilder (Zilliz-Milvus)

__If haven't built database, run these code__

__Else go to class VectorSearcher__

In [12]:
class DatabaseBuilder:
    def __init__(self, collection_name, data, drop=False):
        # if drop: drop existing database
        check_collection = utility.has_collection(collection_name) if drop else None
        drop_result = utility.drop_collection(collection_name) if check_collection else None
        self.data = data
        self.model_id = model
        self.vector_name = 'vector'
        self.schema = self.build_schema()
        self.index_params = {
            "index_type": "AUTOINDEX",
            "metric_type": "IP",
            "params": {}
        }

        self.collection = Collection(
            name=collection_name,
            schema=self.schema
        )

    def build_index(self):
        self.collection.flush()
        self.collection.create_index(
            field_name=self.vector_name,
            index_params=self.index_params
        )
        self.collection.load()

    def build_schema(self, vector_dim=256):
        doc_id = FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, description="doc_id")
        p_ID = FieldSchema(name="p_ID", dtype=DataType.INT16, description="p_ID")
        image_path = FieldSchema(name="image_path", dtype=DataType.VARCHAR, description="image_path", max_length=100)
        vector = FieldSchema(name=self.vector_name, dtype=DataType.FLOAT_VECTOR, dim=vector_dim)
        schema = CollectionSchema(
            fields=[doc_id, p_ID, image_path, vector],
            auto_id=False,
            description="Demo collection"
        )
        return schema

    def make_entities(self, start_index, num_batch = 30000):
        tempdata = self.data[start_index: start_index + num_batch]
        entities = [tempdata['id'].astype('int64'), tempdata['p_ID'].astype('int16'), tempdata['image_path']]
        vectors = tempdata['vector']
        entities.append(vectors)
        entities = [x.to_numpy() for x in entities]
        return entities

    def build_database(self):
        cur_index = 0
        num_batch = 10000
        while cur_index < len(self.data):
            entities = self.make_entities(cur_index, num_batch)
            self.collection.insert(entities)
            cur_index += num_batch
        self.build_index()

## Inference to get image embeddings

__If have built database, don't run these code__

In [15]:
person_ID = [x['image_id'] for x in test_dataset.ann]
image_path = [str(IMAGE_PATH/x['image']) for x in test_dataset.ann]
vector_matrix = torch.empty(len(image_path), 256)
print(vector_matrix.shape)
for image, index in test_loader:
    img_emb = model.encode_image(image)['cls_feature'].cpu().detach()
    vector_matrix[index] = img_emb
dataset_dict = {'vector': list(vector_matrix.numpy()), 'p_ID': person_ID, 'image_path': image_path, 'id': list(range(0, len(person_ID)))}
data_table = pd.DataFrame(dataset_dict)
data_table.head(5)

torch.Size([3074, 256])


Unnamed: 0,vector,p_ID,image_path,id
0,"[0.0050568427, 0.067113854, -0.06266234, 0.029...",12004,/home/jovyan/workspace/BA-PRE_THESIS/DATASET/C...,0
1,"[0.024172317, 0.06336201, -0.005884793, 0.0051...",12004,/home/jovyan/workspace/BA-PRE_THESIS/DATASET/C...,1
2,"[-0.0137637295, -0.015836067, -0.057155482, 0....",12004,/home/jovyan/workspace/BA-PRE_THESIS/DATASET/C...,2
3,"[0.021615053, 0.09207693, 0.05455772, -0.04414...",12005,/home/jovyan/workspace/BA-PRE_THESIS/DATASET/C...,3
4,"[-0.028524911, 0.024824591, 0.051545624, 0.027...",12005,/home/jovyan/workspace/BA-PRE_THESIS/DATASET/C...,4


In [116]:
person_ID = [x['image_id'] for x in val_dataset.ann]
image_path = [str(IMAGE_PATH/x['image']) for x in val_dataset.ann]
vector_matrix = torch.empty(len(image_path), 256)
for image, index in val_loader:
    img_emb = model.encode_image(image)['cls_feature'].cpu().detach()
    vector_matrix[index] = img_emb
dataset_dict_val = {'vector': list(vector_matrix.numpy()), 'p_ID': person_ID, 'image_path': image_path, 'id': list(range(100000,100000 + len(person_ID)))}
data_table_val = pd.DataFrame(dataset_dict_val)
data_table_val.head(5)

Unnamed: 0,vector,p_ID,image_path,id
0,"[-0.008985316, 0.030980144, -0.08232519, -0.05...",11004,/home/jovyan/workspace/BA-PRE_THESIS/DATASET/C...,100000
1,"[-0.042043097, -0.013589532, -0.02761598, -0.0...",11004,/home/jovyan/workspace/BA-PRE_THESIS/DATASET/C...,100001
2,"[-0.02434203, 0.054291096, -0.08493884, -0.073...",11004,/home/jovyan/workspace/BA-PRE_THESIS/DATASET/C...,100002
3,"[0.0067673596, 0.018824464, -0.038445998, -0.1...",11004,/home/jovyan/workspace/BA-PRE_THESIS/DATASET/C...,100003
4,"[-0.02271693, 0.0028380656, -0.12827235, -0.02...",11005,/home/jovyan/workspace/BA-PRE_THESIS/DATASET/C...,100004


## Build database & Insert

__IF HAVEN'T BUILD DATABASE, CHANGE REBUILD_FLAD TO TRUE__

In [44]:
demo_collection = 'DEMO_APTM_CUHK_2910'

In [45]:
REBUILD_FLAG = True

In [117]:
if REBUILD_FLAG: #
    db_builder = DatabaseBuilder(data=data_table, collection_name=demo_collection, drop=True) #CREATE DATABASE BUILDER
    db_builder.build_database() #INSERT TEST

In [118]:
if REBUILD_FLAG:
    db_builder = DatabaseBuilder(data=data_table_val, collection_name=demo_collection, drop=False)
    db_builder.build_database() #INSERT VAL

# Class VectorSearcher

In [103]:
import PIL
class VectorSearcher:
    def __init__(self, model, client=None, collection_name='DEMO_APTM_CUHK'):
        self.model = model
        if client is None:
            self.client = MilvusClient(
                uri=MILVUS_URI,
                token=TOKEN
            )
        else:
            self.client = client
        self.collection_name = collection_name

    def search(self, texts):
        query_infer_dict = self.model.encode_text(texts)
        query_emb = query_infer_dict['cls_feature']
        #search with batchsize = len(texts)
        res = self.client.search(
            collection_name = self.collection_name,
            data = query_emb.detach().cpu().numpy(),
            limit=100,
            output_fields = ['p_ID', 'image_path', 'id', 'vector'] 
        )
        return {'text_infer_dict': query_infer_dict, 'result': res}

    def filter(self, text, temp_search_result=None, topk=50):
        # this function is applied for 1 text
        # if has found candidates, don't search again
        if temp_search_result is None:
            res = self.search(text)
        else:
            res = temp_search_result
        # Get text query embeddings
        text_infer_dict = res['text_infer_dict']
        text_sequence = text_infer_dict['sequence']
        text_attention_mask = text_infer_dict['attention_mask']
        # Get candidates image path
        result = res['result'][0][:topk]
        image_paths = [x['entity']['image_path'] for x in result]
        final_result_list = []
        probs_result = []
        # Infer by cross encoder
        for img_path in image_paths:
            # Get image sequence again
            img = self.model.read_image(img_path).reshape(1, 3, 384, 128)
            img_infer_dict = self.model.encode_image(img)
            img_sequence = img_infer_dict['sequence']
            # Get cross embeddings 
            cross_emb_cls = self.model.encode_cross_modal(img_sequence, text_sequence, text_attention_mask)['cls_feature']
            cross_dict_result = self.model.check_matching(cross_emb_cls)
            prob = cross_dict_result['probs'][0][1] # batchsize 1 -> index 0 -> class 1 = Match
            # if prob < 0.2:
            #     continue
            final_result_list.append(img_path)
            probs_result.append(prob.cpu().detach().numpy())
        sort_index = np.argsort(np.array(probs_result))[::-1]
        final_result_list = np.array(final_result_list)[sort_index]
        probs_result = np.array(probs_result)[sort_index]
        return {
            "image_path": final_result_list,
            "probs": probs_result
        }

In [104]:
vsearch = VectorSearcher(model, collection_name=demo_collection)

# GUI

In [119]:
import gradio as gr
import PIL

In [120]:
TEMP_SEARCH_RESULT = None

In [133]:
import random, time
def generate_random():
    curr_time = round(time.time()*1000)
    random.seed(curr_time)
    k1 = random.randint(0, 2)
    # k1 = 0
    if k1 == 1:
        k = random.randint(0, len(test_dataset))
        demo_text = test_dataset.ann[k]['caption'][0]
        demo_img = PIL.Image.open(IMAGE_PATH/test_dataset.ann[k]['image']).convert('RGB').resize((124,384))
    else:
        k = random.randint(0, len(val_dataset))
        demo_text = val_dataset.ann[k]['caption'][0]
        demo_img = PIL.Image.open(IMAGE_PATH/val_dataset.ann[k]['image']).convert('RGB').resize((124,384))
    return demo_text, demo_img

In [134]:
def gui_candidate_process(texts):
    if type(texts) != list: 
        texts = [texts]
    res = vsearch.search(texts)['result'][0]
    TEMP_SEARCH_RESULT = res
    images = [PIL.Image.open(x['entity']['image_path']) for x in res]
    distance = [x['distance'] for x in res]
    return list(zip(images, distance))

In [135]:
def gui_filter_process(text):
    res = vsearch.filter([text], TEMP_SEARCH_RESULT)
    img_paths= res['image_path']
    probs = [str(x) for x in res['probs']]
    images = [PIL.Image.open(img_path) for img_path in img_paths]
    return list(zip(images, probs))

In [137]:
with gr.Blocks() as demo:
    with gr.Row():
        summit_button_0 = gr.Button(value="Click to generate random query from test set")
        random_text = gr.Textbox(label="Caption of random image")
    # with gr.Row():
        random_image = gr.Image(type="pil", height=384, width=184)
        
    with gr.Row():
        query_1 = gr.Textbox(label="Text query")
        summit_button_1 = gr.Button(value="Ranking candidates by Cosine Similarity")
    with gr.Column(variant="panel"):
        with gr.Row():
            gallery_1 = gr.Gallery(
                label="Found images", show_label=False, elem_id="gallery",
                    columns=[3], rows=[2], object_fit="contain", height="auto")
    with gr.Row():
        query_2 = gr.Textbox(label="Text query")
        summit_button_2 = gr.Button(value="Re-ranking by Image-Text Matching probabilities")
    with gr.Column(variant="panel"):
        with gr.Row():
            gallery_2 = gr.Gallery(
                label="Found images", show_label=False, elem_id="gallery",
                    columns=[3], rows=[2], object_fit="contain", height="auto")
    summit_button_0.click(generate_random, inputs=None, outputs=[random_text, random_image])
    summit_button_1.click(gui_candidate_process, inputs=query_1, outputs=gallery_1)
    summit_button_2.click(gui_filter_process, inputs=query_2, outputs=gallery_2)

demo.launch(inline=True, share=True)

Running on local URL:  http://127.0.0.1:7878
Running on public URL: https://1e79423820ce05e74c.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


