In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.insert(0,"../")

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
#default_exp model.encoder

In [None]:
#export

import torch.nn as nn
import torch
from core.model.scene_graph.scene_graph import SceneGraph
from torchvision.models import resnet34

In [None]:
from core.dataloader import CLEVR_train, collate_boxes
from torch.utils.data import Dataset, DataLoader

In [None]:
train_dataset = CLEVR_train(root_dir='/home/mprabhud/dataset/clevr_lang/npys/ab_5t.txt')
train_loader = DataLoader(train_dataset, batch_size=5, shuffle=True, collate_fn=collate_boxes)

Initialised..... 234  files...


In [None]:
for b in train_loader:
    feed_dict_q, feed_dict_k, metadata = b
    break

In [None]:
feed_dict_q["images"] = feed_dict_k["images"].cuda()
feed_dict_k["images"] = feed_dict_k["images"].cuda()

In [None]:
#export

class Encoder(nn.Module):
    def __init__(self, dim = 256):
        super().__init__()
        
        """
        Input:
            dim : final number of dimensions of the node and spatial embeddings
        
        Returns:
            Intialises a model which has node embeddimgs and spatial embeddings
        """
        
        self.dim = dim
        self.resnet = resnet34(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(self.resnet.children())[:-3])
        
        self.scene_graph = SceneGraph(feature_dim=self.dim, 
                                 output_dims=[self.dim,self.dim],
                                 downsample_rate=16)
        
        self.node_viewpoint_transformation = nn.Sequential(nn.Linear(263,512),
                                                nn.ReLU(),
                                                nn.Linear(512,self.dim))

        self.spatial_viewpoint_transformation = nn.Sequential(nn.Linear(263,512),
                                                        nn.ReLU(),
                                                        nn.Linear(512,self.dim))
        
    def merge_pose_with_scene_embeddings(self,
                                     scene_embeddings,
                                     view=None):
        '''
        Input
            scene_embeddings: output of scene_graph module. A list of of tensors containing node and
                              spatial embeddings of each batch element
            view : a tensor of size [batch, 1, 7] containing information of relative egomotion
                   between the two camera viewpoints
            transform_node and transform spatial: boolean flags whether to do any transformation on nodes or not
        Output
            scene_embeddings: concatenated with pose vectors
        '''

        for batch_ind,(_, spatial_embeddings) in enumerate(scene_embeddings):
            print(batch_ind)
            num_obj_x = spatial_embeddings.shape[0]
            num_obj_y = spatial_embeddings.shape[1]

            print("Adding pose to spatial embeddings")
            # Broadcast view to spatial embedding dimension
            view_spatial = view[batch_ind].unsqueeze(0).repeat(num_obj_x, num_obj_y, 1)
            # Concatenate with visual embeddings
            pose_with_features = torch.cat((view_spatial,spatial_embeddings), dim=2)
            # Reassign the scene embeddings
            scene_embeddings[batch_ind][1] = pose_with_features

            ### To Do : Write some assertion test : (Saksham)

        return scene_embeddings

    def do_viewpoint_transformation(self,
                                    scene_embeddings,
                                    transform_node=True,
                                    transform_spatial=False):

        '''
        Input:
            scene_embeddings: output of scene_graph module concatenated with pose. A list of of tensors containing node and
                              spatial embeddings of each batch element
            transform_node and transform spatial: boolean flags whether to do any transformation on nodes or not
        Output:
            scene_embeddings: viewpoint transformed embeddings
        '''
        for ind,(_, spatial_embeddings) in enumerate(scene_embeddings):
            # Do viewpoint transformation on spatial embeddings
            print("viewpoint transform on spatial embeddings")
            scene_embeddings[ind][1] = self.spatial_viewpoint_transformation(scene_embeddings[ind][1])

        return scene_embeddings

    def forward(self,
                feed_dict,
                mode="node",
                rel_viewpoint=None):
        """
        Input:
            feed_dict: a dictionary containing list tensors containing images and bounding box data.
            Each element of the feed_dict corresponds to one elment of the batch.
            Inside each batch are contained ["image": Image tensor,
                                             "boxes":Bounding box tensor,
                                             bounding box
                                            ]
            mode: should be either 'node' or 'spatial' depending on what feature you want to extract
        """
        num_batch = feed_dict["images"].shape[0]
        num_total_nodes = feed_dict["objects"].sum().item()

        image_features = self.feature_extractor(feed_dict["images"])
        outputs = self.scene_graph(image_features, feed_dict["objects_boxes"], feed_dict["objects"], mode=mode)

        if mode=="node":
            return outputs

        if mode=="spatial" and rel_viewpoint is not None:
            print("Ading viewpoint information to spatial features")
            outputs = self.merge_pose_with_scene_embeddings(outputs,rel_viewpoint)
            outputs = self.do_viewpoint_transformation(outputs)
            
            return outputs
            
        if mode=="spatial" and rel_viewpoint is None:
            return outputs

