In [1]:
import numpy as np
import json
import os
import os.path as op
import logging
from tqdm import tqdm
from numba import jit

In [44]:
class TSVFile(object):
    def __init__(self, tsv_file, generate_lineidx=False):
        self.tsv_file = tsv_file
        self.lineidx = op.splitext(tsv_file)[0] + '.lineidx'
        self._fp = None
        self._lineidx = None
        # the process always keeps the process which opens the file. 
        # If the pid is not equal to the currrent pid, we will re-open the file.
        self.pid = None
        # generate lineidx if not exist
        if not op.isfile(self.lineidx) and generate_lineidx:
            generate_lineidx_file(self.tsv_file, self.lineidx)

    def __del__(self):
        if self._fp:
            self._fp.close()

    def __str__(self):
        return "TSVFile(tsv_file='{}')".format(self.tsv_file)

    def __repr__(self):
        return str(self)

    def num_rows(self):
        self._ensure_lineidx_loaded()
        return len(self._lineidx)

    def seek(self, idx):
        self._ensure_tsv_opened()
        self._ensure_lineidx_loaded()
        try:
            pos = self._lineidx[idx]
        except:
            logging.info('{}-{}'.format(self.tsv_file, idx))
            raise
        self._fp.seek(pos)
        return [s.strip() for s in self._fp.readline().split('\t')]

    def seek_first_column(self, idx):
        self._ensure_tsv_opened()
        self._ensure_lineidx_loaded()
        pos = self._lineidx[idx]
        self._fp.seek(pos)
        return read_to_character(self._fp, '\t')

    def __getitem__(self, index):
        return self.seek(index)

    def __len__(self):
        return self.num_rows()

    def _ensure_lineidx_loaded(self):
        if self._lineidx is None:
            logging.info('loading lineidx: {}'.format(self.lineidx))
            with open(self.lineidx, 'r') as fp:
                self._lineidx = [int(i.strip()) for i in fp.readlines()]

    def _ensure_tsv_opened(self):
        if self._fp is None:
            self._fp = open(self.tsv_file, 'r')
            self.pid = os.getpid()

        if self.pid != os.getpid():
            logging.info('re-open {} because the process id changed'.format(self.tsv_file))
            self._fp = open(self.tsv_file, 'r')
            self.pid = os.getpid()

def sort_objects(raw_objects, img_h=None, img_w = None):
    #Return two list, one without considering the order but value the counts, the other consider the order of the objects
    
    #Design a mechanism to sort objects
    confidence_low = 0.3
    dimension_ratio = 0.01
    
    #Score or filtering?
    filtered_objects = {}
    filtered_objects_list = []
    for x in raw_objects:
        x_class = x['class']
        x_conf = x['conf']
        x_h = x['rect'][2]-x['rect'][0]
        x_w = x['rect'][3]-x['rect'][1]
        x_ratio = (x_h*x_w)/(img_h*img_w)
        if x_conf >= confidence_low and x_ratio >=dimension_ratio:
            if filtered_objects.get(x_class, None) is None:
                filtered_objects[x_class] = {"conf": [x_conf], "size_ratio":[x_ratio]}
            else:
                filtered_objects[x_class]["conf"].append(x_conf)
                filtered_objects[x_class]["size_ratio"].append(x_ratio)
            filtered_objects_list.append(x_class)
    #print(filtered_objects)
    #print(filtered_objects), sort by size and count
#     sort_count_objects = {k:v for k,v in sorted(filtered_objects.items(), key=lambda item: len(item[1]['conf']))}
#     sort_ratio_objects = {k:v for k,v in sorted(filtered_objects.items(), key=lambda item: max(item[1]['size_ratio']))} 
    
#     print(sort_count_objects)
#     print(sort_ratio_objects)
    
    total_counts = 0
    for v in filtered_objects.values():
        total_counts += len(v['conf'])
        
    #Now score the obejects
#     objects_count_score = {k:len(v['conf'])/total_counts*10 for k,v in filtered_objects.items()}
#     objects_ratio_score = {k:max(v['size_ratio'])*10 for k,v in filtered_objects.items()}
    sorted_objects = []
    for k,v in filtered_objects.items():
        current_count_score = len(v['conf'])/total_counts*10
        current_size_ratio_score = max(v['size_ratio'])*10
        sorted_objects.append((k, current_count_score+current_size_ratio_score))
    filtered_sorted_objects = [y[0] for y in sorted(sorted_objects, key=lambda x: x[1], reverse=True)]
    
    
    return filtered_objects_list, filtered_sorted_objects

In [11]:
#Get the Annotation Data
annotation = "/data/home/zmykevin/vinvl_data/CC/model_0060000/annotations/0/dataset_cc.json"
annotation_data = json.load(open(annotation, "r"))

