In [1]:
import numpy as np
import lmdb
import os
import pickle
import torch
from tqdm import tqdm
import json
# from mmf.utils.file_io import PathManager
from iopath.common.file_io import PathManager as pm

PathManager = pm()

In [2]:
class PaddedFasterRCNNFeatureReader:
    def __init__(self, max_loc):
        self.max_loc = max_loc
        self.first = True
        self.take_item = False

    def _load(self, image_feat_path):
        image_info = {}
        image_info["features"] = load_feat(image_feat_path)

        info_path = "{}_info.npy".format(image_feat_path.split(".npy")[0])
        if PathManager.exists(info_path):
            image_info.update(load_feat(info_path).item())

        return image_info

    def read(self, image_feat_path):
        image_info = self._load(image_feat_path)
        if self.first:
            self.first = False
            if (
                image_info["features"].size == 1
                and "image_feat" in image_info["features"].item()
            ):
                self.take_item = True

        image_feature = image_info["features"]
#         print(image_info["features"].size)
        if self.take_item:
            item = image_info["features"].item()
            if "image_text" in item:
                image_info["image_text"] = item["image_text"]
                image_info["is_ocr"] = item["image_bbox_source"]
                image_feature = item["image_feat"]

            if "info" in item:
                if "image_text" in item["info"]:
                    image_info.update(item["info"])
                image_feature = item["feature"]

        # Handle case of features with class probs
        if (
            image_info["features"].size == 1
            and "features" in image_info["features"].item()
        ):
            item = image_info["features"].item()
            image_feature = item["features"]
            image_info["image_height"] = item["image_height"]
            image_info["image_width"] = item["image_width"]

            # Resize these to self.max_loc
            image_loc, _ = image_feature.shape
            image_info["cls_prob"] = np.zeros(
                (self.max_loc, item["cls_prob"].shape[1]), dtype=np.float32
            )
            image_info["cls_prob"][0:image_loc,] = item["cls_prob"][: self.max_loc, :]
            image_info["bbox"] = np.zeros(
                (self.max_loc, item["bbox"].shape[1]), dtype=np.float32
            )
            image_info["bbox"][0:image_loc,] = item["bbox"][: self.max_loc, :]
            image_info["num_boxes"] = item["num_boxes"]

        # Handle the case of ResNet152 features
        if len(image_feature.shape) > 2:
            shape = image_feature.shape
            image_feature = image_feature.reshape(-1, shape[-1])

        image_loc, image_dim = image_feature.shape
        tmp_image_feat = np.zeros((self.max_loc, image_dim), dtype=np.float32)
        tmp_image_feat[0:image_loc,] = image_feature[: self.max_loc, :]  # noqa
        image_feature = torch.from_numpy(tmp_image_feat)

        del image_info["features"]
        image_info["max_features"] = torch.tensor(image_loc, dtype=torch.long)
        return image_feature, image_info


class LMDBFeatureReader(PaddedFasterRCNNFeatureReader):
    def __init__(self, max_loc, base_path):
        super().__init__(max_loc)
        self.db_path = base_path

        if not PathManager.exists(self.db_path):
            raise RuntimeError(
                "{} path specified for LMDB features doesn't exists.".format(
                    self.db_path
                )
            )
        self.env = None

    def _init_db(self):
        self.env = lmdb.open(
            self.db_path,
            subdir=os.path.isdir(self.db_path),
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False,
        )
        print("initialize_db")
        with self.env.begin(write=False, buffers=True) as txn:
            #print("fuck")
            self.image_ids = pickle.loads(txn.get(b"keys"))
            self.image_id_indices = {
                self.image_ids[i]: i for i in range(0, len(self.image_ids))
            }
            #print(self.image_id_indices)

    def _load(self, image_file_path):
        #print("env is: {}".format(self.env))
        if self.env is None:
            self._init_db()
        split = os.path.relpath(image_file_path, self.db_path).split(".npy")[0]

        try:
            image_id = int(split.split("_")[-1])
            # Try fetching to see if it actually exists otherwise fall back to
            # default
            img_id_idx = self.image_id_indices[str(image_id).encode()]
        except (ValueError, KeyError):
            # The image id is complex or involves folder, use it directly
            image_id = str(split).encode()
            img_id_idx = self.image_id_indices[image_id]

        with self.env.begin(write=False, buffers=True) as txn:
            image_info = pickle.loads(txn.get(self.image_ids[img_id_idx]))

        return image_info

In [3]:
# data_path = "/fsx/zmykevin/data/mmf_data/datasets/cc/defaults/annotations/train_vinvl_2.npy"
# val_annotation = np.load(data_path, allow_pickle=True)
ann_data_path = "/home/zmykevin/fb_intern/data/mmf_data/datasets/cc/defaults/annotations/train_vinvl_bookcorpus_retrieved_sorted_all_nps_9.npy"
annotations = np.load(ann_data_path, allow_pickle=True)
# annotations = []
# with open(ann_data_path, "r") as f:
#     for line in f:
#         annotations.append(json.loads(line))
#print(annotations[0])

In [4]:
base_path = "/home/zmykevin/fb_intern/data/mmf_data/datasets/cc/defaults/features/lmdbs/cc_vinvl_train_10.lmdb"
#Load the Image Feature
feature_reader = LMDBFeatureReader(100, base_path)

In [5]:
with open("/home/zmykevin/fb_intern/data/mingyang_data/CC/cc_objects_id.json", "r") as f:
    object_ids = json.load(f)

In [None]:
bad_count = 0
vinvl_annotations = []
for x in tqdm(annotations):
    try:
        img_id = x['image_id']
        img_feat_path = "/home/zmykevin/fb_intern/data/mmf_data/datasets/cc/defaults/features/lmdbs/cc_vinvl_train_9.lmdb/cc_{}".format(img_id)
        
#         img_feat_path = "/fsx/zmykevin/data/mmf_data/datasets/flickr30k/defaults/features/vinvl_detectron.lmdb/flick30k_{}.npy".format(img_id)
        #read the img_id
        feat, info= feature_reader.read(img_feat_path)
        y = x.copy()
        current_object_ids = [object_ids[x] for x in y["objects"]]
        y["objects_ids"] = current_object_ids
        vinvl_annotations.append(y)

    except:
        continue
#         print("this img_id does not work: {}".format(x["image_id"]))

    #break
#print(bad_count)
print(len(vinvl_annotations))

  0%|                                                                                                                                                                                                   | 1/575457 [00:00<20:29:52,  7.80it/s]

initialize_db


 49%|██████████████████████████████████████████████████████████████████████████████████████████████▎                                                                                                 | 282555/575457 [01:03<11:59, 407.27it/s]

In [20]:
with open("/home/zmykevin/fb_intern/data/mmf_data/datasets/cc/defaults/annotations/train_vinvl_bc_retrieved_10.npy", "wb") as f:
    np.save(f, vinvl_annotations)

In [43]:
print(len(vinvl_annotations))
#print(info)

with open("/fsx/zmykevin/data/mmf_data/datasets/visual_entailment/defaults/annotations/snli_ve_vinvl_test.jsonl", "w") as f:
    for ann in vinvl_annotations:
        json.dump(ann, f)
        f.write('\n')

17188


In [58]:
print(type(feature_reader.image_ids))

AttributeError: 'LMDBFeatureReader' object has no attribute 'image_ids'