In [None]:
import sys
sys.path.insert(0,"../")

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
# default_exp model.encoder

In [None]:
#export

import torch.nn as nn
import torch
from core.model.scene_graph.scene_graph import SceneGraph
from torchvision.models import resnet34

In [None]:
from core.dataloader import GQNDataset_pdisco, collate_boxes
from torch.utils.data import Dataset, DataLoader

In [None]:
train_dataset = GQNDataset_pdisco(root_dir='/home/mprabhud/dataset/clevr_veggies/npys/be_lt.txt')
train_loader = DataLoader(train_dataset, batch_size=5, shuffle=True, collate_fn=collate_boxes)

Initialised..... 27495  files...


In [None]:
for b in train_loader:
    feed_dict_q, feed_dict_k, metadata = b
    break

In [None]:
feed_dict_q["images"] = feed_dict_k["images"].cuda()
feed_dict_k["images"] = feed_dict_k["images"].cuda()

In [None]:
#export

class Encoder(nn.Module):
    def __init__(self, dim = 256):
        super().__init__()
        
        """
        Input:
            dim : final number of dimensions of the node and spatial embeddings
        
        Returns:
            Intialises a model which has node embeddimgs and spatial embeddings
        """
        
        self.dim = dim
        self.resnet = resnet34(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(self.resnet.children())[:-3])
        
        self.scene_graph = SceneGraph(feature_dim=self.dim, 
                                 output_dims=[self.dim,self.dim],
                                 downsample_rate=16)
        
        self.node_viewpoint_transformation = nn.Sequential(nn.Linear(263,512),
                                                nn.ReLU(),
                                                nn.Linear(512,self.dim))

        self.spatial_viewpoint_transformation = nn.Sequential(nn.Linear(263,512),
                                                        nn.ReLU(),
                                                        nn.Linear(512,self.dim))
        
    def merge_pose_with_scene_embeddings(self,
                                     scene_embeddings, 
                                     view=None):
        '''
        Input
            scene_embeddings: output of scene_graph module. A list of of tensors containing node and
                              spatial embeddings of each batch element
            view : a tensor of size [batch, 1, 7] containing information of relative egomotion
                   between the two camera viewpoints
            transform_node and transform spatial: boolean flags whether to do any transformation on nodes or not
        Output
            scene_embeddings: concatenated with pose vectors
        '''

        for batch_ind,(_, spatial_embeddings) in enumerate(scene_embeddings):
            print(batch_ind)
            num_obj_x = spatial_embeddings.shape[0]
            num_obj_y = spatial_embeddings.shape[1]

            print("Adding pose to spatial embeddings")
            # Broadcast view to spatial embedding dimension
            view_spatial = view[batch_ind].unsqueeze(0).repeat(num_obj_x, num_obj_y, 1)
            # Concatenate with visual embeddings
            pose_with_features = torch.cat((view_spatial,spatial_embeddings), dim=2)
            # Reassign the scene embeddings
            scene_embeddings[batch_ind][1] = pose_with_features

            ### To Do : Write some assertion test : (Saksham)

        return scene_embeddings

    def do_viewpoint_transformation(self,
                                    scene_embeddings,                                     
                                    transform_node=True, 
                                    transform_spatial=False):

        '''
        Input:
            scene_embeddings: output of scene_graph module concatenated with pose. A list of of tensors containing node and
                              spatial embeddings of each batch element
            transform_node and transform spatial: boolean flags whether to do any transformation on nodes or not
        Output:
            scene_embeddings: viewpoint transformed embeddings
        '''
        for ind,(_, spatial_embeddings) in enumerate(scene_embeddings):
            # Do viewpoint transformation on spatial embeddings
            print("viewpoint transform on spatial embeddings")
            scene_embeddings[ind][1] = self.spatial_viewpoint_transformation(scene_embeddings[ind][1])

        return scene_embeddings 
        
    def forward(self,
                feed_dict,
                mode="node",
                rel_viewpoint=None):
        """
        Input:
            feed_dict: a dictionary containing list tensors containing images and bounding box data. 
            Each element of the feed_dict corresponds to one elment of the batch. 
            Inside each batch are contained ["image": Image tensor, 
                                             "boxes":Bounding box tensor,
                                             bounding box
                                            ]
            mode: should be either 'node' or 'spatial' depending on what feature you want to extract
        """
        num_batch = feed_dict["images"].shape[0]
        num_total_nodes = feed_dict["objects"].sum().item()
        
        image_features = self.feature_extractor(feed_dict["images"])
        outputs = self.scene_graph(image_features, feed_dict["objects_boxes"], feed_dict["objects"], mode=mode)
        
        if mode=="node":
            #Flatten the node embeddings
            node_features = outputs[0][0]
            for num in range(1,num_batch):
                node_features = torch.cat([node_features, outputs[num][0]], dim =0) 

            return outputs, node_features
        
        if mode=="spatial" and rel_viewpoint is not None:
            print("Ading viewpoint information to spatial features")
            outputs = self.merge_pose_with_scene_embeddings(outputs,rel_viewpoint)
            outputs = self.do_viewpoint_transformation(outputs)
        
        #Flattent the spatial embeddings
        spatial_features = outputs[0][1].reshape(-1,256)

        for num in range(1, num_batch):
            spatial_features = torch.cat([spatial_features, outputs[num][1].reshape(-1,256)], dim =0)
        
        return outputs, spatial_features