In [45]:
#Get the object Info
cc_objects_captions = {}
id_range = [x for x in range(12)]
for id_ in id_range: 
    print("Create the features for chunk {}".format(id_))
    #Lets load the VinVL Features
    coco_vinvl_path = "/data/home/zmykevin/vinvl_data/CC/model_0060000/{}".format(id_)
    coco_vinvl_prediction_tsv = TSVFile(os.path.join(coco_vinvl_path, "predictions.tsv"))
    num_rows = coco_vinvl_prediction_tsv.num_rows()
    
    for i in tqdm(range(num_rows)):
        current_prediction = coco_vinvl_prediction_tsv.seek(i)
        #print(current_prediction)
        img_id = int(current_prediction[0])
        #print(img_id)
        #get object list
        #objects = [x['class'] for x in json.loads(current_prediction[1])['objects']]
        current_info = json.loads(current_prediction[1])
        raw_objects = current_info['objects']
        objects, objects_no_rep = sort_objects(raw_objects, current_info['image_h'], current_info['image_w']) 
        #print(objects)
        cc_objects_captions[img_id] = {"objects": " ".join(objects), "objects_no_rep": " ".join(objects_no_rep)}

  0%|          | 757/259957 [00:00<00:34, 7566.69it/s]

Create the features for chunk 0


100%|██████████| 259957/259957 [00:36<00:00, 7133.61it/s]
  0%|          | 0/259795 [00:00<?, ?it/s]

Create the features for chunk 1


100%|██████████| 259795/259795 [00:36<00:00, 7164.16it/s]
  0%|          | 0/259835 [00:00<?, ?it/s]

Create the features for chunk 2


100%|██████████| 259835/259835 [00:37<00:00, 6988.18it/s]
  0%|          | 0/259914 [00:00<?, ?it/s]

Create the features for chunk 3


100%|██████████| 259914/259914 [00:42<00:00, 6161.63it/s]
  0%|          | 0/259908 [00:00<?, ?it/s]

Create the features for chunk 4


100%|██████████| 259908/259908 [00:43<00:00, 5921.82it/s]
  0%|          | 0/259632 [00:00<?, ?it/s]

Create the features for chunk 5


100%|██████████| 259632/259632 [00:45<00:00, 5714.18it/s]
  0%|          | 0/259496 [00:00<?, ?it/s]

Create the features for chunk 6


100%|██████████| 259496/259496 [00:46<00:00, 5546.71it/s]


Create the features for chunk 7


100%|██████████| 259436/259436 [00:46<00:00, 5590.52it/s]
  0%|          | 0/259596 [00:00<?, ?it/s]

Create the features for chunk 8


100%|██████████| 259596/259596 [00:44<00:00, 5806.42it/s]


Create the features for chunk 9


100%|██████████| 259589/259589 [00:44<00:00, 5795.56it/s]
  0%|          | 0/259750 [00:00<?, ?it/s]

Create the features for chunk 10


100%|██████████| 259750/259750 [00:46<00:00, 5633.09it/s]
  0%|          | 0/259346 [00:00<?, ?it/s]

Create the features for chunk 11


100%|██████████| 259346/259346 [00:44<00:00, 5782.67it/s]


In [46]:
for x in annotation_data['images']:
    img_id = x['imgid']
    caption = x['sentences'][0]['raw']
    if cc_objects_captions.get(img_id, None) is not None:
        cc_objects_captions[img_id]['caption'] = caption


In [46]:
#print(len(cc_objects_captions))
with open("/data/home/zmykevin/vinvl_data/CC/cc_objects_captions_sorted.json", "w") as f:
    json.dump(cc_objects_captions, f)

In [2]:
with open("/data/home/zmykevin/vinvl_data/CC/cc_objects_captions_sorted.json", "r") as f:
    cc_objects_captions = json.load(f)

print("finish loading the captions")

finish loading the captions


In [3]:
# object_rep_list = [x['objects_no_rep'] for x in cc_objects_captions]
# caption_list = [x['caption'] for x in cc_objects_captions]
id_list = []
object_list = []
caption_list = []
for k, v in cc_objects_captions.items():
    id_list.append(k)
    object_list.append(v['objects_no_rep'])
    caption_list.append(v['caption'])
#     break
# print(id_list)
# print(object_rep_list)
# print(caption_list)

In [4]:
all_list = object_list + caption_list

In [5]:
from sklearn.feature_extraction.text import TfidfVectorizer
# Create TfidfVectorizer object
vectorizer = TfidfVectorizer()

# Generate matrix of word vectors
tfidf_matrix = vectorizer.fit_transform(all_list)

In [6]:
#Save this tfidf_matrix
# with open('/data/home/zmykevin/vinvl_data/CC/CC_tfidf.npy', 'rb') as f:
#     tfidf_matrix=np.load(f, allow_pickle=True)

