In [1]:
import math
import mxnet as mx
from mxnet import gluon, autograd, nd
from mxnet.gluon import nn,utils
from mxnet.gluon.data.vision import transforms
import mxnet.ndarray as F
import numpy as np
import os, sys
import os
import numpy as np
import collections
from PIL import Image
import csv
import random
from mxnet.gluon.data import Dataset, DataLoader
from tqdm import tqdm

  from ._conv import register_converters as _register_converters


In [2]:
# coding=utf-8
from utils import init_dataset
from omniglot_dataset import OmniglotDataset
from batch_sampler import BatchSampler
import argparse
import torch

In [3]:
mx.__version__

'1.2.0'

In [4]:
class CasualConv1d(nn.Block):
    
    def __init__(self,in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True,**kwargs):
        super(CasualConv1d,self).__init__(**kwargs)
        self.dilation = dilation
        self.padding = dilation * (kernel_size - 1)
        
        with self.name_scope():
            self.casual_conv = nn.Conv1D(in_channels=in_channels,channels=out_channels,kernel_size=kernel_size,padding = self.padding, dilation = dilation, groups=groups, use_bias=bias)
            
    def forward(self,x):
        out = self.casual_conv(x)
        return out[:,:,:-self.dilation]

In [5]:
class DenseBlock(nn.Block):
    def __init__(self, in_channels, filters, dilation=1, kernel_size=2, **kwargs):
        super(DenseBlock,self).__init__(**kwargs)
        
        with self.name_scope():
            self.casual_conv1 = CasualConv1d(in_channels, filters, kernel_size, dilation = dilation)
            self.casual_conv2 = CasualConv1d(in_channels, filters, kernel_size, dilation = dilation)
            
    def forward(self, x):
        #print(x.shape)
        tanh = F.tanh(self.casual_conv1(x))
        sigmoid = F.sigmoid(self.casual_conv1(x))
        out =  F.concat(x,tanh*sigmoid, dim=1)
        #print("Dense success")
        return out  

In [6]:
class TCBlock(nn.Block):
    def __init__(self, in_channels,seq_len, filters, **kwargs):
        super(TCBlock,self).__init__(**kwargs)
        layer_count = int(math.ceil(math.log(seq_len)))
        channel_count = in_channels
        with self.name_scope():
            self.blocks = nn.Sequential()
            for i in range(layer_count):
                self.blocks.add(DenseBlock(in_channels + i * filters, filters, dilation =2 ** (i+1)))
                
    def forward(self, x):
        x = x.swapaxes(1,2)
        out = self.blocks(x)
        return out.swapaxes(1,2)

In [7]:
class AttentionBlock(nn.Block):
    def __init__(self, k_size, v_size,ctx=mx.cpu(),show_shape=False, **kwargs):
        super(AttentionBlock,self).__init__(**kwargs)
        self.ctx = ctx
        self.sqrt_k = math.sqrt(k_size)
        self.show_shape = False
        with self.name_scope():
            self.key_layer = nn.Dense(k_size,flatten=False)
            self.query_layer = nn.Dense(k_size,flatten=False)
            self.value_layer = nn.Dense(v_size,flatten=False)
            
    
    def forward(self, x):
        #x = x.swapaxes(1,2)
        keys = self.key_layer(x)       
        queries = self.query_layer(x)
        values = self.value_layer(x)
        logits = nd.linalg_gemm2(queries,keys.swapaxes(2,1))
        if self.show_shape:
            print("keys shape:{}".format(keys.shape))
            print("queries shape:{}".format(queries.shape))
            print("logits shape:{}".format(logits.shape))
        mask = np.full(shape=(logits.shape[1],logits.shape[2]),fill_value=1).astype('float')
        mask = np.triu(mask,1)
        mask = np.expand_dims(mask,0)
        mask = np.repeat(mask,logits.shape[0],0)
        np.place(mask,mask==1,0.0)
        np.place(mask,mask==0,1.0)
        #np.place(mask,mask==0,0.0)
        #np.place(mask,mask==1,1.0)
        mask = nd.array(mask,ctx=self.ctx)
        logits =  F.elemwise_mul(logits,mask)
        probs = F.softmax(logits / self.sqrt_k, axis=2)
        if self.show_shape:
            print("probs shape:{}".format(probs.shape))
            print("values shape:{}".format(values.shape))
        read = nd.linalg_gemm2(probs,values)
        concat_data = F.concat(x,read,dim=2)
        return concat_data
        #return queries,probs

