# Region Grafts: Datalaoder-Based Relation Augmentation

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from os import environ

environ['DATA_DIR_VG_RCNN'] = '/home/zhanwen/datasets'

In [3]:
from maskrcnn_benchmark.modeling.detector import build_detection_model

# Walk through the dataset. For each label, write down the list of the index of images that contain them.

In [4]:
from torch import manual_seed as torch_manual_seed
import random
import numpy as np

from torch.cuda import max_memory_allocated, set_device, manual_seed_all
from torch.backends import cudnn

def setup_seed(seed):
    torch_manual_seed(seed)
    manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    cudnn.deterministic = True
    
setup_seed(1234)

In [5]:
# Set args
# PRETRAINED
from maskrcnn_benchmark.config import cfg

MODEL_NAME = '44663493_vctree_baseline_predcls_4GPU_riv_1_copied'
CONFIG_FILE = '/home/zhanwen/gsc/checkpoints/44663493_vctree_baseline_predcls_4GPU_riv_1_copied/config.yml'
PROJECT_DIR = '/home/zhanwen/gsc'
SEED=1234
BATCH_SIZE=1

cfg.merge_from_file(CONFIG_FILE)
cfg.SOLVER.IMS_PER_BATCH = BATCH_SIZE
cfg.DATALOADER.NUM_WORKERS = 8
cfg.GLOVE_DIR = f'{PROJECT_DIR}/datasets/vg/'
cfg.MODEL.PRETRAINED_DETECTOR_CKPT = f'{PROJECT_DIR}/checkpoints/pretrained_faster_rcnn/model_final.pth'
cfg.OUTPUT_DIR = f'{PROJECT_DIR}/checkpoints/{MODEL_NAME}'
cfg.PATHS_DATA = f'{PROJECT_DIR}/maskrcnn_benchmark/data/datasets'
cfg.OUTPUT_DIR = '/home/zhanwen/gsc/checkpoints/44663493_vctree_baseline_predcls_4GPU_riv_1_copied'
cfg.MODEL.WEIGHT = f'{PROJECT_DIR}/checkpoints/{MODEL_NAME}/model_0014000.pth'
cfg.PATHS_CATALOG = '/home/zhanwen/gsc/maskrcnn_benchmark/config/paths_catalog.py'

cfg.freeze()

In [8]:
from maskrcnn_benchmark.data import get_dataset_statistics
result = get_dataset_statistics(cfg, return_lookup=True)

dataset_name=VG_stanford_filtered_with_attribute_train
split:  train
root_classes_count:  {}
mean root class number:  0.0
sum root class number:  0
leaf_classes_count:  {}
mean leaf class number:  0.0
sum leaf class number:  0
all_classes_count:  {}
mean all class number:  0.0
sum all class number:  0
number images:  57723
get visual genome statistics!!!!!!!!!!!!!!!!!!


100%|███████████████████████████████████| 57723/57723 [00:06<00:00, 8493.31it/s]


In [9]:
result.keys()

dict_keys(['fg_matrix', 'pred_dist', 'obj_classes', 'rel_classes', 'att_classes', 'stats', 'obj2examples', 'rel2examples'])

In [10]:
obj2examples = result['obj2examples']
rel2examples = result['rel2examples']
stats = result['stats']


In [11]:
# row = [ex_ind, o1_idx, o1] + list(gt_box_o1) + [o2_idx, o2] + list(gt_box_o2) + [gtr]
names = [
    'example_idx',
    'subj_obj_idx_local',
    'subj_obj_category_idx',
    'subj_gtbox_1',
    'subj_gtbox_2',
    'subj_gtbox_3',
    'subj_gtbox_4',
    'obj_obj_idx_local',
    'obj_obj_category_idx',
    'obj_gtbox_1',
    'obj_gtbox_2',
    'obj_gtbox_3',
    'obj_gtbox_4',
    'rel_local_idx',
    'rel_category_idx',
]

In [12]:
len(stats)

405860

In [13]:
from pandas import DataFrame
df_stats = DataFrame(stats, columns=names)
del stats

In [14]:
df_stats

Unnamed: 0,example_idx,subj_obj_idx_local,subj_obj_category_idx,subj_gtbox_1,subj_gtbox_2,subj_gtbox_3,subj_gtbox_4,obj_obj_idx_local,obj_obj_category_idx,obj_gtbox_1,obj_gtbox_2,obj_gtbox_3,obj_gtbox_4,rel_local_idx,rel_category_idx
0,0,12,77,231,313,290,397,10,111,67,197,414,767,0,31
1,0,1,20,30,31,457,767,0,3,35,461,371,618,1,20
2,0,13,78,31,32,507,764,5,58,249,462,369,575,2,20
3,0,13,78,31,32,507,764,7,97,381,430,542,507,3,21
4,0,11,115,519,192,669,388,8,99,594,208,614,382,4,50
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
405855,57722,6,58,584,334,677,431,2,78,584,149,951,557,5,30
405856,57722,9,61,211,126,314,262,1,78,1,88,468,578,6,30
405857,57722,10,61,715,151,838,291,2,78,584,149,951,557,7,30
405858,57722,7,58,621,503,660,555,2,78,584,149,951,557,8,30


