In [None]:
import sys
sys.path.append("../")

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
# default_exp model.encoder

In [None]:
#export

import torch.nn as nn
import torch
from core.scene_graph.scene_graph import SceneGraph
from torchvision.models import resnet34

In [None]:
from core.dataloader import GQNDataset_pdisco, collate_boxes
from torch.utils.data import Dataset, DataLoader

In [None]:
train_dataset = GQNDataset_pdisco(root_dir='/home/mprabhud/dataset/clevr_veggies/npys/be_lt.txt')
train_loader = DataLoader(train_dataset, batch_size=5, shuffle=True, collate_fn=collate_boxes)

Initialised..... 27495  files...


In [None]:
for b in train_loader:
    feed_dict_q, feed_dict_k, metadata = b
    break

In [None]:
feed_dict_q["images"] = feed_dict_k["images"].cuda()
feed_dict_k["images"] = feed_dict_k["images"].cuda()

In [None]:
#export

class Encoder(nn.Module):
    def __init__(self, dim = 256):
        super().__init__()
        
        """
        Input:
            dim : final number of dimensions of the node and spatial embeddings
        
        Returns:
            Intialises a model which has node embeddimgs and spatial embeddings
        """
        
        self.dim = dim
        self.resnet = resnet34(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(self.resnet.children())[:-3])
        
        self.scene_graph = SceneGraph(feature_dim=self.dim, 
                                 output_dims=[self.dim,self.dim],
                                 downsample_rate=16)
        
        self.node_viewpoint_transformation = nn.Sequential(nn.Linear(263,512),
                                                nn.ReLU(),
                                                nn.Linear(512,self.dim))

        self.spatial_viewpoint_transformation = nn.Sequential(nn.Linear(263,512),
                                                        nn.ReLU(),
                                                        nn.Linear(512,self.dim))
        
    def merge_pose_with_scene_embeddings(self,
                                     scene_embeddings, 
                                     view=None, 
                                     transform_node=True, 
                                     transform_spatial=False):
        '''
        Input
            scene_embeddings: output of scene_graph module. A list of of tensors containing node and
                              spatial embeddings of each batch element
            view : a tensor of size [batch, 1, 7] containing information of relative egomotion
                   between the two camera viewpoints
            transform_node and transform spatial: boolean flags whether to do any transformation on nodes or not
        Output
            scene_embeddings: concatenated with pose vectors
        '''
    
        if view is None:
            raise NotImplementedError("Wrong Implementation")

        for batch_ind,(node_embeddings, spatial_embeddings) in enumerate(scene_embeddings):

            if transform_node:
                print("Node: Pose with Node Concat :  Batch Ind: {}".format(batch_ind))
                num_objects = node_embeddings.shape[0]
                # Broadcast view to visual embedding dimension
                view_visual = view[batch_ind].repeat(num_objects,1)
                # Concatenate with visual embeddings
                pose_with_features = torch.cat((view_visual,node_embeddings), dim=1)
                # Reassign the scene embeddings
                scene_embeddings[batch_ind][0] = pose_with_features

            if transform_spatial:
                num_obj_x = spatial_embeddings.shape[0]
                num_obj_y = spatial_embeddings.shape[1]

                # Broadcast view to spatial embedding dimension
                view_spatial = view[batch_ind].unsqueeze(0).repeat(num_obj_x, num_obj_y, 1)
                # Concatenate with visual embeddings
                pose_with_features = torch.cat((view_spatial,spatial_embeddings), dim=2)
                # Reassign the scene embeddings
                scene_embeddings[batch_ind][1] = pose_with_features

            ### To Do : Write some assertion test : (Saksham)

        return scene_embeddings

    def do_viewpoint_transformation(self,
                                    scene_embeddings,                                     
                                    transform_node=True, 
                                    transform_spatial=False):

        '''
        Input:
            scene_embeddings: output of scene_graph module concatenated with pose. A list of of tensors containing node and
                              spatial embeddings of each batch element
            transform_node and transform spatial: boolean flags whether to do any transformation on nodes or not
        Output:
            scene_embeddings: viewpoint transformed embeddings
        '''

        for ind,(node_embeddings, spatial_embeddings) in enumerate(scene_embeddings):
            if transform_node:
                print("Node: Transformation:  Batch Ind: {}".format(ind))
                # Do viewpoint transformation on visual embeddings
                scene_embeddings[ind][0] = self.node_viewpoint_transformation(scene_embeddings[ind][0])

            if transform_spatial:
                # Do viewpoint transformation on spatial embeddings
                scene_embeddings[ind][1] = self.spatial_viewpoint_transformation(scene_embeddings[ind][1])

        return scene_embeddings 
        
        
    def forward(self,feed_dict, rel_viewpoint=None):
        """
        Input:
            feed_dict: a dictionary containing list tensors containing images and bounding box data. 
            Each element of the feed_dict corresponds to one elment of the batch. 
            Inside each batch are contained ["image": Image tensor, 
                                             "boxes":Bounding box tensor,
                                             bounding box
                                            ]
        """
        num_batch = feed_dict["images"].shape[0]
        num_total_nodes = feed_dict["objects"].sum().item()
        
        image_features = self.feature_extractor(feed_dict["images"])
        outputs = self.scene_graph(image_features, feed_dict["objects_boxes"], feed_dict["objects"])
        
        if rel_viewpoint is not None:
            print("Viewpoint Transformation of Node feature vectors")
            outputs = self.merge_pose_with_scene_embeddings(outputs,rel_viewpoint, True, False)
            outputs = self.do_viewpoint_transformation(outputs, True, False)
        
        node_features = outputs[0][0]
        for num in range(1,num_batch):
            node_features = torch.cat([node_features, outputs[num][0]], dim =0)
        
        # To be implemented
        spatial_features = None
        
        return outputs, node_features, spatial_features