In [8]:
class CnnEmbedding(nn.Block):
    
    def __init__(self,**kwargs):
        super(CnnEmbedding,self).__init__(**kwargs)
        with self.name_scope():
            self.cnn1 = nn.Conv2D(64,3,padding=1,activation='relu')
            self.bn1 =  nn.BatchNorm()
            self.max1 = nn.MaxPool2D(2,2)
            self.cnn2 = nn.Conv2D(64,3,padding=1,activation='relu')
            self.bn2 = nn.BatchNorm()
            self.max2 = nn.MaxPool2D(2,2)
            self.cnn3 = nn.Conv2D(64,3,padding=1,activation='relu')
            self.bn3 = nn.BatchNorm()
            self.max3 = nn.MaxPool2D(2,2)
            self.cnn4 = nn.Conv2D(64,3,padding=1,activation='relu')
            self.bn4 = nn.BatchNorm()
            self.max4 = nn.MaxPool2D(2)
    
    def forward(self,x):
        out = self.cnn1(x)
        out = self.bn1(out)
        out = self.max1(out)
        out = self.cnn2(out)
        out = self.bn2(out)
        out = self.max2(out)
        out = self.cnn3(out)
        out = self.bn3(out)
        out = self.max3(out)
        out = self.cnn4(out)
        out = self.bn4(out)
        out = self.max4(out)
        return out.reshape(out.shape[0],-1)

In [9]:
class SNAIL(nn.Block):
    def __init__(self,N,K,input_dims,ctx=mx.cpu(),**kwargs):
        super(SNAIL,self).__init__(**kwargs)
        self.N = N
        self.K = K
        self.num_filters = int(math.ceil(math.log(N * K + 1)))
        self.ctx = ctx
        self.num_channels = input_dims + N
        with self.name_scope():
            self.attn1 = AttentionBlock(64, 32, ctx=self.ctx)
            attn1_out_shape = self.num_channels + 32
            self.tc1 = TCBlock(attn1_out_shape ,N*K+1 , 128)
            tc1_out_shape = attn1_out_shape + self.num_filters * 128
            self.attn2 = AttentionBlock(256, 128, ctx=self.ctx)
            attn2_out_shape = tc1_out_shape + 128
            self.tc2 = TCBlock(attn2_out_shape ,N*K+1 , 128)
            tc2_out_shape = attn2_out_shape + self.num_filters * 128
            self.attn3 = AttentionBlock(512, 256, ctx=self.ctx)
            attn3_out_shape = tc2_out_shape + 128
            self.fc = nn.Dense(N,flatten=False)
                        
    def forward(self, x, labels):
        batch_size = int(labels.shape[0] / (N * K + 1))
        last_idxs = [(i + 1) * (N * K + 1) - 1 for i in range(batch_size)]
        labels[last_idxs] = nd.zeros(shape=(batch_size, labels.shape[1]), ctx=self.ctx)
        x = F.concat(x,labels,dim=1)
        x = x.reshape((batch_size,N*K+1,-1))
        x = self.attn1(x)
        x = self.tc1(x)
        x = self.attn2(x)
        x = self.tc2(x)
        x = self.attn3(x)
        x = self.fc(x)
        
        return x
        

In [10]:
train_dataset = OmniglotDataset(mode='train',download=False)

/home/skinet/work/datasets/omniglot/data
== Dataset: Found 82240 items 
== Dataset: Found 4112 classes


In [11]:
N = 10     #num_class
K = 5  #num_samples
iterations = 10
batch_size = 24

In [12]:
tr_sampler = BatchSampler(labels=train_dataset.y,
                                          classes_per_it=N,
                                          num_samples=K,
                                          iterations=iterations,
                                          batch_size=batch_size)

