In [1]:
%load_ext autoreload
%autoreload 2

# Imports

In [2]:
import os
import sys

In [3]:
sys.path.insert(0, 'code/')

In [4]:
import pickle

import h5py
import numpy as np

import torch
import torch.nn as nn
import torch.nn.init as init
from torch.autograd import Variable
from torch.nn import functional as F

from tqdm import tqdm_notebook as tqdm
from easydict import EasyDict as edict

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cpu


# Utils

In [6]:
## UTILS

def load_vocab(cfg):
    def invert_dict(d):
        return {v: k for k, v in d.items()}

    with open(os.path.join(cfg.DATASET.DATA_DIR, 'dic.pkl'), 'rb') as f:
        dictionaries = pickle.load(f)
    vocab = {}
    vocab['question_token_to_idx'] = dictionaries["word_dic"]
    vocab['answer_token_to_idx'] = dictionaries["answer_dic"]
    vocab['question_token_to_idx']['pad'] = 0
    vocab['question_idx_to_token'] = invert_dict(vocab['question_token_to_idx'])
    vocab['answer_idx_to_token'] = invert_dict(vocab['answer_token_to_idx'])

    return vocab

def init_modules(modules, w_init='kaiming_uniform'):
    if w_init == "normal":
        _init = init.normal_
    elif w_init == "xavier_normal":
        _init = init.xavier_normal_
    elif w_init == "xavier_uniform":
        _init = init.xavier_uniform_
    elif w_init == "kaiming_normal":
        _init = init.kaiming_normal_
    elif w_init == "kaiming_uniform":
        _init = init.kaiming_uniform_
    elif w_init == "orthogonal":
        _init = init.orthogonal_
    else:
        raise NotImplementedError
    for m in modules:
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
            _init(m.weight)
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)
        if isinstance(m, (nn.LSTM, nn.GRU)):
            for name, param in m.named_parameters():
                if 'bias' in name:
                    nn.init.zeros_(param)
                elif 'weight' in name:
                    _init(param)

# Model

In [17]:
class ControlUnit(nn.Module):
    def __init__(self,
                 module_dim,
                 max_step=4,
                 separate_syntax_semantics=False,
                ):
        super().__init__()
        self.attn = nn.Linear(module_dim, 1)
        # self.control_input = nn.Sequential(nn.Linear(module_dim, module_dim),
        #                                    nn.Tanh())

        self.control_input_u = nn.ModuleList()
        for i in range(max_step):
            self.control_input_u.append(nn.Linear(module_dim, module_dim))

        self.module_dim = module_dim
        self.separate_syntax_semantics = separate_syntax_semantics

    def mask(self, question_lengths, device):
        max_len = max(question_lengths)
        mask = torch.arange(max_len, device=device).expand(len(question_lengths), int(max_len)) < question_lengths.unsqueeze(1)
        mask = mask.float()
        ones = torch.ones_like(mask)
        mask = (ones - mask) * (1e-30)
        return mask

    @staticmethod
    def mask_by_length(x, lengths, device=None):
        lengths = torch.as_tensor(lengths, dtype=torch.float32, device=device)
        max_len = max(lengths)
        mask = torch.arange(max_len, device=device).expand(len(lengths), int(max_len)) < lengths.unsqueeze(1)
        mask = mask.float().unsqueeze(2)
        x_masked = x * mask + (1 - 1 / mask)

        return x_masked

    def forward(self, question, context, question_lengths, step):
        """
        Args:
            question: external inputs to control unit (the question vector).
                [batchSize, ctrlDim]
            context: the representation of the words used to compute the attention.
                [batchSize, questionLength, ctrlDim]
            control: previous control state
            question_lengths: the length of each question.
                [batchSize]
            step: which step in the reasoning chain
        """
        # compute interactions with question words
        # question = self.control_input(question)
        if self.separate_syntax_semantics:
            syntactics, semantics = context
        else:
            syntactics, semantics = context, context

        question = self.control_input_u[step](question)

        newContControl = question
        newContControl = torch.unsqueeze(newContControl, 1)
        interactions = newContControl * syntactics

        # compute attention distribution over words and summarize them accordingly
        logits = self.attn(interactions)

        logits = self.mask_by_length(logits, question_lengths, device=syntactics.device)
        attn = F.softmax(logits, 1)

        # apply soft attention to current context words
        next_control = (attn * semantics).sum(1)

        return next_control