In [7]:
print(tfidf_matrix.shape)

(6232508, 46246)


In [8]:
object_tfidf = tfidf_matrix[:3116254]
caption_tfidf = tfidf_matrix[3116254:]

In [9]:
# del tfidf_matrix
from sklearn.metrics.pairwise import cosine_similarity
# from scipy import sparse
from scipy.sparse import csr_matrix
import sparse_dot_topn.sparse_dot_topn as ct


def awesome_cossim_top(A, B, ntop, lower_bound=0):
    # force A and B as a CSR matrix.
    # If they have already been CSR, there is no overhead
    A = A.tocsr()
    B = B.tocsr()
    M, _ = A.shape
    _, N = B.shape
 
    idx_dtype = np.int32
 
    nnz_max = M*ntop
 
    indptr = np.zeros(M+1, dtype=idx_dtype)
    indices = np.zeros(nnz_max, dtype=idx_dtype)
    data = np.zeros(nnz_max, dtype=A.dtype)

    ct.sparse_dot_topn(
        M, N, np.asarray(A.indptr, dtype=idx_dtype),
        np.asarray(A.indices, dtype=idx_dtype),
        A.data,
        np.asarray(B.indptr, dtype=idx_dtype),
        np.asarray(B.indices, dtype=idx_dtype),
        B.data,
        ntop,
        lower_bound,
        indptr, indices, data)

    return csr_matrix((data,indices,indptr),shape=(M,N))

In [12]:
import time
import csv
# t1 = time.time()
# #cosine_sim = cosine_similarity(object_tfidf[:1000], caption_tfidf)
# cosine_sim = awesome_cossim_top(object_tfidf[:10], caption_tfidf.T, 10, 0.1)
# print(time.time() - t1)
def save_csv(visualization, output_path):
    csv_columns = ['original_id', "original_obj", "original_caption", "retrieved_1", "retrieved_2", "retrieved_3"]
    with open(output_path, "w") as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=csv_columns)
        writer.writeheader()
        for data in visualization:
            writer.writerow(data)
    
def save_json(closest_captions, output_path):
    with open(output_path, "w") as f:
        json.dump(closest_captions, f)
    
def load_csv(csv_path):
    loaded_data = []
    with open(csv_path, "r") as csv_file:
        csv_reader = csv.reader(csv_file,delimiter=',')
        for i, row in enumerate(csv_reader):
            if i == 0:
                header_info = row
            else:
                current_data_point = {}
                for j, x in enumerate(row):
                    current_data_point[header_info[j]] = x
                loaded_data.append(current_data_point)
    return loaded_data

batch_size = 1000
save_checkpoint = 100000
#closest_captions = {}
# visualization = []
output_path = "/data/home/zmykevin/vinvl_data/CC"
closest_captions = json.load(open(os.path.join(output_path, "captions_retrieved.json"),"r"))
visualization = load_csv(os.path.join(output_path, "captions_retrieved.csv"))
starting_index = len(visualization)
error_ids = []

checkpoint_index = 0
for i in tqdm(range(starting_index,caption_tfidf.shape[0], batch_size)):
    valid_batch_size = batch_size if i+batch_size < caption_tfidf.shape[0] else caption_tfidf.shape[0]-i
    cosine_sim = awesome_cossim_top(object_tfidf[i:i+valid_batch_size], caption_tfidf.T, 10, -1)
    non_zeros = cosine_sim.nonzero()
    sparserows = non_zeros[0]
    sparsecols = non_zeros[1]
    
    for index in range(valid_batch_size):
        try:
            original_id = id_list[index + i]
            right_side = [id_list[x] for x in sparsecols[index*10:index*10+10]]
            current_output = {}
            current_output['original_id'] = original_id
            current_output['original_obj'] = cc_objects_captions[original_id]['objects_no_rep']
            current_output['original_caption'] = cc_objects_captions[original_id]['caption']
            current_output['retrieved_1'] = cc_objects_captions[right_side[0]]['caption']
            current_output['retrieved_2'] = cc_objects_captions[right_side[1]]['caption']
            current_output['retrieved_3'] = cc_objects_captions[right_side[2]]['caption']
            #store the other info into json
            closest_captions[original_id] =right_side
            visualization.append(current_output)
        except:
            print("The bug occurs at {}".format(i+index))
            print("The error id is: {}".format(original_id))
            error_ids.append(original_id)
    if int((i+valid_batch_size)/save_checkpoint) > checkpoint_index:
        save_json(closest_captions, os.path.join(output_path, "captions_retrieved_backup.json")) 
        save_csv(visualization, os.path.join(output_path, "captions_retrieved_backup.csv"))
        save_json(error_ids, os.path.join(output_path, "captions_retrieved_error_ids.json"))
        checkpoint_index +=1