In [13]:
tr_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                batch_sampler=tr_sampler)

In [14]:
def batch_for_few_shot(num_cls,num_samples,batch_size, x, y):
    seq_size = num_cls * num_samples + 1
    one_hots = []
    last_targets = []
    for i in range(batch_size):
        one_hot, idxs = labels_to_one_hot(y[i * seq_size: (i + 1) * seq_size])
        one_hots.append(one_hot)
        last_targets.append(idxs[-1])
    #last_targets = Variable(torch.Tensor(last_targets).long())
    last_targets =torch.Tensor(last_targets).long()
    one_hots = [torch.Tensor(temp) for temp in one_hots]
    y = torch.cat(one_hots, dim=0)
    #x, y = Variable(x), Variable(y)
    x = nd.array(x.data.numpy())
    y = nd.array(y.data.numpy())
    last_targets = nd.array(last_targets.data.numpy())
    return x, y, last_targets

In [15]:
def labels_to_one_hot(labels):
    labels = labels.numpy()
    unique = np.unique(labels)
    map = {label:idx for idx, label in enumerate(unique)}
    idxs = [map[labels[i]] for i in range(labels.size)]
    one_hot = np.zeros((labels.size, unique.size))
    one_hot[np.arange(labels.size), idxs] = 1
    return one_hot, idxs

In [16]:
model = SNAIL(N=N,K=K,input_dims=64)
model.collect_params().initialize()
loss = gluon.loss.SoftmaxCrossEntropyLoss()
trainer = gluon.Trainer(model.collect_params(),optimizer='Adam',optimizer_params={'learning_rate':0.0001})

In [17]:
tr_iter = iter(tr_dataloader)

In [18]:
for batch in tqdm(tr_iter):
    x,y = batch
    break

  0%|          | 0/10 [00:00<?, ?it/s]


In [19]:
x, y, last_targets = batch_for_few_shot(N, K ,batch_size, x, y)

In [20]:
x.shape

(1224, 1, 28, 28)

In [21]:
y.shape

(1224, 10)

In [22]:
model_output = model(x,y)

MXNetError: [15:09:27] src/operator/nn/concat.cc:66: Check failed: shape_assign(&(*in_shape)[i], dshape) Incompatible input shape: expected [1224,0,28,28], got [1224,10]

Stack trace returned 10 entries:
[bt] (0) /opt/venv/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x30cbe2) [0x7fb0035e0be2]
[bt] (1) /opt/venv/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x30d1b8) [0x7fb0035e11b8]
[bt] (2) /opt/venv/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x531348) [0x7fb003805348]
[bt] (3) /opt/venv/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x29bf312) [0x7fb005c93312]
[bt] (4) /opt/venv/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x29c8dee) [0x7fb005c9cdee]
[bt] (5) /opt/venv/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2900c4b) [0x7fb005bd4c4b]
[bt] (6) /opt/venv/lib/python3.6/site-packages/mxnet/libmxnet.so(MXImperativeInvokeEx+0x6f) [0x7fb005bd520f]
[bt] (7) /opt/miniconda/lib/libffi.so.6(ffi_call_unix64+0x4c) [0x7fb1ec772ec0]
[bt] (8) /opt/miniconda/lib/libffi.so.6(ffi_call+0x22d) [0x7fb1ec77287d]
[bt] (9) /usr/lib/python3.6/lib-dynload/_ctypes.cpython-36m-x86_64-linux-gnu.so(_ctypes_callproc+0x2fd) [0x7fb1ec98719d]



In [22]:
y.shape

(1224,)

In [32]:
last_targets


[6. 5. 0. 7. 6. 5. 2. 2. 3. 3. 7. 7. 0. 9. 6. 4. 2. 0. 6. 5. 3. 6. 1. 1.]
<NDArray 24 @cpu(0)>

In [23]:
x = np.load('x.npy')
x = nd.array(x)

In [24]:
y = np.load('y.npy')
labels = nd.array(y)

In [25]:
x.shape

(1224, 1, 28, 28)

In [26]:
labels.shape

(1224, 10)