class ReadUnit(nn.Module):
    def __init__(self, module_dim):
        super().__init__()

        self.concat = nn.Linear(module_dim * 2, module_dim)
        self.concat_2 = nn.Linear(module_dim, module_dim)
        self.attn = nn.Linear(module_dim, 1)
        self.dropout = nn.Dropout(0.15)
        self.kproj = nn.Linear(module_dim, module_dim)
        self.mproj = nn.Linear(module_dim, module_dim)

        self.activation = nn.ELU()
        self.module_dim = module_dim

    def forward(self, memory, know, control, memDpMask=None):
        """
        Args:
            memory: the cell's memory state
                [batchSize, memDim]

            know: representation of the knowledge base (image).
                [batchSize, kbSize (Height * Width), memDim]

            control: the cell's control state
                [batchSize, ctrlDim]

            memDpMask: variational dropout mask (if used)
                [batchSize, memDim]
        """
        ## Step 1: knowledge base / memory interactions
        # compute interactions between knowledge base and memory
        know = self.dropout(know)
        if memDpMask is not None:
            if self.training:
                memory = applyVarDpMask(memory, memDpMask, 0.85)
        else:
            memory = self.dropout(memory)
        know_proj = self.kproj(know)
        memory_proj = self.mproj(memory)
        memory_proj = memory_proj.unsqueeze(1)
        interactions = know_proj * memory_proj

        # project memory interactions back to hidden dimension
        interactions = torch.cat([interactions, know_proj], -1)
        interactions = self.concat(interactions)
        interactions = self.activation(interactions)
        interactions = self.concat_2(interactions)

        ## Step 2: compute interactions with control
        control = control.unsqueeze(1)
        interactions = interactions * control
        interactions = self.activation(interactions)

        ## Step 3: sum attentions up over the knowledge base
        # transform vectors to attention distribution
        interactions = self.dropout(interactions)
        attn = self.attn(interactions).squeeze(-1)
        attn = F.softmax(attn, 1)

        # sum up the knowledge base according to the distribution
        attn = attn.unsqueeze(-1)
        read = (attn * know).sum(1)

        return read


class WriteUnit(nn.Module):
    def __init__(self, module_dim, rtom=True):
        super().__init__()
        self.linear = nn.Linear(module_dim * 2, module_dim)
        self.rtom = rtom
        
    def forward(self, memory, info):
        if self.rtom:
            newMemory = info
        else:
            newMemory = torch.cat([memory, info], -1)
            newMemory = self.linear(newMemory)

        return newMemory


class MACUnit(nn.Module):
    def __init__(self, units_cfg, module_dim=512, max_step=4):
        super().__init__()
        self.cfg = cfg
        self.control = ControlUnit(
            **{
                'module_dim': module_dim,
                'max_step': max_step,
                **units_cfg.common,
                **units_cfg.control_unit
            })
        self.read = ReadUnit(
            **{
                'module_dim': module_dim,
                **units_cfg.common,
                **units_cfg.read_unit,
            })
        self.write = WriteUnit(
            **{
                'module_dim': module_dim,
                **units_cfg.common,
                **units_cfg.write_unit,
            })

        self.initial_memory = nn.Parameter(torch.zeros(1, module_dim))

        self.module_dim = module_dim
        self.max_step = max_step

    def zero_state(self, batch_size, question):
        initial_memory = self.initial_memory.expand(batch_size, self.module_dim)
        initial_control = question

        if self.cfg.TRAIN.VAR_DROPOUT:
            memDpMask = generateVarDpMask((batch_size, self.module_dim), 0.85)
        else:
            memDpMask = None

        return initial_control, initial_memory, memDpMask

    def forward(self, context, question, knowledge, question_lengths):
        batch_size = question.size(0)
        control, memory, memDpMask = self.zero_state(batch_size, question)

        for i in range(self.max_step):
            # control unit
            control = self.control(question, context, question_lengths, i)
            # read unit
            info = self.read(memory, knowledge, control, memDpMask)
            # write unit
            memory = self.write(memory, info)

        return memory


