In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch
from torch import nn

In [4]:
class HierarchicalMemory(nn.Module):
    def __init__(self,
                 num_categories=4,
                 num_attributes_per_category=5,
                 categories_key_size=256,
                 attributes_key_size=256,
                 value_size=512,
                 dropout=0.1,
                ):
        super().__init__()
        
        self.num_categories = num_categories
        self.num_attributes_per_category = num_attributes_per_category
        
        self.categories_key_size = categories_key_size
        self.attributes_key_size = attributes_key_size
        self.value_size = value_size
        
        self.categories = torch.FloatTensor(num_categories, categories_key_size)
        self.attributes = torch.FloatTensor(num_categories, num_attributes_per_category, attributes_key_size)
        
        self.categories = nn.Parameter(self.categories)
        self.attributes = nn.Parameter(self.attributes)
        
        if value_size is not None:
            self.values = torch.FloatTensor(num_categories, num_attributes_per_category, value_size)
            self.values = nn.Parameter(self.values)
        else:
            self.values = None
        
        self.dropout = nn.Dropout(dropout)
        self._init_params()
        
    def _init_params(self):
        nn.init.normal_(self.categories)
        nn.init.normal_(self.attributes)
        if self.values is not None:
            nn.init.normal_(self.values)
        
    def forward(self, query):
        
        query = self.dropout(query)
        category_attention = torch.matmul(query[:, :self.categories_key_size], self.categories.t())
        category_attention = nn.functional.softmax(category_attention, dim=1)
        
        attribute_attention = torch.matmul(self.attributes, query[:, self.categories_key_size:].t())
        attribute_attention = nn.functional.softmax(attribute_attention.permute(2, 0, 1), dim=2)
        
        category_values = torch.matmul(category_attention, self.categories)
        if self.values is not None:
            attribute_values = (attribute_attention.unsqueeze(3) * self.values).sum(dim=2)
        else:
            attribute_values = (attribute_attention.unsqueeze(3) * self.attributes).sum(dim=2)

        values = torch.matmul(category_attention.unsqueeze(1), attribute_values).squeeze(1)
        
        return values, category_values, category_attention, attribute_attention 

    def get_category_attention(self, query):
        category_attention = torch.matmul(query[:, :self.categories_key_size], self.categories.t())
        category_attention = nn.functional.softmax(category_attention, dim=1)

        return category_attention

In [6]:
hm = HierarchicalMemory(3, 4, 5, 6, value_size=None)
attributes, categories = hm.attributes, hm.categories
# print(categories), print(attributes);
hm(torch.ones(2, 11))[0].size()

torch.Size([2, 6])

In [4]:
from mac import HierarchicalMemory

ModuleNotFoundError: No module named 'mac'

In [None]:
torch.randn(8)

In [57]:
(nn.functional.softmax(torch.randn(2, 5, 6), dim=2).unsqueeze(3) *  torch.randn(5, 6, 8)).sum(dim=2)

tensor([[[-0.4126,  0.6920,  0.2098, -0.2235, -0.3082,  0.2059,  0.4995,
           0.1840],
         [-0.8027, -0.1491, -0.2894, -0.7281,  0.5694, -0.1165, -0.1300,
          -0.0618],
         [-0.5812,  1.0318, -0.1973, -0.1821, -0.3681, -0.3152, -0.9683,
          -0.3389],
         [ 0.3253,  0.0450, -0.0862, -0.0870, -0.3185,  0.1147,  0.0950,
          -0.5925],
         [ 0.4874, -0.6049, -0.0925,  0.0021,  0.3434,  0.9950, -0.0963,
           0.3307]],

        [[-0.4825, -0.1253, -0.1234, -0.1229,  0.3158, -0.5058, -0.2277,
           0.6234],
         [-1.0779, -0.6746, -0.4625,  0.3930, -0.2796, -0.1995, -0.3210,
          -0.5292],
         [-0.5685,  1.2189,  0.1967,  0.2939, -0.1047,  0.0913, -1.1220,
          -0.4336],
         [ 0.2339, -0.0343,  0.0202, -0.9019, -0.5439,  0.3890,  0.2416,
          -0.6860],
         [ 0.7302, -0.9310,  0.3548,  0.2108, -0.2683,  0.9025, -0.1845,
           0.1122]]])

torch.Size([2, 11])
torch.Size([2, 3])
torch.Size([3, 4, 6])
torch.Size([3, 4, 2])
torch.Size([2, 3, 4])
torch.Size([2, 3])
torch.Size([2, 3, 6])


tensor([[ 0.1142, -0.0093,  1.2074,  0.8915, -0.6113,  0.2692],
        [-0.0981,  1.2242, -0.0969,  0.5737,  0.4226, -0.1635]],
       grad_fn=<SqueezeBackward1>)

In [7]:
import sys

sys.path.insert(0, 'code/')

In [8]:
import torch
from torch import nn
from torch.utils.data import DataLoader
import torchvision
from torchsummaryX import summary


from mac import MACNetwork
from utils import load_vocab
from datasets import ClevrDataset, collate_fn

In [9]:
from config import cfg_from_file, __C, cfg

cfg_from_file('cfg/clevr_train_mac.yml')
__C.CUDA = False
__C.GPU_ID = '-1'
vocab = load_vocab(cfg)
# cfg.TRAIN.RECV_OBJECTS = False

  yaml_cfg = edict(yaml.load(f))


In [10]:
ds = ClevrDataset(
    data_dir='/Users/sebamenabar/Documents/datasets/CLEVR/data',
    # img_dir='/Users/sebamenabar/Documents/datasets/CLEVR/CLEVR_v1.0/images/',
    # scenes_json='/Users/sebamenabar/Documents/TAIA/individual/sm/data/clevr/train/scenes.json',
    # raw_image=True,
    split='val',
)

In [11]:
loader = DataLoader(dataset=ds, batch_size=2, shuffle=True,
                                       num_workers=2, drop_last=False, collate_fn=collate_fn)

In [12]:
b = next(iter(loader))

In [13]:
model = MACNetwork(cfg=cfg, max_step=1, vocab=vocab)
model(b['image'], b['question'], b['question_length'])

tensor([[-0.0121, -0.5630,  0.8687,  0.2026, -0.2280,  0.3494, -0.1031, -0.7273,
         -0.9819, -0.1984,  0.6235,  0.2583,  0.3696,  1.0336, -1.3561, -1.3453,
         -0.2055, -1.2427, -0.1769, -1.0707,  0.0057, -1.0505, -0.5859,  0.4879,
         -2.1596, -0.9108, -0.0133, -1.5043],
        [-0.8363,  0.7582,  0.8435, -0.2574, -0.2468,  0.0294,  0.2364, -1.3489,
         -1.3964,  0.2389, -0.0677, -1.0775,  0.7257,  1.1218, -1.3847, -0.9917,
          1.0029, -0.7607, -0.6859,  0.2610, -0.2687, -0.7053, -1.5134,  0.3882,
         -1.4377,  0.6135,  0.2812, -0.3642]], grad_fn=<AddmmBackward>)