In [None]:
encoder = Encoder()

In [None]:
enoder = encoder.cuda()

## Testing the Encoder

In [None]:
feed_dict_ = feed_dict_k

##### **Mode** : Node Embeddings

In [None]:
output_, node_features_ = encoder(feed_dict_, mode="node")

In [None]:
rel_viewpoint_ = metadata["rel_viewpoint"]
rel_viewpoint_.shape

torch.Size([5, 1, 7])

In [None]:
image_features_ = encoder.feature_extractor(feed_dict_["images"])
scene_graph_output = encoder.scene_graph(image_features_, feed_dict_["objects_boxes"], feed_dict_["objects"], mode="node")

In [None]:
batch_ind = 3

In [None]:
len(scene_graph_output), scene_graph_output[batch_ind][0].shape

(5, torch.Size([3, 256]))

In [None]:
node_features_.shape

torch.Size([15, 256])

##### **Mode** : Spatial Embeddings

In [None]:
output_, spatial_features_ = encoder(feed_dict_, mode="spatial", rel_viewpoint= rel_viewpoint_ )

Ading viewpoint information to spatial features
0
Adding pose to spatial embeddings
1
Adding pose to spatial embeddings
2
Adding pose to spatial embeddings
3
Adding pose to spatial embeddings
4
Adding pose to spatial embeddings
viewpoint transform on spatial embeddings
viewpoint transform on spatial embeddings
viewpoint transform on spatial embeddings
viewpoint transform on spatial embeddings
viewpoint transform on spatial embeddings


In [None]:
spatial_features = output_[0][1].reshape(-1,256)

for num in range(1,len(output)):
    spatial_features = torch.cat([spatial_features, output_[num][1].reshape(-1,256)], dim =0)

In [None]:
spatial_features.shape

torch.Size([45, 256])

### Unit Tests

In [None]:
merged_pose_with_scene = encoder.merge_pose_with_scene_embeddings(scene_graph_output, rel_viewpoint)
merged_pose_with_scene[batch_ind][0].shape, merged_pose_with_scene[batch_ind][1].shape

In [None]:
view_transformed_scene = encoder.do_viewpoint_transformation(merged_pose_with_scene, True, False)
view_transformed_scene[batch_ind][0].shape, view_transformed_scene[batch_ind][1].shape

In [None]:
outputs = encoder(feed_dict, rel_viewpoint)

In [None]:
outputs , node_features, spatial_features = outputs

In [None]:
node_features.shape

In [None]:
spatial_features=None

In [None]:
feed_dict["objects"]