class InputUnit(nn.Module):
    def __init__(self,
                 vocab_size,
                 wordvec_dim=300,
                 rnn_dim=512,
                 module_dim=512,
                 bidirectional=True,
                 separate_syntax_semantics=False,
                 separate_syntax_semantics_embeddings=False,
                ):
        super(InputUnit, self).__init__()

        self.dim = module_dim
        self.wordvec_dim = wordvec_dim
        self.separate_syntax_semantics = separate_syntax_semantics
        self.separate_syntax_semantics_embeddings = separate_syntax_semantics and separate_syntax_semantics_embeddings

        self.stem = nn.Sequential(nn.Dropout(p=0.18),
                                  nn.Conv2d(1024, module_dim, 3, 1, 1),
                                  nn.ELU(),
                                  nn.Dropout(p=0.18),
                                  nn.Conv2d(module_dim, module_dim, kernel_size=3, stride=1, padding=1),
                                  nn.ELU())

        self.bidirectional = bidirectional
        if bidirectional:
            rnn_dim = rnn_dim // 2

        self.encoder = nn.LSTM(wordvec_dim, rnn_dim, batch_first=True, bidirectional=bidirectional)
        if self.separate_syntax_semantics_embeddings:
            wordvec_dim *= 2
        self.encoder_embed = nn.Embedding(vocab_size, wordvec_dim)
        self.embedding_dropout = nn.Dropout(p=0.15)
        self.question_dropout = nn.Dropout(p=0.08)

    def forward(self, image, question, question_len):
        b_size = question.size(0)

        # get image features
        img = self.stem(image)
        img = img.view(b_size, self.dim, -1)
        img = img.permute(0,2,1)

        # get question and contextual word embeddings
        embed = self.encoder_embed(question)
        embed = self.embedding_dropout(embed)
        if self.separate_syntax_semantics_embeddings:
            semantics = embed[:, :, self.wordvec_dim:]
            embed = embed[:, :, :self.wordvec_dim]
        else:
            semantics = embed
        
        embed = nn.utils.rnn.pack_padded_sequence(embed, question_len, batch_first=True)
        contextual_words, (question_embedding, _) = self.encoder(embed)
        
        if self.bidirectional:
            question_embedding = torch.cat([question_embedding[0], question_embedding[1]], -1)
        question_embedding = self.question_dropout(question_embedding)

        contextual_words, _ = nn.utils.rnn.pad_packed_sequence(contextual_words, batch_first=True)
        
        if self.separate_syntax_semantics:
            contextual_words = (contextual_words, semantics)
        
        return question_embedding, contextual_words, img


class OutputUnit(nn.Module):
    def __init__(self, module_dim=512, num_answers=28):
        super(OutputUnit, self).__init__()

        self.question_proj = nn.Linear(module_dim, module_dim)

        self.classifier = nn.Sequential(nn.Dropout(0.15),
                                        nn.Linear(module_dim * 2, module_dim),
                                        nn.ELU(),
                                        nn.Dropout(0.15),
                                        nn.Linear(module_dim, num_answers))

    def forward(self, question_embedding, memory):
        # apply classifier to output of MacCell and the question
        question_embedding = self.question_proj(question_embedding)
        out = torch.cat([memory, question_embedding], 1)
        out = self.classifier(out)

        return out


class MACNetwork(nn.Module):
    def __init__(self, cfg, vocab, num_answers=28):
        super().__init__()

        self.cfg = cfg
        if getattr(cfg.model, 'separate_syntax_semantics') is True:
            cfg.model.input_unit.separate_syntax_semantics = True
            cfg.model.control_unit.separate_syntax_semantics = True
            
        
        encoder_vocab_size = len(vocab['question_token_to_idx'])
        
        self.input_unit = InputUnit(
            vocab_size=encoder_vocab_size,
            **cfg.model.common,
            **cfg.model.input_unit,
        )

        self.output_unit = OutputUnit(
            num_answers=num_answers,
            **cfg.model.common,
            **cfg.model.output_unit,
        )

        self.mac = MACUnit(
            cfg.model,
            max_step=cfg.model.max_step,
            **cfg.model.common,
        )

        init_modules(self.modules(), w_init=cfg.TRAIN.WEIGHT_INIT)
        nn.init.uniform_(self.input_unit.encoder_embed.weight, -1.0, 1.0)
        nn.init.normal_(self.mac.initial_memory)

    def forward(self, image, question, question_len):
        # get image, word, and sentence embeddings
        question_embedding, contextual_words, img = self.input_unit(image, question, question_len)

        # apply MacCell
        memory = self.mac(contextual_words, question_embedding, img, question_len)

        # get classification
        out = self.output_unit(question_embedding, memory)

        return out


In [8]:
cfg = edict({
    'GPU_ID': '-1',
    'CUDA': False,
    'WORKERS': 4,
    'TRAIN': {'FLAG': True,
    'LEARNING_RATE': 0.0001,
    'BATCH_SIZE': 64,
    'MAX_EPOCHS': 25,
    'SNAPSHOT_INTERVAL': 5,
    'WEIGHT_INIT': 'xavier_uniform',
    'CLIP_GRADS': True,
    'CLIP': 8,
    # 'MAX_STEPS': 4,
    'EALRY_STOPPING': True,
    'PATIENCE': 5,
    'VAR_DROPOUT': False},
    'DATASET': {
        # 'DATA_DIR': '/mnt/nas2/GrimaRepo/datasets/CLEVR_v1.0/features',
        'DATA_DIR': '/Users/sebamenabar/Documents/datasets/CLEVR/data/',
    },
    'model': {
        'max_step': 4,
        'separate_syntax_semantics': True,
        'common': {
            'module_dim': 256,
        },
        'input_unit': {
            'wordvec_dim': 256,
            'rnn_dim': 256, 
            'bidirectional': True,
            'separate_syntax_semantics_embeddings': True,
        },
        'control_unit': {
        },
        'read_unit': {},
        'write_unit': {
            'rtom': True,
        },
        'output_unit': {},
    }
})