# 1. Strategy='Same'


## Algorithm:
### 1. Group by rel_category_idx
### 2. Select bottom_k=30 rel_category_idx
### 3. For each relation in bottom_k=30 rel_category_idx, 3 potential methods:
#### 1. exchange subj n times, keep obj.
#### 2. exchange obj n times, keep subj.
#### 3. exchange both subj and obj n times, excluding original combination
#### heuristic: compute visually similar ones, not just any visual features.
#### let's not do the heuristic for now.

In [15]:
# obj2examples_list = []
# for i in range(len(obj2examples)):
#     obj2examples_list.append(obj2examples[i])

In [16]:
# rel2examples_list = []
# for i in range(len(rel2examples)):
#     rel2examples_list.append(rel2examples[i])

# TODO: build two lookup tables. One is for objs. The other is for relations.

In [17]:
df_stats.head()

Unnamed: 0,example_idx,subj_obj_idx_local,subj_obj_category_idx,subj_gtbox_1,subj_gtbox_2,subj_gtbox_3,subj_gtbox_4,obj_obj_idx_local,obj_obj_category_idx,obj_gtbox_1,obj_gtbox_2,obj_gtbox_3,obj_gtbox_4,rel_local_idx,rel_category_idx
0,0,12,77,231,313,290,397,10,111,67,197,414,767,0,31
1,0,1,20,30,31,457,767,0,3,35,461,371,618,1,20
2,0,13,78,31,32,507,764,5,58,249,462,369,575,2,20
3,0,13,78,31,32,507,764,7,97,381,430,542,507,3,21
4,0,11,115,519,192,669,388,8,99,594,208,614,382,4,50


In [18]:
# # df.groupby(['rel_category_idx'])['C'].describe()[['count', 'mean']]
# rels_bottom_k = set(df_stats.groupby(['rel_category_idx']).count()['example_idx'].nsmallest(30).index.tolist())


In [19]:
# len(df_stats)

In [20]:
# df_stats = df_stats.query("rel_category_idx.isin(@rels_bottom_k).values")


In [21]:
# exchange subj
from tqdm import tqdm
from numpy import array as np_array, asarray as np_asarray
from pandas import DataFrame
from torchvision.transforms.functional import resize
from PIL.Image import fromarray
from maskrcnn_benchmark.data import get_dataset_statistics

# NOTE: PIL (W, H) => np (H, W)

class GraftAugmenter:
#     def __init__(self, df_objs, df_objsdf_triplets):
    def __init__(self, cfg):
#         self.df_objs = df_objs
#         self.df_triplets = df_triplets
#         self.groupby = defaultdict(np_array)
        result, datasets = get_dataset_statistics(cfg, return_lookup=True, return_datasets=True)
        assert len(datasets) == 1
        self.dataset = datasets[0]
        stats = result['stats']
        del result
        df_stats = DataFrame(stats, columns=names)
        del stats
        rels_bottom_k = set(df_stats.groupby(['rel_category_idx']).count()['example_idx'].nsmallest(30).index.tolist())
        self.df_stats = df_stats
        self.df_stats_bottom_k = df_stats.query("rel_category_idx.isin(@rels_bottom_k).values")

#     def exchange_subj(self, row_id):
#         '''
#         Takes triplet id, outputs a new row with 
#         '''
#         # TODO: should cache these
#         df_stats[].sample

#     def swap(self, idx_og, obj, image1, box1, image2, box2, n, subj_or_obj):
#     def swap(self, idx_og, obj, image1, box1, image2, box2, n, subj_or_obj):
    def swap(self, idx_og, subj_or_obj):
        # TODO: need relations. Keep ones that don't overlap or is that particular one. Drop ones that do.
