In [None]:
import os

import torch.backends.cudnn as cudnn
import torch.cuda as cuda

from jacinle.cli.argument import JacArgumentParser
from jacinle.logging import get_logger, set_output_file
from jacinle.utils.imp import load_source
from jacinle.utils.tqdm import tqdm_pbar

from jactorch.cli import escape_desc_name, ensure_path, dump_metainfo
from jactorch.cuda.copy import async_copy_to
from jactorch.train import TrainerEnv
from jactorch.utils.meta import as_float

from nscl.datasets import get_available_datasets, initialize_dataset, get_dataset_builder

In [None]:
parser = JacArgumentParser(description=__doc__.strip())
args = parser.parse_args("")

In [None]:
args.data_dir = "~/CLEVR_v1.0/val"
args.data_image_root =  "~/CLEVR_v1.0/val/images/"
args.data_vocab_json =  "~/CLEVR_v1.0/val/vocab.json"
args.data_scenes_json =  "~/CLEVR_v1.0/val/scenes.json"
args.data_questions_json = "~/CLEVR_v1.0/val/CLEVR_val_questions.json"

In [None]:
args.dataset = "clevr"
args.desc = "experiments/clevr/desc_nscl_derender.py"

In [None]:
initialize_dataset(args.dataset)
build_dataset = get_dataset_builder(args.dataset)

In [None]:
desc = load_source(args.desc)
configs = desc.configs

In [None]:
dataset = build_dataset(args, configs, args.data_image_root, args.data_scenes_json, args.data_questions_json)

In [None]:
train_dataloader = dataset.make_dataloader(2, shuffle=True, drop_last=True, nr_workers=1)a

In [None]:
train_iter = iter(train_dataloader)

In [None]:
feed_dict = next(train_iter)

In [None]:
model = desc.make_model(args, dataset.unwrapped.vocab)

In [None]:
from jacinle.utils.container import GView

feed_dict = GView(feed_dict)

In [None]:
import sys
sys.path.append("../")

In [None]:
# default_exp scene_graph.scene_graph

In [None]:
#export

import torch
import torch.nn as nn
from torchvision.ops import RoIAlign
from torchvision.models import resnet34

from .utils import *

class SceneGraph(nn.Module):
    def __init__(self, feature_dim=256, output_dims=[256,256], downsample_rate=16):
        super().__init__()
        self.pool_size = 7
        self.feature_dim = feature_dim
        self.output_dims = output_dims
        self.downsample_rate = downsample_rate
        
        self.object_roi_pool = RoIAlign(self.pool_size, 1.0 / self.downsample_rate, -1)
        self.context_roi_pool = RoIAlign(self.pool_size, 1.0 / self.downsample_rate, -1)
        self.relation_roi_pool = RoIAlign(self.pool_size, 1.0 / self.downsample_rate, -1)
        
        self.context_feature_extract = nn.Conv2d(feature_dim, feature_dim, 1)
        self.relation_feature_extract = nn.Conv2d(feature_dim, feature_dim // 2 * 3, 1)

        self.object_feature_fuse = nn.Conv2d(feature_dim * 2, output_dims[0], 1)
        self.relation_feature_fuse = nn.Conv2d(feature_dim // 2 * 3 + output_dims[0] * 2, output_dims[1], 1)

        self.object_feature_fc = nn.Sequential(nn.ReLU(True), nn.Linear(output_dims[0] * self.pool_size ** 2, output_dims[0]))
        self.relation_feature_fc = nn.Sequential(nn.ReLU(True), nn.Linear(output_dims[1] * self.pool_size ** 2, output_dims[1]))
        
        self.reset_parameters()
        
    def reset_parameters(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight.data)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight.data)
                m.bias.data.zero_()
                
    def forward(self, image_features, objects, objects_length):
        
        object_features = image_features
        context_features = self.context_feature_extract(image_features)
        relation_features = self.relation_feature_extract(image_features)

        outputs = list()
        objects_index = 0
        for i in range(image_features.size(0)):
            box = objects[objects_index:objects_index + objects_length[i].item()]
            objects_index += objects_length[i].item()

            with torch.no_grad():
                batch_ind = i + torch.zeros(box.size(0), 1, dtype=box.dtype, device=box.device)

                # generate a "full-image" bounding box
                image_h, image_w = image_features.size(2) * self.downsample_rate, image_features.size(3) * self.downsample_rate
                image_box = torch.cat([
                    torch.zeros(box.size(0), 1, dtype=box.dtype, device=box.device),
                    torch.zeros(box.size(0), 1, dtype=box.dtype, device=box.device),
                    image_w + torch.zeros(box.size(0), 1, dtype=box.dtype, device=box.device),
                    image_h + torch.zeros(box.size(0), 1, dtype=box.dtype, device=box.device)
                ], dim=-1)

                # meshgrid to obtain the subject and object bounding boxes
                sub_id, obj_id = meshgrid(torch.arange(box.size(0), dtype=torch.int64, device=box.device), dim=0)
                sub_id, obj_id = sub_id.contiguous().view(-1), obj_id.contiguous().view(-1)
                sub_box, obj_box = meshgrid(box, dim=0)
                sub_box = sub_box.contiguous().view(box.size(0) ** 2, 4)
                obj_box = obj_box.contiguous().view(box.size(0) ** 2, 4)

                # union box
                union_box = generate_union_box(sub_box, obj_box)
                rel_batch_ind = i + torch.zeros(union_box.size(0), 1, dtype=box.dtype, device=box.device)

                # intersection maps
                box_context_imap = generate_intersection_map(box, image_box, self.pool_size)
                sub_union_imap = generate_intersection_map(sub_box, union_box, self.pool_size)
                obj_union_imap = generate_intersection_map(obj_box, union_box, self.pool_size)

            this_context_features = self.context_roi_pool(context_features, torch.cat([batch_ind, image_box], dim=-1))
            x, y = this_context_features.chunk(2, dim=1)
            this_object_features = self.object_feature_fuse(torch.cat([
                self.object_roi_pool(object_features, torch.cat([batch_ind, box], dim=-1)),
                x, y * box_context_imap
            ], dim=1))

            this_relation_features = self.relation_roi_pool(relation_features, torch.cat([rel_batch_ind, union_box], dim=-1))
            x, y, z = this_relation_features.chunk(3, dim=1)
            this_relation_features = self.relation_feature_fuse(torch.cat([
                this_object_features[sub_id], this_object_features[obj_id],
                x, y * sub_union_imap, z * obj_union_imap
            ], dim=1))


            outputs.append([
                self._norm(self.object_feature_fc(this_object_features.view(box.size(0), -1))),
                self._norm(self.relation_feature_fc(this_relation_features.view(box.size(0) * box.size(0), -1)).view(box.size(0), box.size(0), -1))
            ])

        return outputs

    def _norm(self, x):
        return x / x.norm(2, dim=-1, keepdim=True)  

In [None]:
scene_graph = SceneGraph(feature_dim=256, 
                         output_dims=[256,256],
                         downsample_rate=16)

In [None]:
resnet = resnet34(pretrained=True)
feature_extractor = nn.Sequential(*list(resnet.children())[:-3])

In [None]:
image_features = feature_extractor(feed_dict["image"])
outputs = scene_graph(image_features, feed_dict["objects"], feed_dict["objects_length"])

In [None]:
outputs[1].shape

In [None]:
outputs[1][1].shape

In [None]:
outputs[1][2].shape