In [None]:
encoder = Encoder()

In [None]:
enoder = encoder.cuda()

In [None]:
output, node_features, spatial_features = encoder(feed_dict_q)

## Testing the Encoder

In [None]:
feed_dict = feed_dict_k

In [None]:
rel_viewpoint = metadata["rel_viewpoint"]
rel_viewpoint.shape

torch.Size([5, 1, 7])

In [None]:
image_features = encoder.feature_extractor(feed_dict["images"])
scene_graph_output = encoder.scene_graph(image_features, feed_dict["objects_boxes"], feed_dict["objects"])

In [None]:
batch_ind = 3

In [None]:
len(scene_graph_output),scene_graph_output[batch_ind][0].shape, scene_graph_output[batch_ind][1].shape

(5, torch.Size([3, 256]), torch.Size([3, 3, 256]))

In [None]:
merged_pose_with_scene = encoder.merge_pose_with_scene_embeddings(scene_graph_output, rel_viewpoint)
merged_pose_with_scene[batch_ind][0].shape, merged_pose_with_scene[batch_ind][1].shape

Node: Pose with Node Concat :  Batch Ind: 0
Node: Pose with Node Concat :  Batch Ind: 1
Node: Pose with Node Concat :  Batch Ind: 2
Node: Pose with Node Concat :  Batch Ind: 3
Node: Pose with Node Concat :  Batch Ind: 4


(torch.Size([3, 263]), torch.Size([3, 3, 256]))

In [None]:
view_transformed_scene = encoder.do_viewpoint_transformation(merged_pose_with_scene, True, False)
view_transformed_scene[batch_ind][0].shape, view_transformed_scene[batch_ind][1].shape

Node: Transformation:  Batch Ind: 0
Node: Transformation:  Batch Ind: 1
Node: Transformation:  Batch Ind: 2
Node: Transformation:  Batch Ind: 3
Node: Transformation:  Batch Ind: 4


(torch.Size([3, 256]), torch.Size([3, 3, 256]))

In [None]:
outputs = encoder(feed_dict, rel_viewpoint)

Viewpoint Transformation of Node feature vectors
Node: Pose with Node Concat :  Batch Ind: 0
Node: Pose with Node Concat :  Batch Ind: 1
Node: Pose with Node Concat :  Batch Ind: 2
Node: Pose with Node Concat :  Batch Ind: 3
Node: Pose with Node Concat :  Batch Ind: 4
Node: Transformation:  Batch Ind: 0
Node: Transformation:  Batch Ind: 1
Node: Transformation:  Batch Ind: 2
Node: Transformation:  Batch Ind: 3
Node: Transformation:  Batch Ind: 4


In [None]:
outputs , node_features, spatial_features = outputs

In [None]:
node_features.shape

torch.Size([15, 256])

In [None]:
spatial_features=None

In [None]:
feed_dict["objects"]

tensor([3, 3, 3, 3, 3])