In [None]:
encoder = Encoder()

In [None]:
enoder = encoder.cuda()

## Testing the Encoder

##### **Mode** : Node Embeddings

In [None]:
feed_dict_ = feed_dict_k

In [None]:
node_outputs_ = encoder(feed_dict_, mode="node")

In [None]:
rel_viewpoint_ = metadata["rel_viewpoint"]
rel_viewpoint_.shape

torch.Size([5, 1, 7])

In [None]:
image_features_ = encoder.feature_extractor(feed_dict_["images"])
scene_graph_output = encoder.scene_graph(image_features_, feed_dict_["objects_boxes"], feed_dict_["objects"], mode="node")

In [None]:
batch_ind = 3

In [None]:
len(scene_graph_output), scene_graph_output[batch_ind][0].shape

(5, torch.Size([2, 256]))

##### **Mode** : Spatial Embeddings

In [None]:
spatial_outputs_ = encoder(feed_dict_, mode="spatial", rel_viewpoint= rel_viewpoint_ )

Ading viewpoint information to spatial features
0
Adding pose to spatial embeddings
1
Adding pose to spatial embeddings
2
Adding pose to spatial embeddings
3
Adding pose to spatial embeddings
4
Adding pose to spatial embeddings
viewpoint transform on spatial embeddings
viewpoint transform on spatial embeddings
viewpoint transform on spatial embeddings
viewpoint transform on spatial embeddings
viewpoint transform on spatial embeddings


In [None]:
len(spatial_outputs_), spatial_outputs_[batch_ind][0].shape

(5, torch.Size([2, 256]))

In [None]:
len(spatial_outputs_), spatial_outputs_[batch_ind][1].shape

(5, torch.Size([2, 2, 256]))

## Matching the Nodes

In [None]:
feed_dict_k_ = feed_dict_k
feed_dict_q_ = feed_dict_q

In [None]:
output_k_ = encoder(feed_dict_k_, mode="node")

In [None]:
output_q_ = encoder(feed_dict_q_, mode="node")

In [None]:
from sklearn.neighbors import NearestNeighbors

In [None]:
def pair_embeddings(output_k, output_q, mode = "node"):
    
    if mode=="node":
        mode = 0
    elif mode=="spatial":
        mode = 1
    else:
        raise ValueError("Mode should be either node or spatial")
    
    num_batch = len(output_k)
    assert num_batch==len(output_q)   
    
    output_q_rearrange = []
    
    for batch_ind in range(num_batch):
        
        num_obj_in_batch = output_k[batch_ind][0].shape[0]
        assert num_obj_in_batch==output_q[batch_ind][0].shape[0]
        
        if mode=="spatial":
            assert num_obj_in_batch==output_q[batch_ind][1].shape[0]
            assert num_obj_in_batch==output_q[batch_ind][1].shape[1]
            assert output_k[batch_ind][1].shape[0]==output_k[batch_ind][1].shape[0]
            assert output_k[batch_ind][1].shape[1]==output_k[batch_ind][1].shape[1]
            assert output_k[batch_ind][1].shape[0]==output_k[batch_ind][1].shape[1]
            assert output_k[batch_ind][1].shape[1]==output_k[batch_ind][1].shape[0]
            
        #flatten the node features only - 
        output_k[batch_ind][0] = output_k[batch_ind][0].view(-1,256)
        output_q[batch_ind][0] = output_q[batch_ind][0].view(-1,256)
        
        
        #form two pool from node features for nearest neighbour search
        pool_e = output_k[batch_ind][0].clone().detach().cpu()
        pool_g = output_q[batch_ind][0].clone().detach().cpu()

        with torch.no_grad():

            knn_e = NearestNeighbors(n_neighbors= num_obj_in_batch, metric="euclidean")
            knn_g = NearestNeighbors(n_neighbors= num_obj_in_batch, metric="euclidean")

            knn_g.fit(pool_g)
            knn_e.fit(pool_e)
            
            paired = []
            pairs = []
            for index in range(num_obj_in_batch):  

                #fit knn on each of the object 
                _, indices_e = knn_g.kneighbors(torch.reshape(pool_e[index], (1,-1)).detach().cpu())
                indices_e = list(indices_e.flatten())
                for e in indices_e:
                    if e not in paired:
                        paired.append(e)
                        pairs.append(e)
                        break
        
        print(pairs)
        #rearranging the matched in output_q based on pair formed
        
    
        #Rearranging the node_features in output_q based on pair formed
        assert num_obj_in_batch == len(pairs)
        
        node_pool_rearranged = torch.zeros(pool_e.shape[0], 256)
        for index_node in range(num_obj_in_batch):
            pair_mapping_obj = pairs[index_node]
            node_pool_rearranged[index_node] = output_q[batch_ind][0][pair_mapping_obj]
        
        output_q[batch_ind][0] = node_pool_rearranged.cuda()
        
        #If mode is spatial : also repair the spatial embeddings
        if mode=="spatial":
            spatial_pool_rearranged = torch.zeros(pool_e.shape[0], pool_e.shape[0], 256)
            for index_subj in range(num_obj_in_batch):
                for index_obj in range(num_obj_in_batch):
                    pair_mapping_subj = pairs[index_subj]
                    pair_mapping_obj = pairs[index_obj]
                    spatial_pool_rearranged[index_subj][index_obj] = output_q[batch_ind][1][pair_mapping_subj][pair_mapping_obj]
                    
            output_q[batch_ind][1] = spatial_pool_rearranged
        
    return output_k, output_q    