#         image3 = 
#         relations = 
    #     self.df_objs[obj='']
        dataset = self.dataset
        df_stats = self.df_stats
        df_stats_bottom_k = self.df_stats_bottom_k
        img_og, target_og, index_og = dataset[idx_og]
        print(f'img_og.size={img_og.size}')
        img_og_np = np_asarray(img_og)
        print(f'img_og_np.shape={img_og_np.shape}')
        row_og = df_stats_bottom_k.iloc[idx_og]
        if subj_or_obj == 'subj':
            obj_idx_local_og = row_og['subj_obj_idx_local']
            obj_category_idx_og = row_og['subj_obj_category_idx']
            gtbox_1_og = row_og['subj_gtbox_1']
            gtbox_2_og = row_og['subj_gtbox_2']
            gtbox_3_og = row_og['subj_gtbox_3']
            gtbox_4_og = row_og['subj_gtbox_4']
        else:
            obj_idx_local_og = row_og['obj_obj_idx_local']
            obj_category_idx_og = row_og['obj_obj_category_idx']
            gtbox_1_og = row_og['obj_gtbox_1']
            gtbox_2_og = row_og['obj_gtbox_2']
            gtbox_3_og = row_og['obj_gtbox_3']
            gtbox_4_og = row_og['obj_gtbox_4']
        rel_local_idx_og = row_og['rel_local_idx']
            
#         obj_idx_local = row['subj_obj_idx_local']
#         obj_category_idx = row['subj_obj_category_idx']
#         gtbox_1 = row_og['subj_gtbox_1']
#         gtbox_2 = row_og['subj_gtbox_2']
#         gtbox_3 = row_og['subj_gtbox_3']
#         gtbox_4 = row_og['subj_gtbox_4']

#         rows = df_stats_bottom_k[(df_stats_bottom_k['subj_obj_category_idx'] == obj_category_idx_og) | (df_stats_bottom_k['obj_obj_category_idx'] == obj_category_idx_og)]
        rows = df_stats[(df_stats['subj_obj_category_idx'] == obj_category_idx_og) | (df_stats['obj_obj_category_idx'] == obj_category_idx_og)]
#         candidates = rows.sample(n=n, replace=False)
        row_new = rows.sample(n=1, ignore_index=True)
        row_new = row_new.iloc[0]
    
        idx_new = row_new['example_idx']
        img_new, target_new, index_new = dataset[idx_new]
        print(f'img_new.size={img_new.size}')
        
        subj_or_obj_new = 'subj' if row_new['subj_obj_category_idx'] == obj_category_idx_og else 'obj'
        
        
        if subj_or_obj_new == 'subj':
            obj_idx_local_new = row_new['subj_obj_idx_local']
            obj_category_idx_new = row_new['subj_obj_category_idx']
            gtbox_1_new = row_new['subj_gtbox_1']
            gtbox_2_new = row_new['subj_gtbox_2']
            gtbox_3_new = row_new['subj_gtbox_3']
            gtbox_4_new = row_new['subj_gtbox_4']
        else:
            obj_idx_local_new = row_new['obj_obj_idx_local']
            obj_category_idx_new = row_new['obj_obj_category_idx']
            gtbox_1_new = row_new['obj_gtbox_1']
            gtbox_2_new = row_new['obj_gtbox_2']
            gtbox_3_new = row_new['obj_gtbox_3']
            gtbox_4_new = row_new['obj_gtbox_4']
        rel_local_idx_new = row_new['rel_local_idx']
        
#         assert roi_og is an instance of PIL or Tensor (whichever one is consistent)
        img_new_np = np_asarray(img_new)
        print(f'img_new_np.shape={img_new_np.shape}')
#         roi_new = img_new[:, gtbox_1_new:gtbox_3_new, gtbox_2_new:gtbox_4_new]
#         roi_new = img_new.crop(box=[gtbox_1_new, gtbox_2_new, gtbox_3_new, gtbox_4_new])
#         img_og = img_og.paste(roi_new, box=[gtbox_1_og, gtbox_2_og, gtbox_3_og, gtbox_4_og])
#         try:
#             roi_og = img_og[:, gtbox_1_og:gtbox_3_og, gtbox_2_og:gtbox_4_og]
#         except:
#             import pdb; pdb.set_trace()
#         roi_og.paste(im[, box, mask])
#         roi_new_np = img_new_np[:, gtbox_1_new:gtbox_3_new, gtbox_2_new:gtbox_4_new]
#         roi_new_np = img_new_np[gtbox_1_new:gtbox_3_new, gtbox_2_new:gtbox_4_new, :]
        roi_new_np = img_new_np[gtbox_2_new:gtbox_4_new, gtbox_1_new:gtbox_3_new, :]
        print(f'roi_new_np.shape={roi_new_np.shape}')
        try:
            roi_new = fromarray(roi_new_np)
            print(f'roi_new.size={roi_new.size}')
        except:
            import pdb; pdb.set_trace()
        roi_new_resized = resize(roi_new, [gtbox_4_og-gtbox_2_og, gtbox_3_og-gtbox_1_og]) # takes H,W
        print(f'roi_new_resized.size={roi_new_resized.size}')
