In [5]:
%load_ext autoreload
%autoreload 2

In [1]:
import torch
from torch import nn

In [2]:
class HierarchicalMemory(nn.Module):
    def __init__(self,
                 num_categories,
                 num_attributes_per_category,
                 categories_key_size=64,
                 attributes_key_size=64,
                 value_size=64,
                 dropout=0.1,
                ):
        super().__init__()
        
        self.num_categories = num_categories
        self.num_attributes_per_category = num_attributes_per_category
        
        self.categories_key_size = categories_key_size
        self.attributes_key_size = attributes_key_size
        self.value_size = value_size
        
        self.categories = torch.FloatTensor(num_categories, categories_key_size)
        self.attributes = torch.FloatTensor(num_categories, num_attributes_per_category, attributes_key_size)
        
        self.categories = nn.Parameter(self.categories)
        self.attributes = nn.Parameter(self.attributes)
        
        if value_size is not None:
            self.values = torch.FloatTensor(num_categories, num_attributes_per_category, value_size)
            self.values = nn.Parameter(self.values)
        else:
            self.values = None
        
        self.dropout = nn.Dropout(dropout)
        self._init_params()
        
    def _init_params(self):
        nn.init.normal_(self.categories)
        nn.init.normal_(self.attributes)
        if self.values is not None:
            nn.init.normal_(self.values)
        
    def forward(self, query):
        
        query = self.dropout(query)
        category_attention = torch.matmul(query[:, :self.categories_key_size], self.categories.t())
        category_attention = nn.functional.softmax(category_attention, dim=1)
                
        attribute_attention = torch.matmul(query[:, self.categories_key_size:], self.attributes.permute(0, 2, 1))
        attribute_attention = nn.functional.softmax(attribute_attention, dim=2)
                
        if self.values is not None:
            attribute_attention = torch.matmul(attribute_attention, self.values)

        else:
            attribute_attention = torch.matmul(attribute_attention, self.attributes)
        
        category_attention = torch.matmul(category_attention.unsqueeze(1), attribute_attention).squeeze(1)
        
        return category_attention

In [34]:
from mac import HierarchicalMemory

In [56]:
torch.randn(8)

tensor([ 0.9812, -1.2954,  0.1503,  0.1214,  2.4526, -0.2100,  0.6897, -0.3006])

In [57]:
(nn.functional.softmax(torch.randn(2, 5, 6), dim=2).unsqueeze(3) *  torch.randn(5, 6, 8)).sum(dim=2)

tensor([[[-0.4126,  0.6920,  0.2098, -0.2235, -0.3082,  0.2059,  0.4995,
           0.1840],
         [-0.8027, -0.1491, -0.2894, -0.7281,  0.5694, -0.1165, -0.1300,
          -0.0618],
         [-0.5812,  1.0318, -0.1973, -0.1821, -0.3681, -0.3152, -0.9683,
          -0.3389],
         [ 0.3253,  0.0450, -0.0862, -0.0870, -0.3185,  0.1147,  0.0950,
          -0.5925],
         [ 0.4874, -0.6049, -0.0925,  0.0021,  0.3434,  0.9950, -0.0963,
           0.3307]],

        [[-0.4825, -0.1253, -0.1234, -0.1229,  0.3158, -0.5058, -0.2277,
           0.6234],
         [-1.0779, -0.6746, -0.4625,  0.3930, -0.2796, -0.1995, -0.3210,
          -0.5292],
         [-0.5685,  1.2189,  0.1967,  0.2939, -0.1047,  0.0913, -1.1220,
          -0.4336],
         [ 0.2339, -0.0343,  0.0202, -0.9019, -0.5439,  0.3890,  0.2416,
          -0.6860],
         [ 0.7302, -0.9310,  0.3548,  0.2108, -0.2683,  0.9025, -0.1845,
           0.1122]]])

In [65]:
hm = HierarchicalMemory(3, 4, 5, 6, value_size=None)
attributes, categories = hm.attributes, hm.categories
# print(categories), print(attributes);
hm(torch.ones(2, 11))

torch.Size([2, 11])
torch.Size([2, 3])
torch.Size([3, 4, 6])
torch.Size([3, 4, 2])
torch.Size([2, 3, 4])
torch.Size([2, 3])
torch.Size([2, 3, 6])


tensor([[ 0.1142, -0.0093,  1.2074,  0.8915, -0.6113,  0.2692],
        [-0.0981,  1.2242, -0.0969,  0.5737,  0.4226, -0.1635]],
       grad_fn=<SqueezeBackward1>)

In [6]:
import sys

sys.path.insert(0, 'code/')

In [7]:
import torch
from torch import nn
from torch.utils.data import DataLoader
import torchvision
from torchsummaryX import summary


from mac import MACNetwork
from utils import load_vocab
from datasets import ClevrDataset, collate_fn

In [10]:
from config import cfg_from_file, __C, cfg

cfg_from_file('cfg/clevr_train_mac.yml')
__C.CUDA = False
__C.GPU_ID = '-1'
vocab = load_vocab(cfg)
# cfg.TRAIN.RECV_OBJECTS = False

In [14]:
ds = ClevrDataset(
    data_dir='/Users/sebamenabar/Documents/datasets/CLEVR/data',
    # img_dir='/Users/sebamenabar/Documents/datasets/CLEVR/CLEVR_v1.0/images/',
    # scenes_json='/Users/sebamenabar/Documents/TAIA/individual/sm/data/clevr/train/scenes.json',
    # raw_image=True,
    split='val',
)

In [15]:
loader = DataLoader(dataset=ds, batch_size=2, shuffle=True,
                                       num_workers=2, drop_last=False, collate_fn=collate_fn)

In [16]:
b = next(iter(loader))

In [68]:
model = MACNetwork(cfg=cfg, max_step=1, vocab=vocab)
model(b['image'], b['question'], b['question_length'])

tensor([[ 0.1677,  1.4895, -0.4643,  0.0921,  2.1208,  0.0596, -0.3440,  0.1111,
          0.5066,  0.0738, -0.8081,  0.6521,  0.7132,  0.7500, -0.7690,  0.0585,
          0.3345,  0.1306,  0.5794,  1.6670, -0.4928,  1.6715,  0.1323, -0.4607,
          0.9342,  0.8394, -0.6713,  2.0152],
        [ 0.4098,  0.7664,  0.5154, -0.2484,  1.2958,  0.8197, -0.7085, -0.8590,
          0.1597, -0.1527, -1.0427,  1.0002,  1.4161,  0.7169, -0.0661,  0.8626,
          1.4626,  0.2948,  0.0471,  0.9093, -1.0565,  0.4646,  0.7410, -0.8162,
          0.3647,  0.6035, -0.1683,  2.1071]], grad_fn=<AddmmBackward>)