In [None]:
rearranged_output_k, rearranged_output_q = pair_embeddings(output_k_, output_q_, mode = "node")

[1, 0]
[1, 0]
[1, 0]
[1, 0]
[0, 1]


In [None]:
rearranged_output_k==output_k_

True

In [None]:
rearranged_output_q==output_q_

True

## Matching the spatial embeddings

In [None]:
output_k__ = encoder(feed_dict_k_, mode="spatial")
output_q__ = encoder(feed_dict_q_, mode="spatial")

In [None]:
#_,__ = pair_embeddings(output_k_, output_q_, mode = "spatial")


#Code breaking resolve later

In [None]:
output_k__[0][1]

tensor([[[-0.0172,  0.0110, -0.0169,  ..., -0.0446,  0.1148, -0.0295],
         [-0.0199,  0.0063, -0.0343,  ..., -0.0187,  0.1133, -0.0062]],

        [[-0.0129,  0.0145,  0.0046,  ..., -0.0250,  0.1126, -0.0103],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan]]],
       device='cuda:0', grad_fn=<DivBackward0>)

## Flatten the embeddings across batch

In [None]:
def stack_features_across_batch(output_feature_list, mode="node"):

    num_batch = len(output_feature_list)
    if mode=="node":  
        node_features = output_feature_list[0][0].view(-1,256)

        for num in range(1,num_batch):
            node_features = torch.cat([node_features, output_feature_list[num][0]], dim =0)
        
        return node_features
    
    if mode=="spatial":
        spatial_features = output_feature_list[0][1].view(-1,256)

        for num in range(1, num_batch):
            spatial_features = torch.cat([spatial_features, outputs[num][1].view(-1,256)], dim =0)
            
        return spatial_features
    
    raise ValueError("Training mode not defined properly. It should be either 'node' or 'spatial'." )       

In [None]:
stacked_output_k = stack_features_across_batch(rearranged_output_k)

In [None]:
stacked_output_k.shape

torch.Size([10, 256])

In [None]:
stacked_output_k[3] == rearranged_output_k[1][0][1]

tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, Tr

In [None]:
stacked_output_k[3]

torch.Size([10, 256])

In [None]:
rearranged_output_k

[[tensor([[-6.8771e-02, -5.9459e-02, -4.9381e-02, -4.5075e-02,  2.6175e-02,
           -5.5583e-02, -2.8446e-02, -5.4258e-02, -1.2026e-01, -3.3940e-02,
            7.4954e-02, -5.0520e-02,  7.1580e-02, -2.2378e-02, -1.0954e-01,
            6.4361e-03, -3.7401e-02,  3.3686e-02, -7.0083e-03, -1.2406e-01,
           -2.6524e-03,  1.3030e-02,  6.4491e-04,  8.5586e-03, -5.0108e-02,
           -1.0728e-01,  8.1664e-02, -4.7497e-02,  5.5873e-04,  5.9002e-02,
            7.8673e-02,  9.3005e-02,  2.6180e-02, -6.2063e-02,  4.2883e-02,
           -7.0917e-03, -6.1418e-02, -8.3136e-02, -3.2738e-02, -9.6146e-02,
           -8.8633e-02,  8.9317e-02, -7.7961e-02,  1.0480e-02, -1.8010e-01,
           -1.2887e-01,  1.1231e-02, -2.2957e-02, -6.3313e-02,  2.9969e-02,
            3.5156e-04,  2.4619e-02,  6.7547e-02,  3.9243e-04, -5.2683e-02,
            1.1091e-01,  5.5613e-02,  9.2998e-02,  4.0873e-02,  1.7603e-01,
           -1.4293e-02,  1.5070e-01,  1.4609e-01,  5.6070e-02,  2.2137e-04,
            