In [28]:
snail = SNAIL(N=N,K=K,input_dims=64)
snail.collect_params().initialize()

In [29]:
out = snail(x,labels)

MXNetError: [15:10:26] src/operator/nn/concat.cc:66: Check failed: shape_assign(&(*in_shape)[i], dshape) Incompatible input shape: expected [1224,0,28,28], got [1224,10]

Stack trace returned 10 entries:
[bt] (0) /opt/venv/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x30cbe2) [0x7fb0035e0be2]
[bt] (1) /opt/venv/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x30d1b8) [0x7fb0035e11b8]
[bt] (2) /opt/venv/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x531348) [0x7fb003805348]
[bt] (3) /opt/venv/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x29bf312) [0x7fb005c93312]
[bt] (4) /opt/venv/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x29c8dee) [0x7fb005c9cdee]
[bt] (5) /opt/venv/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2900c4b) [0x7fb005bd4c4b]
[bt] (6) /opt/venv/lib/python3.6/site-packages/mxnet/libmxnet.so(MXImperativeInvokeEx+0x6f) [0x7fb005bd520f]
[bt] (7) /opt/miniconda/lib/libffi.so.6(ffi_call_unix64+0x4c) [0x7fb1ec772ec0]
[bt] (8) /opt/miniconda/lib/libffi.so.6(ffi_call+0x22d) [0x7fb1ec77287d]
[bt] (9) /usr/lib/python3.6/lib-dynload/_ctypes.cpython-36m-x86_64-linux-gnu.so(_ctypes_callproc+0x2fd) [0x7fb1ec98719d]



In [26]:
cnn_emb = CnnEmbedding()
cnn_emb.collect_params().initialize()

In [27]:
emb_out = cnn_emb(x)

In [28]:
out = snail(emb_out,labels)

In [33]:
out_label = out[:,-1,:]

In [36]:
out_label.shape

(24, 10)

In [29]:
out.shape

(24, 51, 10)

In [343]:
x.shape

(1224, 1, 28, 28)

In [344]:
labels.shape

(1224, 10)

In [345]:
cnn_emb = CnnEmbedding()
cnn_emb.collect_params().initialize()

In [346]:
emb_out = cnn_emb(x)

In [347]:
emb_out.shape

(1224, 64)

In [348]:
batch_size = int(labels.shape[0] / (N * K + 1))

In [349]:
labels.shape[0] 

1224

In [350]:
batch_size

24

In [351]:
last_idxs = [(i + 1) * (N * K + 1) - 1 for i in range(batch_size)]

In [352]:
x_cat = F.concat(emb_out,label,dim=1)

In [353]:
x_cat.shape

(1224, 74)

In [354]:
x_cat = x_cat.reshape((batch_size,N*K+1,-1))

In [355]:
x_cat.shape

(24, 51, 74)

In [356]:
attention1 = AttentionBlock(64, 32)
attention1.collect_params().initialize()

In [357]:
a1_out = attention1(x_cat)

In [358]:
a1_out.shape

(24, 51, 106)

In [359]:
106-32

74

In [253]:
tc1 = TCBlock(a1_out.shape[2] ,N*K+1 , 128)
tc1.collect_params().initialize()

In [259]:
tc1_out = tc1(a1_out)
tc1_out.shape

(24, 51, 618)

In [278]:
618 - 106

512

In [268]:
attention2 = AttentionBlock(256, 128)
attention2.collect_params().initialize()

In [269]:
a2_out = attention2(tc1_out)
a2_out.shape

(24, 51, 746)

In [270]:
tc2 = TCBlock(a2_out.shape[2] ,N*K+1 , 128)
tc2.collect_params().initialize()

In [271]:
tc2_out = tc2(a2_out)
tc2_out.shape

(24, 51, 1258)

In [272]:
attention3 = AttentionBlock(512, 256)
attention3.collect_params().initialize()

In [273]:
a3_out = attention3(tc2_out)
a3_out.shape

(24, 51, 1514)

In [274]:
fc = nn.Dense(N,flatten=False)
fc.collect_params().initialize()

In [275]:
out = fc(a3_out)

In [276]:
out.shape

(24, 51, 10)