# QA - gluon
* http://gluon.mxnet.io/chapter08_computer-vision/visual-question-answer.html

In [1]:
from __future__ import print_function
import numpy as np
import mxnet as mx
import mxnet.ndarray as F
import mxnet.contrib.ndarray as C
import mxnet.gluon as gluon
from mxnet.gluon import nn
from mxnet import autograd
import bisect
from IPython.core.display import display, HTML
import logging
logging.basicConfig(level=logging.INFO)
import os
from mxnet.test_utils import download
import json
from IPython.display import HTML, display

  import OpenSSL.SSL


# Define the model

In [2]:
batch_size = 64
ctx = mx.cpu()
compute_size = batch_size
out_dim = 10000
gpus = 2

## In the first model, we will concatenate the image and question features and use multilayer perception(MLP) to predict the answe

In [3]:
class Net1(gluon.Block):
    def __init__(self, **kwargs):
        super(Net1, self).__init__(**kwargs)
        with self.name_scope():
            self.bn = nn.BatchNorm()
            self.dropout = nn.Dropout(0.3)
            self.fc1 = nn.Dense(8192, activation = 'relu')
            self.fc2 = nn.Dense(1000)
        
    def forward(self, x):
        x1 = F.L2Normalization(x[0])
        x2 = F.L2Normalization(x[1])
        z = F.concat(x1, x2, dim = 1)
        z = self.fc1(z)
        z = self.bn(z)
        z = self.dropout(z)
        z = self.fc2(z)
        return z


In [4]:
net1 = Net1()
net1

Net1(
  (bn): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, in_channels=None)
  (dropout): Dropout(p = 0.3)
  (fc1): Dense(None -> 8192, Activation(relu))
  (fc2): Dense(None -> 1000, linear)
)

## In the second model, instead of linearly combine the image and text features, we use count sketch to estimate the outer product of the image and question features. It is also named as multimodel compact bilinear pooling(MCB)

In [6]:
class Net2(gluon.Block):
    def __init__(self, **kwargs):
        super(Net2, self).__init__(**kwargs)
        with self.name_scope():
            slef.bn = nn.BatchNorm()
            self.dropout = nn.Dropout(0.3)
            self.fc1 = nn.Dense(8192, activation = 'relu')
            self.fc2 = nn.Dense(1000)
    
    def forward(self, x):
        x1 = F.L2Normalization(x[0])
        x2 = F.L2Normalization(x[1])
        text_ones = F.ones((batch_size/gpus, 2048), ctx = ctx)
        img_ones = F.ones((batch_size/gpus, 2048), ctx = ctx)
        text_data = F.Concat(x1, text_ones, dim = 1)
        image_data = F.Concat(x2, img_ones, dim = 1)
        # Initialize hash tables
        S1 = F.array(np.random.randint(0, 2, (1, 3072))*2 -1 , ctx = ctx)
        H1 = F.array(np.random.randint(0, out_dim(1, 3072))*2 - 1, ctx = ctx)
        S2 = F.array(np.random.randint(0, 2, (1, 3072))*2 -1 , ctx = ctx)
        H2 = F.array(np.random.randint(0, out_dim(1, 3072))*2 - 1, ctx = ctx)
        # Count Sketch
        cs1 = C.count_sketch(data = image_data, s = S1, h = H1, name = 'cs1', out_dim = out_dim)
        cs2 = C.count_sketch(data = text_data, s = S1, h = H1, name = 'cs1', out_dim = out_dim)
        fft1 = C.fft(data = cs1, name = 'fft1', compute_size = compute_size)
        ff2 = C.fft(data = cs2, name = 'fft2', compute_size = compute_size)
        c = fft1 * fft2 # Elementwise product
        ifft1 = C.ifft(data = c, name = 'ifft1', compute_size = compute_size)
        # MLP
        z = self.fc1(ifft1)
        z = self.bn(z)
        z = self.dropout(z)
        z = self.fc2(z)
        return z

# Data Iterator

In [16]:
class VQAtrainIter(mx.io.DataIter):
    def __init__(self, img, sentences, answer, batch_size, buckets=None, invalid_label=-1,
                 text_name='text', img_name = 'image', label_name='softmax_label', dtype='float32', layout='NTC'):
        super(VQAtrainIter, self).__init__()
        if not buckets:
            buckets = [i for i, j in enumerate(np.bincount([len(s) for s in sentences]))
                       if j >= batch_size]
        buckets.sort()

        ndiscard = 0
        self.data = [[] for _ in buckets]
        for i in range(len(sentences)):
            buck = bisect.bisect_left(buckets, len(sentences[i]))
            if buck == len(buckets):
                ndiscard += 1
                continue
            buff = np.full((buckets[buck],), invalid_label, dtype=dtype)
            buff[:len(sentences[i])] = sentences[i]
            self.data[buck].append(buff)

        self.data = [np.asarray(i, dtype=dtype) for i in self.data]
        self.answer = answer
        self.img = img
        print("WARNING: discarded %d sentences longer than the largest bucket."%ndiscard)

        
        self.batch_size = batch_size
        self.buckets = buckets
        self.text_name = text_name
        self.img_name = img_name
        self.label_name = label_name
        self.dtype = dtype
        self.invalid_label = invalid_label
        self.nd_text = []
        self.nd_img = []
        self.ndlabel = []
        self.major_axis = layout.find('N')
        self.default_bucket_key = max(buckets)

        if self.major_axis == 0:
            self.provide_data = [(text_name, (batch_size, self.default_bucket_key)),
                                 (img_name, (batch_size, self.default_bucket_key))]
            self.provide_label = [(label_name, (batch_size, self.default_bucket_key))]
        elif self.major_axis == 1:
            self.provide_data = [(text_name, (self.default_bucket_key, batch_size)),
                                 (img_name, (self.default_bucket_key, batch_size))]
            self.provide_label = [(label_name, (self.default_bucket_key, batch_size))]
        else:
            raise ValueError("Invalid layout %s: Must by NT (batch major) or TN (time major)")

        self.idx = []
        for i, buck in enumerate(self.data):
            self.idx.extend([(i, j) for j in range(0, len(buck) - batch_size + 1, batch_size)])
        self.curr_idx = 0

        self.reset()

    def reset(self):
        self.curr_idx = 0
        self.nd_text = []
        self.nd_img = []
        self.ndlabel = []
        for buck in self.data:
            label = np.empty_like(buck.shape[0])
            label = self.answer
            self.nd_text.append(mx.ndarray.array(buck, dtype=self.dtype))
            self.nd_img.append(mx.ndarray.array(self.img, dtype=self.dtype))
            self.ndlabel.append(mx.ndarray.array(label, dtype=self.dtype))

    def next(self):
        if self.curr_idx == len(self.idx):
            raise StopIteration
        i, j = self.idx[self.curr_idx]
        self.curr_idx += 1

        if self.major_axis == 1:
            img = self.nd_img[i][j:j + self.batch_size].T
            text = self.nd_text[i][j:j + self.batch_size].T
            label = self.ndlabel[i][j:j+self.batch_size]
        else:
            img = self.nd_img[i][j:j + self.batch_size]
            text = self.nd_text[i][j:j + self.batch_size]
            label = self.ndlabel[i][j:j+self.batch_size]

        data = [text, img]
        return mx.io.DataBatch(data, [label],
                         bucket_key=self.buckets[i],
                         provide_data=[(self.text_name, text.shape),(self.img_name, img.shape)],
                         provide_label=[(self.label_name, label.shape)])

# Load the data

In [6]:
import numpy as np
a = np.bincount([1,2,3])
a.sort()
print(a)

[0 1 1 1]