vocab = load_vocab(cfg)

In [24]:
edict(a=1)

{'a': 1}

# DS

In [9]:
class ClevrDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, split='train'):

        with open(os.path.join(data_dir, '{}.pkl'.format(split)), 'rb') as f:
            self.data = pickle.load(f)
        self.img = h5py.File(os.path.join(data_dir, '{}_features.h5'.format(split)), 'r')['features'] # ['data']
        # self.img = h5py.File(os.path.join(data_dir, '{}_features.hdf5'.format(split)), 'r')['data']

    def __getitem__(self, index):
        imgfile, question, answer, family = self.data[index]
        id = int(imgfile.rsplit('_', 1)[1][:-4])
        img = torch.from_numpy(self.img[id])

        return img, question, len(question), answer, family

    def __len__(self):
        return len(self.data)


def collate_fn(batch):
    images, lengths, answers, _ = [], [], [], []
    batch_size = len(batch)

    max_len = max(map(lambda x: len(x[1]), batch))

    questions = np.zeros((batch_size, max_len), dtype=np.int64)
    sort_by_len = sorted(batch, key=lambda x: len(x[1]), reverse=True)

    for i, b in enumerate(sort_by_len):
        image, question, length, answer, family = b
        images.append(image)
        length = len(question)
        questions[i, :length] = question
        lengths.append(length)
        answers.append(answer)

    return {'image': torch.stack(images), 'question': torch.from_numpy(questions),
            'answer': torch.LongTensor(answers), 'question_length': lengths}

# main

In [10]:
ds = ClevrDataset(cfg.DATASET.DATA_DIR, split='val')

In [11]:
from torchsummaryX import summary

In [18]:
model = MACNetwork(cfg=cfg, vocab=vocab)
loader = torch.utils.data.DataLoader(ds, batch_size=8, shuffle=False, collate_fn=collate_fn)
b = next(iter(loader))

summary(model, b['image'], b['question'], b['question_length'])

                                              Kernel Shape       Output Shape  \
Layer                                                                           
0_input_unit.stem.Dropout_0                              -  [8, 1024, 14, 14]   
1_input_unit.stem.Conv2d_1               [1024, 256, 3, 3]   [8, 256, 14, 14]   
2_input_unit.stem.ELU_2                                  -   [8, 256, 14, 14]   
3_input_unit.stem.Dropout_3                              -   [8, 256, 14, 14]   
4_input_unit.stem.Conv2d_4                [256, 256, 3, 3]   [8, 256, 14, 14]   
5_input_unit.stem.ELU_5                                  -   [8, 256, 14, 14]   
6_input_unit.Embedding_encoder_embed             [512, 90]       [8, 19, 512]   
7_input_unit.Dropout_embedding_dropout                   -       [8, 19, 512]   
8_input_unit.LSTM_encoder                                -         [116, 256]   
9_input_unit.Dropout_question_dropout                    -           [8, 256]   
10_mac.control.control_input

Unnamed: 0_level_0,Kernel Shape,Output Shape,Params,Mult-Adds
Layer,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0_input_unit.stem.Dropout_0,-,"[8, 1024, 14, 14]",,
1_input_unit.stem.Conv2d_1,"[1024, 256, 3, 3]","[8, 256, 14, 14]",2359552.0,462422016.0
2_input_unit.stem.ELU_2,-,"[8, 256, 14, 14]",,
3_input_unit.stem.Dropout_3,-,"[8, 256, 14, 14]",,
4_input_unit.stem.Conv2d_4,"[256, 256, 3, 3]","[8, 256, 14, 14]",590080.0,115605504.0
5_input_unit.stem.ELU_5,-,"[8, 256, 14, 14]",,
6_input_unit.Embedding_encoder_embed,"[512, 90]","[8, 19, 512]",46080.0,46080.0
7_input_unit.Dropout_embedding_dropout,-,"[8, 19, 512]",,
8_input_unit.LSTM_encoder,-,"[116, 256]",395264.0,393216.0
9_input_unit.Dropout_question_dropout,-,"[8, 256]",,
