In [None]:
import sys
sys.path.append("../")

In [None]:
# default_exp model.model

In [None]:
#model

import torch
import torch.nn as nn
from random import sample
import torch.nn as nn
import torch

from core.model.encoder import Encoder

In [None]:
#model

class MoCo(nn.Module):
    """
    Build a MoCo model with: a query encoder, a key encoder, and a queue
    https://arxiv.org/abs/1911.05722
    """
    def __init__(self, base_encoder=None, dim=256, r=16384, m=0.999, T=0.1, mlp=False):
        """
        dim: feature dimension (default: 128)
        r: queue size; number of negative samples/prototypes (default: 16384)
        m: momentum for updating key encoder (default: 0.999)
        T: softmax temperature 
        mlp: whether to use mlp projection
        """
        super(MoCo, self).__init__()

        self.r = r
        self.m = m
        self.T = T
        

        # create the encoders
        # num_classes is the output fc dimension
#         self.encoder_q = base_encoder(num_classes=dim)
#         self.encoder_k = base_encoder(num_classes=dim)

#         if mlp:  # hack: brute-force replacement
#             dim_mlp = self.encoder_q.fc.weight.shape[1]
#             self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc)
#             self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc)

        self.encoder_q = Encoder(dim = dim)
        self.encoder_k = Encoder(dim = dim)

        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        # create the queue
        self.register_buffer("queue", torch.randn(dim, r))
        self.queue = nn.functional.normalize(self.queue, dim=0)

        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """
        Momentum update of the key encoder
        """
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):

        # gather keys before updating queue
#         keys = concat_all_gather(keys)

        batch_size = keys.shape[0]

        ptr = int(self.queue_ptr)
        #removes for now
#         assert self.r % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.queue[:, ptr:ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.r  # move pointer

        self.queue_ptr[0] = ptr

    @torch.no_grad()
    def _batch_shuffle_ddp(self, x):
        """
        Batch shuffle, for making use of BatchNorm.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # random shuffle index
        idx_shuffle = torch.randperm(batch_size_all).cuda()

        # broadcast to all gpus
        torch.distributed.broadcast(idx_shuffle, src=0)

        # index for restoring
        idx_unshuffle = torch.argsort(idx_shuffle)

        # shuffled index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this], idx_unshuffle

    @torch.no_grad()
    def _batch_unshuffle_ddp(self, x, idx_unshuffle):
        """
        Undo batch shuffle.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # restored index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this]

    def forward(self, feed_dict_q, feed_dict_k=None, is_eval=False, cluster_result=None, index=None):
        """
        Input:
            feed_dict_q: a batch of query images and bounding boxes
            feed_dict_k: a batch of key images and bounding boxes
            is_eval: return momentum embeddings (used for clustering)
            cluster_result: cluster assignments, centroids, and density
            index: indices for training samples
        Output:
            logits, targets, proto_logits, proto_targets
        """
        
        if is_eval:
            _, k, __ = self.encoder_k(feed_dict_q)  
            k = nn.functional.normalize(k, dim=1)            
            return k
        
        # compute key features
        with torch.no_grad():  # no gradient to keys
            self._momentum_update_key_encoder()  # update the key encoder

#             # shuffle for making use of BN
#             im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)

            _, k, __ = self.encoder_k(feed_dict_k)  # keys: NxC
            k = nn.functional.normalize(k, dim=1)    # not needed scene graph does that already

#             # undo shuffle
#             k = self._batch_unshuffle_ddp(k, idx_unshuffle)

        # compute query features
        _, q, __ = self.encoder_q(feed_dict_q)  # queries: NxC
        q = nn.functional.normalize(q, dim=1)
        
        # compute logits
        # Einstein sum is more intuitive
        # positive logits: Nx1
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        # negative logits: Nxr
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])

        # logits: Nx(1+r)
        logits = torch.cat([l_pos, l_neg], dim=1)

        # apply temperature
        logits /= self.T

        # labels: positive key indicators
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

        # dequeue and enqueue
        self._dequeue_and_enqueue(k)
        
        # prototypical contrast
        if cluster_result is not None:  
            proto_labels = []
            proto_logits = []
            for n, (im2cluster,prototypes,density) in enumerate(zip(cluster_result['im2cluster'],cluster_result['centroids'],cluster_result['density'])):
                # get positive prototypes
                pos_proto_id = im2cluster[index]
                pos_prototypes = prototypes[pos_proto_id]    
                
                # sample negative prototypes
                all_proto_id = [i for i in range(im2cluster.max())]       
                neg_proto_id = set(all_proto_id)-set(pos_proto_id.tolist())
                neg_proto_id = sample(neg_proto_id,self.r) #sample r negative prototypes 
                neg_prototypes = prototypes[neg_proto_id]    

                proto_selected = torch.cat([pos_prototypes,neg_prototypes],dim=0)
                
                # compute prototypical logits
                logits_proto = torch.mm(q,proto_selected.t())
                
                # targets for prototype assignment
                labels_proto = torch.linspace(0, q.size(0)-1, steps=q.size(0)).long().cuda()
                
                # scaling temperatures for the selected prototypes
                temp_proto = density[torch.cat([pos_proto_id,torch.LongTensor(neg_proto_id).cuda()],dim=0)]  
                logits_proto /= temp_proto
                
                proto_labels.append(labels_proto)
                proto_logits.append(logits_proto)
            return logits, labels, proto_logits, proto_labels
        else:
            return logits, labels, None, None


