In [8]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch
from torch import nn

In [4]:
class HierarchicalMemory(nn.Module):
    def __init__(self,
                 num_categories=4,
                 num_attributes_per_category=5,
                 categories_key_size=256,
                 attributes_key_size=256,
                 value_size=512,
                 dropout=0.1,
                ):
        super().__init__()
        
        self.num_categories = num_categories
        self.num_attributes_per_category = num_attributes_per_category
        
        self.categories_key_size = categories_key_size
        self.attributes_key_size = attributes_key_size
        self.value_size = value_size
        
        self.categories = torch.FloatTensor(num_categories, categories_key_size)
        self.attributes = torch.FloatTensor(num_categories, num_attributes_per_category, attributes_key_size)
        
        self.categories = nn.Parameter(self.categories)
        self.attributes = nn.Parameter(self.attributes)
        
        if value_size is not None:
            self.values = torch.FloatTensor(num_categories, num_attributes_per_category, value_size)
            self.values = nn.Parameter(self.values)
        else:
            self.values = None
        
        self.dropout = nn.Dropout(dropout)
        self._init_params()
        
    def _init_params(self):
        nn.init.normal_(self.categories)
        nn.init.normal_(self.attributes)
        if self.values is not None:
            nn.init.normal_(self.values)
        
    def forward(self, query):
        
        query = self.dropout(query)
        category_attention = torch.matmul(query[:, :self.categories_key_size], self.categories.t())
        category_attention = nn.functional.softmax(category_attention, dim=1)
        
        attribute_attention = torch.matmul(self.attributes, query[:, self.categories_key_size:].t())
        attribute_attention = nn.functional.softmax(attribute_attention.permute(2, 0, 1), dim=2)
        
        category_values = torch.matmul(category_attention, self.categories)
        if self.values is not None:
            attribute_values = (attribute_attention.unsqueeze(3) * self.values).sum(dim=2)
        else:
            attribute_values = (attribute_attention.unsqueeze(3) * self.attributes).sum(dim=2)

        values = torch.matmul(category_attention.unsqueeze(1), attribute_values).squeeze(1)
        
        return values, category_values, category_attention, attribute_attention 

    def get_category_attention(self, query):
        category_attention = torch.matmul(query[:, :self.categories_key_size], self.categories.t())
        category_attention = nn.functional.softmax(category_attention, dim=1)

        return category_attention

In [6]:
hm = HierarchicalMemory(3, 4, 5, 6, value_size=None)
attributes, categories = hm.attributes, hm.categories
# print(categories), print(attributes);
hm(torch.ones(2, 11))[0].size()

torch.Size([2, 6])

In [4]:
from mac import HierarchicalMemory

ModuleNotFoundError: No module named 'mac'

In [None]:
torch.randn(8)

In [57]:
(nn.functional.softmax(torch.randn(2, 5, 6), dim=2).unsqueeze(3) *  torch.randn(5, 6, 8)).sum(dim=2)

tensor([[[-0.4126,  0.6920,  0.2098, -0.2235, -0.3082,  0.2059,  0.4995,
           0.1840],
         [-0.8027, -0.1491, -0.2894, -0.7281,  0.5694, -0.1165, -0.1300,
          -0.0618],
         [-0.5812,  1.0318, -0.1973, -0.1821, -0.3681, -0.3152, -0.9683,
          -0.3389],
         [ 0.3253,  0.0450, -0.0862, -0.0870, -0.3185,  0.1147,  0.0950,
          -0.5925],
         [ 0.4874, -0.6049, -0.0925,  0.0021,  0.3434,  0.9950, -0.0963,
           0.3307]],

        [[-0.4825, -0.1253, -0.1234, -0.1229,  0.3158, -0.5058, -0.2277,
           0.6234],
         [-1.0779, -0.6746, -0.4625,  0.3930, -0.2796, -0.1995, -0.3210,
          -0.5292],
         [-0.5685,  1.2189,  0.1967,  0.2939, -0.1047,  0.0913, -1.1220,
          -0.4336],
         [ 0.2339, -0.0343,  0.0202, -0.9019, -0.5439,  0.3890,  0.2416,
          -0.6860],
         [ 0.7302, -0.9310,  0.3548,  0.2108, -0.2683,  0.9025, -0.1845,
           0.1122]]])