#         roi_new_resized = resize(roi_new, [gtbox_3_og-gtbox_1_og, gtbox_4_og-gtbox_2_og]) # takes H,W
        roi_new_resized_np = np_asarray(roi_new_resized)
        print(f'roi_new_resized_np.shape={roi_new_resized_np.shape}')
#         img_og_np[gtbox_1_og:gtbox_3_og, gtbox_2_og:gtbox_4_og, :] = roi_new_resized_np
        try:
            img_og_np[gtbox_2_og:gtbox_4_og, gtbox_1_og:gtbox_3_og, :] = roi_new_resized_np #(1024, 683, 3) => (311, 173, 3)
#             img_og_np[gtbox_1_og:gtbox_3_og, gtbox_2_og:gtbox_4_og, :] = roi_new_resized_np.swapaxes(0,1)
        except:
            import pdb; pdb.set_trace()
        img_og_modified = fromarray(img_og_np)
        print(f'img_og_modified.size={img_og_modified.size}')
        img_og.show()
        img_og_modified.show()
        return img_og_modified, target_og, None # set index to None
    
    def run(self):
        rows_new = []
#         for rel_bottom_k in rels_bottom_k:
#         from tqdm.notebook import tqdm
        swap = self.swap
#         tqdm.pandas()
        for idx, _ in tqdm(self.df_stats_bottom_k.iterrows()):
            row_new_subj = swap(idx, 'subj')
            rows_new.append(row_new_subj)
            row_new_obj = swap(idx, 'obj')
            rows_new.append(row_new_obj)

        return rows_new
                
        # TODO: make sure H, W order is correct
#         pass
#         for _ in range(self.num_mix):
#             r = np.random.rand(1)
#             if self.beta <= 0 or r > self.prob
#                 continue

#             # generate mixed sample
#             lam = np.random.beta(self.beta, self.beta)
#             rand_index = random.choice(range(len(self)))

#             img2, lb2 = self.dataset[rand_index]
#             lb2_onehot = onehot(self.num_class, lb2)

#             bbx1, bby1, bbx2, bby2 = rand_bbox(img.size(), lam)
#             img[:, bbx1:bbx2, bby1:bby2] = img2[:, bbx1:bbx2, bby1:bby2]
#             lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (img.size()[-1] * img.size()[-2]))
#             lb_onehot = lb_onehot * lam + lb2_onehot * (1. - lam)

# Actually won't affect relations because of the resizeing. This augmentation is entirely visual

# So the problem is that we need to make a new image with both existing and new triplets. 
# We need to make sure that the new ones don't overlap with the old ones.


# self.dataset.gt_boxes[row_og['example_idx']][row_og['subj_obj_idx_local']] if subj_or_obj == 'subj' else self.dataset.gt_boxes[row_og['example_idx']][row_og['obj_obj_idx_local']]

In [22]:
graft_augmenter = GraftAugmenter(cfg)

dataset_name=VG_stanford_filtered_with_attribute_train
split:  train
root_classes_count:  {}
mean root class number:  0.0
sum root class number:  0
leaf_classes_count:  {}
mean leaf class number:  0.0
sum leaf class number:  0
all_classes_count:  {}
mean all class number:  0.0
sum all class number:  0
number images:  57723
get visual genome statistics!!!!!!!!!!!!!!!!!!


100%|███████████████████████████████████| 57723/57723 [00:06<00:00, 8372.06it/s]


In [None]:
lol = graft_augmenter.run()

0it [00:00, ?it/s]

img_og.size=(1024, 543)
img_og_np.shape=(543, 1024, 3)
img_new.size=(500, 375)
img_new_np.shape=(375, 500, 3)
roi_new_np.shape=(178, 0, 3)
> [0;32m/tmp/ipykernel_166700/3743510950.py[0m(122)[0;36mswap[0;34m()[0m
[0;32m    120 [0;31m        [0;32mexcept[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    121 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 122 [0;31m        [0mroi_new_resized[0m [0;34m=[0m [0mresize[0m[0;34m([0m[0mroi_new[0m[0;34m,[0m [0;34m[[0m[0mgtbox_4_og[0m[0;34m-[0m[0mgtbox_2_og[0m[0;34m,[0m [0mgtbox_3_og[0m[0;34m-[0m[0mgtbox_1_og[0m[0;34m][0m[0;34m)[0m [0;31m# takes H,W[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    123 [0;31m        [0mprint[0m[0;34m([0m[0;34mf'roi_new_resized.size={roi_new_resized.size}'[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    124 [0;31m[0;31m#         roi_new_

In [None]:
help(tqdm.pandas)