# utils
@torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [torch.ones_like(tensor)
        for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output


In [None]:
import warnings
warnings.filterwarnings("ignore")

import torch
from torch.utils.data import Dataset, DataLoader
from core.dataloader import GQNDataset_pdisco, collate_boxes

In [None]:
train_dataset = GQNDataset_pdisco(root_dir='/home/mprabhud/dataset/clevr_veggies/npys/be_lt.txt')
train_loader = DataLoader(train_dataset, batch_size=5, shuffle=True, collate_fn=collate_boxes)

Initialised..... 27495  files...


In [None]:
for b in train_loader:
    feed_dict_q, feed_dict_k, metadata = b
    break

In [None]:
feed_dict_k["images"] = feed_dict_k["images"].cuda()
feed_dict_q["images"] = feed_dict_q["images"].cuda()

In [None]:
model = MoCo()
model = model.cuda()

In [None]:
outputs = model(feed_dict_q, feed_dict_k)

In [None]:
outputs

(tensor([[ 4.7918,  0.3855, -0.4694,  ..., -0.1715,  0.0989,  0.2826],
         [ 6.9984,  0.5304, -0.6606,  ..., -0.8775,  0.5560,  0.1995],
         [ 6.9984,  0.5304, -0.6606,  ..., -0.8775,  0.5560,  0.1995],
         ...,
         [ 4.7955,  0.9802, -0.0634,  ..., -1.1479,  0.3556,  1.0232],
         [ 4.4394,  0.9802, -0.0634,  ..., -1.1479,  0.3556,  1.0232],
         [ 5.5887,  0.2923, -0.3802,  ..., -0.9283, -0.3126,  0.4356]],
        device='cuda:0', grad_fn=<DivBackward0>),
 tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0'),
 None,
 None)

In [None]:
feed_dict_q.keys()

dict_keys(['images', 'objects', 'objects_boxes', 'view'])

In [None]:
feed_dict_k["objects_boxes"].shape

torch.Size([15, 4])

In [None]:
feed_dict_k["images"].shape

torch.Size([5, 3, 256, 256])

In [None]:
feed_dict_k["images"] = torch.rand((5, 3, 256, 384))

In [None]:
feed_dict_q["images"] = torch.rand((5, 3, 256, 384))

In [None]:
feed_dict_k["objects_boxes"]

tensor([[ 82.7502,   0.0000, 163.7342,   0.0000],
        [104.9818,   0.0000, 152.8877,   0.0000],
        [ 87.1218,   0.0000, 149.3486,   0.0000],
        [255.0000,  69.7427, 255.0000, 124.0020],
        [255.0000,  82.1422, 255.0000, 127.0757],
        [255.0000, 107.3452, 255.0000, 206.2461],
        [  0.0000,   0.0000,   0.0000,  20.7141],
        [  0.0000,   0.0000,   7.5103,  25.6727],
        [  2.0656,   0.0000,  81.2577,   5.7856],
        [255.0000,   0.0000, 255.0000,   8.3311],
        [226.3104,   0.0000, 255.0000,   0.0000],
        [255.0000,   0.0000, 255.0000,  37.2964],
        [  0.0000,   0.0000,  75.9146,  42.3128],
        [  0.0000,   0.0000,  10.0282,  46.0512],
        [  0.0000,   6.3393,   0.0000,  43.8361]], device='cuda:0')

In [None]:
feed_dict_k["images"]

tensor([[[[0.4157, 0.4275, 0.4275,  ..., 0.4314, 0.4314, 0.4392],
          [0.4235, 0.4235, 0.4275,  ..., 0.4392, 0.4392, 0.4471],
          [0.4235, 0.4235, 0.4275,  ..., 0.4392, 0.4392, 0.4471],
          ...,
          [0.4627, 0.4667, 0.4627,  ..., 0.4471, 0.4471, 0.4471],
          [0.4627, 0.4627, 0.4627,  ..., 0.4471, 0.4471, 0.4471],
          [0.4667, 0.4627, 0.4627,  ..., 0.4471, 0.4471, 0.4431]],

         [[0.4157, 0.4275, 0.4275,  ..., 0.4314, 0.4314, 0.4392],
          [0.4235, 0.4235, 0.4275,  ..., 0.4392, 0.4392, 0.4471],
          [0.4235, 0.4235, 0.4275,  ..., 0.4392, 0.4392, 0.4471],
          ...,
          [0.4627, 0.4667, 0.4588,  ..., 0.4471, 0.4471, 0.4471],
          [0.4627, 0.4627, 0.4588,  ..., 0.4471, 0.4471, 0.4471],
          [0.4627, 0.4627, 0.4627,  ..., 0.4471, 0.4471, 0.4431]],

         [[0.4157, 0.4275, 0.4275,  ..., 0.4275, 0.4314, 0.4392],
          [0.4235, 0.4235, 0.4275,  ..., 0.4353, 0.4392, 0.4471],
          [0.4235, 0.4235, 0.4275,  ..., 0

In [None]:
feed_dict_k["objects_boxes"]

tensor([[135.1649,   0.0000, 161.1292,  31.9892],
        [ 82.1292,  19.1714,  98.9613,  55.6602],
        [109.6565,   0.0000, 136.7068,  40.1612],
        [ 83.2603,   0.0000, 112.0836,   0.0000],
        [149.3388,   0.0000, 195.2714,  22.0639],
        [142.0847,   0.0000, 156.1275,   0.0000],
        [  0.0000,  14.8198,   0.0000,  61.5764],
        [ 43.6216,  14.2285,  78.9726,  52.6959],
        [  0.0000,  18.6559,   2.4447,  67.6315],
        [ 44.2445, 255.0000, 214.4792, 255.0000],
        [220.5551, 255.0000, 255.0000, 255.0000],
        [  0.0000, 255.0000,  94.6047, 255.0000],
        [  0.0000,  86.0737,   0.0000, 147.6474],
        [  0.0000,  69.2935,   0.0000, 108.4176],
        [  0.0000, 116.3612,   0.0000, 179.0138]], device='cuda:0')