torch.Size([2, 11])
torch.Size([2, 3])
torch.Size([3, 4, 6])
torch.Size([3, 4, 2])
torch.Size([2, 3, 4])
torch.Size([2, 3])
torch.Size([2, 3, 6])


tensor([[ 0.1142, -0.0093,  1.2074,  0.8915, -0.6113,  0.2692],
        [-0.0981,  1.2242, -0.0969,  0.5737,  0.4226, -0.1635]],
       grad_fn=<SqueezeBackward1>)

In [None]:
%load_ext autoreload
%autoreload 2

In [1]:
import sys

sys.path.insert(0, 'code/')

In [10]:
import torch
from torch import nn
from torch.utils.data import DataLoader
import torchvision
from torchsummaryX import summary


from mac import MACNetwork
from utils import load_vocab
from datasets import ClevrDataset, collate_fn

In [4]:
from config import cfg_from_file, __C, cfg

cfg_from_file('cfg/local.yml')
__C.CUDA = False
__C.GPU_ID = '-1'
vocab = load_vocab(cfg)
# cfg.TRAIN.RECV_OBJECTS = False

  yaml_cfg = edict(yaml.load(f))


In [12]:
ds = ClevrDataset(
    data_dir='/Users/sebamenabar/Documents/datasets/CLEVR/data',
    # img_dir='/Users/sebamenabar/Documents/datasets/CLEVR/CLEVR_v1.0/images/',
    # scenes_json='/Users/sebamenabar/Documents/TAIA/individual/sm/data/clevr/train/scenes.json',
    # raw_image=True,
    split='val',
)

In [13]:
loader = DataLoader(dataset=ds, batch_size=2, shuffle=True,
                                       num_workers=2, drop_last=False, collate_fn=collate_fn)

In [23]:
model = MACNetwork(cfg=cfg, max_step=4, vocab=vocab)
model.load_state_dict(torch.load('/Users/sebamenabar/Documents/concept_mac.pth', map_location='cpu')['model'])
# model(b['image'], b['question'], b['question_length'])

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [30]:
model.mac.control.attn.weight

Parameter containing:
tensor([[ 9.1099e-04, -5.2064e-02, -1.6987e-03,  6.7334e-04,  4.5890e-02,
         -3.0428e-02,  5.2594e-02,  3.4504e-03,  1.2701e-02, -2.5365e-03,
         -2.6305e-02,  4.4835e-02, -2.5110e-02,  6.1536e-02,  1.0283e-02,
         -3.6328e-02,  1.7856e-03,  5.1706e-03,  1.5270e-01, -1.4143e-02,
          7.4833e-03, -1.6326e-02,  1.1401e-03, -5.9885e-02,  1.3202e-03,
          1.6604e-04,  9.7150e-03,  1.0755e-01,  5.1553e-02,  1.3866e-04,
          4.2734e-02,  2.5243e-02, -2.5355e-02,  3.2289e-02,  2.3130e-02,
          1.2027e-01, -2.9729e-02, -1.3794e-03, -6.0275e-02, -6.9985e-02,
          9.5877e-02,  2.0320e-03, -2.1570e-03,  7.6947e-02,  3.5740e-02,
          8.1291e-04, -2.9686e-02, -4.6840e-02, -9.4602e-04,  7.1419e-02,
          1.5206e-01,  4.9460e-02, -7.8515e-05, -7.4855e-02, -3.1934e-03,
         -4.4200e-02,  2.3165e-03, -1.8579e-03, -3.1161e-03,  3.0754e-04,
         -3.7320e-02,  4.1890e-02, -1.1318e-01, -1.2373e-02, -9.0849e-02,
          3.7740