#Store the information to csv and json
output_path = "/data/home/zmykevin/vinvl_data/CC"
#save the json file

save_json(closest_captions, os.path.join(output_path, "captions_retrieved.json"))   
#save the csv file
save_csv(visualization, os.path.join(output_path, "captions_retrieved.csv"))
save_json(error_ids, os.path.join(output_path, "captions_retrieved_error_ids.json"))

  0%|          | 4/2217 [01:50<16:59:15, 27.63s/it]


KeyboardInterrupt: 

In [None]:
print(cosine_sim.shape)
    

In [32]:
retrieved_image_id = id_list[3110513]
print(retrieved_image_id)

3327912


In [33]:
print(cc_objects_captions[str(3)])
print(cc_objects_captions[retrieved_image_id])

{'objects': 'fireplace floor ceiling shadow letter wall wall building light ledge fire room bar log brick wall wall brick wall desk light sign word letter wall bench tile', 'objects_no_rep': 'floor ceiling fireplace building room bar log desk brick word shadow ledge fire sign tile letter bench wall light', 'caption': 'interior design of modern living room with fireplace in a new house'}
{'objects': 'fire fire fire fire fire log fireplace sky fire letter fire wire word rock log cable letter rock', 'objects_no_rep': 'fireplace log fire cable letter sky rock wire word', 'caption': 'log on fire in a fireplace'}


0
{'objects': 'bus bus bus bus building bus bus building shirt man shirt person road shirt person person person person man car man man person man bus person bus door person person person person pant person bus person window window person street person window man bus shirt man roof man man shirt person person boat hat balcony van pant bus sign roof', 'objects_no_rep': 'pant street building shirt man person hat car roof sign door van window road balcony boat bus', 'caption': 'a very typical bus station'}
1
{'objects': 'drum drum hair microphone dress man drum person woman face shoe shirt person leg hand drum wall shoe woman skirt girl shirt leg poster dress woman leg person arm person picture person tripod neck top', 'objects_no_rep': 'man neck wall woman shirt skirt face top shoe arm person poster hand girl dress tripod drum hair leg picture microphone', 'caption': 'sierra looked stunning in this top and this skirt while performing with person at their former university'}
2
{'objects': 

1746
{'objects': 'person person person person person person stadium shoe person person person shirt shirt field sign person grass ceiling shoe person person man person man flag person light uniform person man person sign person person man short man person seat shoe sign short man', 'objects_no_rep': 'ceiling flag field shirt person man sign stadium grass uniform seat shoe short light', 'caption': 'the team celebrate after the match .'}
1747
{'objects': 'glasses glasses man man shirt ear shirt pant clothes rug beard head hair purse watch face hand shelf hand purse face watch floor mouth shirt wrist purse ear chair nose tattoo arm sign nose shelf shirt head arm shelf shelf', 'objects_no_rep': 'man tattoo glasses mouth wrist pant shirt face watch chair beard purse arm floor hand ear nose sign shelf clothes head rug hair', 'caption': 'fashion business : we respect how all people live'}
1748
{'objects': 'suit dress suit hair grass flower hair man flower tree tree woman tree girl head tree s

{'objects': 'flag pole flag pole building sky building light roof tree sign wall word window letter letter window cloud', 'objects_no_rep': 'flag building tree word roof sign letter pole window sky wall cloud light', 'caption': 'building may move to another building'}
3391
{'objects': 'hand arm person finger arm background finger finger finger finger finger finger finger fingernail', 'objects_no_rep': 'person hand background finger fingernail arm', 'caption': 'hand presenting business people videos in his hands in the darkness'}
3392
{'objects': 'tree man boot boot man jean pant man man jean shirt scarf stroller woman wheel woman man window person woman jacket wall building jacket jacket woman jean wheel woman shoe person chair man paper wall vest sign wall person jacket man sign person wheel person suit jean jean shirt bottle car jacket sweater shoe person boot hair shirt paper woman sidewalk pant pant person shoe sweater sweater', 'objects_no_rep': 'scarf man sweater vest jean wall w

{'objects': 'building window window window window window umbrella building building umbrella sky window umbrella van building building roof window window table window sign window house window window building fence building window tree awning motorcycle window window tire trash can window window sign pole window flag window road window person building street window bench car window window road window window van window sidewalk roof house window window sign window house flag umbrella roof light building wall bike', 'objects_no_rep': 'bench pole wall motorcycle house bike street flag awning car roof sidewalk window road building tire person sign van sky light umbrella tree trash can fence table', 'caption': 'main square or the city center'}
5161
{'objects': 'uniform short man man man stripe man jersey hair jersey pant short man hair number number hand jersey man uniform shirt jersey hair short shirt short man hair hair shirt face uniform head hand stripe arm head shirt logo belt hand hand

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)

