In [1]:
import math
import mxnet as mx
from mxnet import gluon, autograd, nd
from mxnet.gluon import nn,utils
from mxnet.gluon.data.vision import transforms
import mxnet.ndarray as F
import numpy as np
import os, sys
import os
import numpy as np
import collections
from PIL import Image
import csv
import random
from mxnet.gluon.data import Dataset, DataLoader
from tqdm import tqdm
from random import randint

  from ._conv import register_converters as _register_converters


In [2]:
mx.__version__

'1.2.0'

In [4]:
class MiniImagenet(Dataset):
    """
    put mini-imagenet files as :
    root :
        |- images/*.jpg includes all imgeas
        |- train.csv
        |- test.csv
        |- val.csv
        NOTICE: meta-learning is different from general supervised learning, especially the concept of batch and set.
    batch: contains several sets
    sets: conains n_way * k_shot for meta-train set, n_way * n_query for meta-test set.
    """
    
    def __init__(self, root, mode, batchsz, n_way, k_shot, k_query, resize, ctx=mx.cpu(), startidx=0):
        """

        :param root: root path of mini-imagenet
        :param mode: train, val or test
        :param batchsz: batch size of sets, not batch of imgs
        :param n_way:
        :param k_shot:
        :param k_query: num of qeruy imgs per class
        :param resize: resize to
        :param ctx: context
        :param startidx: start to index label from startidx
        """

        self.batchsz = batchsz  # batch of set, not batch of imgs
        self.n_way = n_way  # n-way
        self.k_shot = k_shot  # k-shot
        self.k_query = k_query  # for evaluation
        self.setsz = self.n_way * self.k_shot  # num of samples per set
        self.querysz = self.n_way * self.k_query  # number of samples per set for evaluation
        self.resize = resize  # resize to
        self.ctx = ctx
        self.startidx = startidx  # index label not from 0, but from startidx
        print('shuffle DB :%s, b:%d, %d-way, %d-shot, %d-query, resize:%d' % (mode, batchsz, n_way, k_shot, k_query, resize))

        if mode == 'train':
            self.transform = transforms.Compose([lambda x: nd.array(np.array(Image.open(x).convert('RGB'))).astype(dtype=np.uint8),
                                                 transforms.Resize((self.resize, self.resize)),
                                                 # transforms.RandomHorizontalFlip(),
                                                # transforms.RandomRotation(5),
                                                 transforms.ToTensor(),
                                                 transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
                                                 ])
        else:
            self.transform = transforms.Compose([lambda x: nd.array(np.array(Image.open(x).convert('RGB'))).astype(dtype=np.uint8),
                                                 transforms.Resize((self.resize, self.resize)),
                                                 transforms.ToTensor(),
                                                 transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
                                                 ])

        self.path = os.path.join(root, 'images')  # image path
        csvdata = self.loadCSV(os.path.join(root, mode + '.csv'))  # csv path
        self.data = []
        self.img2label = {}
        for i, (k, v) in enumerate(csvdata.items()):
            self.data.append(v)  # [[img1, img2, ...], [img111, ...]]
            self.img2label[k] = i + self.startidx  # {"img_name[:9]":label}
        self.cls_num = len(self.data)

        self.create_batch(self.batchsz)

    def loadCSV(self, csvf):
        """
        return a dict saving the information of csv
        :param splitFile: csv file name
        :return: {label:[file1, file2 ...]}
        """
        dictLabels = {}
        with open(csvf) as csvfile:
            csvreader = csv.reader(csvfile, delimiter=',')
            next(csvreader, None)  # skip (filename, label)
            for i, row in enumerate(csvreader):
                filename = row[0]
                label = row[1]
                # append filename to current label
                if label in dictLabels.keys():
                    dictLabels[label].append(filename)
                else:
                    dictLabels[label] = [filename]
        return dictLabels

    def create_batch(self, batchsz):
        """
        create batch for meta-learning.
        ×episode× here means batch, and it means how many sets we want to retain.
        :param episodes: batch size
        :return:
        """
        self.support_x_batch = []  # support set batch
        self.query_x_batch = []  # query set batch
        for b in range(batchsz):  # for each batch
            # 1.select n_way classes randomly
            selected_cls = np.random.choice(self.cls_num, self.n_way, False)  # no duplicate
            np.random.shuffle(selected_cls)
            support_x = []
            query_x = []
            for cls in selected_cls:
                # 2. select k_shot + k_query for each class
                selected_imgs_idx = np.random.choice(len(self.data[cls]), self.k_shot + self.k_query, False)
                np.random.shuffle(selected_imgs_idx)
                indexDtrain = np.array(selected_imgs_idx[:self.k_shot])  # idx for Dtrain
                indexDtest = np.array(selected_imgs_idx[self.k_shot:])  # idx for Dtest
                support_x.append(
                    np.array(self.data[cls])[indexDtrain].tolist())  # get all images filename for current Dtrain
                query_x.append(np.array(self.data[cls])[indexDtest].tolist())

            # shuffle the correponding relation between support set and query set
            random.shuffle(support_x)
            random.shuffle(query_x)

            self.support_x_batch.append(support_x)  # append set to current sets
            self.query_x_batch.append(query_x)  # append sets to current sets

    def __getitem__(self, index):
        """
        index means index of sets, 0<= index <= batchsz-1
        :param index:
        :return:
        """
        # [setsz, 3, resize, resize]
        #support_x = torch.FloatTensor(self.setsz, 3, self.resize, self.resize)
        support_x = nd.empty(shape=(self.setsz, 3, self.resize, self.resize))
        # [setsz]
        support_y = np.zeros((self.setsz), dtype=np.int)
        # [querysz, 3, resize, resize]
        #query_x = torch.FloatTensor(self.querysz, 3, self.resize, self.resize)
        query_x = nd.empty(shape=(self.querysz, 3, self.resize, self.resize))
        # [querysz]
        query_y = np.zeros((self.querysz), dtype=np.int)

        flatten_support_x = [os.path.join(self.path, item)
                             for sublist in self.support_x_batch[index] for item in sublist]
        support_y = np.array(
            [self.img2label[item[:9]]  # filename:n0153282900000005.jpg, the first 9 characters treated as label
            for sublist in self.support_x_batch[index] for item in sublist]).astype(np.int32)

        flatten_query_x = [os.path.join(self.path, item)
                           for sublist in self.query_x_batch[index] for item in sublist]
        query_y = np.array([self.img2label[item[:9]]
                            for sublist in self.query_x_batch[index] for item in sublist]).astype(np.int32)


        # print('global:', support_y, query_y)
        # support_y: [setsz]
        # query_y: [querysz]
        # unique: [n-way], sorted
        unique = np.unique(support_y)
        # relative means the label ranges from 0 to n-way
        support_y_relative = np.zeros(self.setsz)
        query_y_relative = np.zeros(self.querysz)
        for idx, l in enumerate(unique):
            support_y_relative[support_y==l] = idx
            query_y_relative[query_y==l] = idx

        # print('relative:', support_y_relative, query_y_relative)


        for i, path in enumerate(flatten_support_x):
            #print(path)
            support_x[i] = self.transform(path)

        for i, path in enumerate(flatten_query_x):
            query_x[i] = self.transform(path)
        # print(support_set_y)
        # return support_x, torch.LongTensor(support_y), query_x, torch.LongTensor(query_y)

        #return support_x, torch.LongTensor(support_y_relative), query_x, torch.LongTensor(query_y_relative)
        return support_x, nd.array(support_y_relative,ctx=self.ctx), query_x, nd.array(query_y_relative,ctx=self.ctx)
    def __len__(self):
        # as we have built up to batchsz of sets, you can sample some small batch size of sets.
        return self.batchsz

In [5]:
class CasualConv1d(nn.Block):
    
    def __init__(self,in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True,**kwargs):
        super(CasualConv1d,self).__init__(**kwargs)
        self.dilation = dilation
        self.padding = dilation * (kernel_size - 1)
        
        with self.name_scope():
            self.casual_conv = nn.Conv1D(in_channels=in_channels,channels=out_channels,kernel_size=kernel_size,padding = self.padding, dilation = dilation, groups=groups, use_bias=bias)
            
    def forward(self,x):
        out = self.casual_conv(x)
        return out[:,:,:-self.dilation]

In [6]:
class DenseBlock(nn.Block):
    def __init__(self, in_channels, filters, dilation=1, kernel_size=2, **kwargs):
        super(DenseBlock,self).__init__(**kwargs)
        
        with self.name_scope():
            self.casual_conv1 = CasualConv1d(in_channels, filters, kernel_size, dilation = dilation)
            self.casual_conv2 = CasualConv1d(in_channels, filters, kernel_size, dilation = dilation)
            
    def forward(self, x):
        #print(x.shape)
        tanh = F.tanh(self.casual_conv1(x))
        sigmoid = F.sigmoid(self.casual_conv1(x))
        out =  F.concat(x,tanh*sigmoid, dim=1)
        #print("Dense success")
        return out  

In [7]:
class TCBlock(nn.Block):
    def __init__(self, in_channels,seq_len, filters, **kwargs):
        super(TCBlock,self).__init__(**kwargs)
        layer_count = int(math.ceil(math.log(seq_len)))
        channel_count = in_channels
        with self.name_scope():
            self.blocks = nn.Sequential()
            for i in range(layer_count):
                self.blocks.add(DenseBlock(in_channels + i * filters, filters, dilation =2 ** (i+1)))
                
    def forward(self, x):
        x = x.swapaxes(1,2)
        out = self.blocks(x)
        return out.swapaxes(1,2)

In [8]:
class AttentionBlock(nn.Block):
    def __init__(self, k_size, v_size,ctx=mx.cpu(),show_shape=False, **kwargs):
        super(AttentionBlock,self).__init__(**kwargs)
        self.ctx = ctx
        self.sqrt_k = math.sqrt(k_size)
        self.show_shape = False
        with self.name_scope():
            self.key_layer = nn.Dense(k_size,flatten=False)
            self.query_layer = nn.Dense(k_size,flatten=False)
            self.value_layer = nn.Dense(v_size,flatten=False)
            
    
    def forward(self, x):
        with x.context:
            #x = x.swapaxes(1,2)
            keys = self.key_layer(x)       
            queries = self.query_layer(x)
            values = self.value_layer(x)
            logits = nd.linalg_gemm2(queries,keys.swapaxes(2,1))
            if self.show_shape:
                print("keys shape:{}".format(keys.shape))
                print("queries shape:{}".format(queries.shape))
                print("logits shape:{}".format(logits.shape))
            #Generate masking part 
            mask = np.full(shape=(logits.shape[1],logits.shape[2]),fill_value=1).astype('float')
            mask = np.triu(mask,1)
            mask = np.expand_dims(mask,0)
            mask = np.repeat(mask,logits.shape[0],0)
            np.place(mask,mask==1,0.0)
            np.place(mask,mask==0,1.0)
            mask = nd.array(mask)
            logits =  F.elemwise_mul(logits,mask)
            probs = F.softmax(logits / self.sqrt_k, axis=2)
            if self.show_shape:
                print("probs shape:{}".format(probs.shape))
                print("values shape:{}".format(values.shape))
            read = nd.linalg_gemm2(probs,values)
            concat_data = F.concat(x,read,dim=2)
            return concat_data
            #return queries,probs

In [9]:
class Residual_Block(nn.Block):
    def __init__(self,filters, pool_padding=0,**kwargs):
        super(Residual_Block,self).__init__( **kwargs)
        with self.name_scope():
            self.conv1 = nn.Conv2D(filters,kernel_size=3,strides=1,padding=1,use_bias=False)
            self.bn1 =  nn.BatchNorm()
            self.relu1 = nn.LeakyReLU(alpha=0.1)
            self.conv2 = nn.Conv2D(filters,kernel_size=3,strides=1,padding=1,use_bias=False)
            self.bn2 = nn.BatchNorm()
            self.relu2 = nn.LeakyReLU(alpha=0.1)
            self.conv3 = nn.Conv2D(filters,kernel_size=3,strides=1,padding=1,use_bias=False)
            self.bn3 = nn.BatchNorm()
            self.relu3 = nn.LeakyReLU(alpha=0.1)
            self.conv4 = nn.Conv2D(filters,kernel_size=1,strides=1,padding=0,use_bias=False)
            
            self.max_pooling = nn.MaxPool2D(2, padding=pool_padding)
            self.dropout = nn.Dropout(rate=0.9)
            
    def forward(self, x):
        residual = self.conv4(x)
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu2(out)
        
        out = self.conv3(out)
        out = self.bn3(out)
        out = self.relu3(out)
        
        out = out + residual
        out = self.max_pooling(out)
        out = self.dropout(out)
        
        return out

In [10]:
class MiniImageNet_emb(nn.Block):
    
    def __init__(self,**kwargs):
        super(MiniImageNet_emb,self).__init__(**kwargs)
        with self.name_scope():
            self.block1 = Residual_Block(64)
            self.block2 = Residual_Block(96)
            self.block3 = Residual_Block(128, pool_padding=1)
            self.block4 = Residual_Block(256, pool_padding=1)
            self.conv1 = nn.Conv2D(2048,kernel_size=1,padding=1)
            self.max_pooling = nn.MaxPool2D(6)
            self.relu = nn.LeakyReLU(alpha=0)
            self.dropout = nn.Dropout(rate=0.9)
            self.conv2 = nn.Conv2D(384,kernel_size=1)
            
    
    def forward(self,x):
        out = self.block1(x)
        out = self.block2(out)
        out = self.block3(out)
        out = self.block4(out)
        out = self.conv1(out)
        out = self.max_pooling(out)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.conv2(out)
        return out.reshape(out.shape[0],-1)

In [11]:
class SNAIL(nn.Block):
    def __init__(self,N,K,input_dims,**kwargs):
        super(SNAIL,self).__init__(**kwargs)
        self.N = N
        self.K = K
        self.num_filters = int(math.ceil(math.log(N * K + 1)))
        self.ctx = ctx
        self.num_channels = input_dims + N
        with self.name_scope():
            self.cnn_emb = MiniImageNet_emb()
            self.attn1 = AttentionBlock(64, 32, ctx=self.ctx)
            attn1_out_shape = self.num_channels + 32
            self.tc1 = TCBlock(attn1_out_shape ,N*K+1 , 128)
            tc1_out_shape = attn1_out_shape + self.num_filters * 128
            self.attn2 = AttentionBlock(256, 128, ctx=self.ctx)
            attn2_out_shape = tc1_out_shape + 128
            self.tc2 = TCBlock(attn2_out_shape ,N*K+1 , 128)
            tc2_out_shape = attn2_out_shape + self.num_filters * 128
            self.attn3 = AttentionBlock(512, 256, ctx=self.ctx)
            attn3_out_shape = tc2_out_shape + 128
            self.fc = nn.Dense(N,flatten=False)
                        
    def forward(self, x, labels):
        with x.context:
            batch_size = int(labels.shape[0] / (N * K + 1))
            last_idxs = [(i + 1) * (N * K + 1) - 1 for i in range(batch_size)]
            labels[last_idxs] = nd.zeros(shape=(batch_size, labels.shape[1]))
            x = self.cnn_emb(x)
            x = F.concat(x,labels,dim=1)
            x = x.reshape((batch_size,N*K+1,-1))
            x = self.attn1(x)
            x = self.tc1(x)
            x = self.attn2(x)
            x = self.tc2(x)
            x = self.attn3(x)
            x = self.fc(x)
        
        return x
        

In [12]:
os.chdir('/home/skinet/work/datasets/')

In [13]:
N = 10     #num_class
K = 5  #num_samples
k_query = 1
iterations = 1000
batch_size = 4
GPU_INDEX = [4,5,6,7]
ctx = [mx.gpu(i) for i in GPU_INDEX]

iterations = int(iterations / len(ctx))
batch_size = int(batch_size * len(ctx))
epoches = 200

In [14]:
mini = MiniImagenet('mini-imagenet/', mode='train', n_way=N, k_shot=K, k_query=k_query, batchsz=iterations, resize=84)
mini_test = MiniImagenet('mini-imagenet/', mode='test', n_way=N, k_shot=K, k_query=k_query, batchsz=iterations, resize=84)

shuffle DB :train, b:250, 10-way, 5-shot, 1-query, resize:84
shuffle DB :test, b:250, 10-way, 5-shot, 1-query, resize:84


In [15]:
db = DataLoader(mini, batch_size=batch_size, shuffle=True, num_workers=1)
db_test = DataLoader(mini_test, batch_size=batch_size, shuffle=True, num_workers=1)

In [16]:
def snail_data_generation(batch, N):
    iter_idx = 0
    batch_size = batch[0].shape[0]
    support_x = batch[0].asnumpy()
    support_y = batch[1].asnumpy()
    query_x = batch[2].asnumpy()
    query_y =  batch[3].asnumpy()

    for i in range(batch_size):
        query_idx = randint(0,N-1)
        if iter_idx == 0 :
            x = support_x[i]
            x = np.append(x,np.expand_dims(query_x[i][query_idx],axis=0),axis=0)
            y = support_y[i]
            y = np.append(y,query_y[i][query_idx])
            iter_idx +=1
        else :
            x = np.append(x,support_x[i],axis=0)
            x = np.append(x,np.expand_dims(query_x[i][query_idx],axis=0),axis=0)
            y = np.append(y,support_y[i])
            y = np.append(y,query_y[i][query_idx])
            iter_idx +=1
    return x, y

In [17]:
def batch_for_few_shot(num_cls,num_samples,batch_size, x, y):
    seq_size = num_cls * num_samples + 1
    one_hots = []
    last_targets = []
    for i in range(batch_size):
        one_hot, idxs = labels_to_one_hot(y[i * seq_size: (i + 1) * seq_size])
        one_hots.append(one_hot)
        last_targets.append(idxs[-1])
    one_hots = [np.array(temp) for temp in one_hots]

    y = np.concatenate(one_hots,0)
    x = nd.array(x)
    y = nd.array(y)
    last_targets = nd.array(np.array(last_targets))
    return x, y, last_targets

In [18]:
def labels_to_one_hot(labels):
    labels = labels
    unique = np.unique(labels)
    map = {label:idx for idx, label in enumerate(unique)}
    idxs = [map[labels[i]] for i in range(labels.size)]
    one_hot = np.zeros((labels.size, unique.size))
    one_hot[np.arange(labels.size), idxs] = 1
    return one_hot, idxs

In [26]:
for step, batch in enumerate(tqdm(db_test)):
    x,y = snail_data_generation(batch,N)
    x, y, last_targets = batch_for_few_shot(N, K ,batch[0].shape[0], x, y)



  0%|          | 0/16 [00:00<?, ?it/s][A[A

  6%|▋         | 1/16 [00:08<02:02,  8.15s/it][A[A

 12%|█▎        | 2/16 [00:15<01:45,  7.56s/it][A[A

 19%|█▉        | 3/16 [00:22<01:36,  7.40s/it][A[A

 25%|██▌       | 4/16 [00:29<01:29,  7.42s/it][A[A

 31%|███▏      | 5/16 [00:37<01:22,  7.51s/it][A[A

 38%|███▊      | 6/16 [00:45<01:15,  7.54s/it][A[A

 44%|████▍     | 7/16 [00:53<01:08,  7.58s/it][A[A

 50%|█████     | 8/16 [01:00<01:00,  7.52s/it][A[A

 56%|█████▋    | 9/16 [01:07<00:52,  7.53s/it][A[A

 62%|██████▎   | 10/16 [01:15<00:45,  7.52s/it][A[A

 69%|██████▉   | 11/16 [01:22<00:37,  7.50s/it][A[A

 75%|███████▌  | 12/16 [01:29<00:29,  7.42s/it][A[A

 81%|████████▏ | 13/16 [01:37<00:22,  7.46s/it][A[A

 88%|████████▊ | 14/16 [01:44<00:14,  7.49s/it][A[A

 94%|█████████▍| 15/16 [01:52<00:07,  7.52s/it][A[A

100%|██████████| 16/16 [01:57<00:00,  7.33s/it][A[A

[A[A

In [21]:
batch[0].shape

(16, 50, 3, 84, 84)

In [25]:
batch[0].shape[0]

10

In [22]:
16*50

800

In [19]:
model = SNAIL(N=N,K=K,input_dims=384)
model.collect_params().initialize(ctx=ctx)
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
trainer = gluon.Trainer(model.collect_params(),optimizer='Adam',optimizer_params={'learning_rate':0.0001})

In [20]:
train_acc = mx.metric.Accuracy()
test_acc = mx.metric.Accuracy()
global_va_acc = 0.0
for epoch in range(epoches):
    tr_acc = list()
    te_acc = list()
    for step, batch in enumerate(tqdm(db)):
        x,y = snail_data_generation(batch,N)
        x, y, last_targets = batch_for_few_shot(N, K ,batch[0].shape[0], x, y)       
        with autograd.record():
            x_split = gluon.utils.split_and_load(x,ctx)
            y_split = gluon.utils.split_and_load(y,ctx)
            last_targets_split = gluon.utils.split_and_load(last_targets,ctx)
            last_model = [model(X,Y)[:,-1,:] for X, Y in zip(x_split,y_split)]
            loss_val = [loss_fn(X,Y) for X, Y in zip(last_model,last_targets_split)]
            #loss_val = [loss_fn(model(X,Y)[:,-1,:],L) for X, Y, L in zip(x_split,y_split,last_targets_split)]
            for l in loss_val:
                l.backward()
            for pred,target in zip(last_model,last_targets_split):
                train_acc.update(preds=nd.argmax(pred,1),labels=target)
                tr_acc.append(train_acc.get()[1])
        trainer.step(batch_size,ignore_stale_grad=True)
        
    for step, batch in enumerate(tqdm(db_test)):
        x,y = snail_data_generation(batch,N)
        x, y, last_targets = batch_for_few_shot(N, K ,batch[0].shape[0], x, y)
        x = x.copyto(ctx[0])
        y = y.copyto(ctx[0])
        last_targets = last_targets.copyto(ctx[0])
        model_output = model(x,y)
        last_model = model_output[:,-1,:]
        test_acc.update(preds=nd.argmax(last_model,1),labels=last_targets)
        te_acc.append(test_acc.get()[1])
    current_va_acc = np.mean(te_acc)
    if global_va_acc < current_va_acc:
        filename = '/home/skinet/work/research/SNAIL/imagenet_models/best_perf_epoch_'+str(epoch)+"_tr_acc_"+str(round(np.mean(tr_acc),2))+"_te_acc_"+str(round(np.mean(te_acc),2))
        model.save_params(filename)
        global_va_acc = current_va_acc
    print("epoch {e}  train_acc:{ta} test_acc:{tea} ".format(e=epoch,ta=np.mean(tr_acc),tea=np.mean(te_acc)))


  0%|          | 0/16 [00:00<?, ?it/s][A

OSError: [Errno 28] No space left on device

In [33]:
batch[0].shape

(16, 50, 3, 84, 84)

In [34]:
x,y = snail_data_generation(batch,N)

In [38]:
y.shape

(816,)

In [36]:
N

10

In [35]:
x.shape

(816, 3, 84, 84)

In [27]:
batch[0].shape

(16, 50, 3, 84, 84)

In [18]:
for step, batch in enumerate(tqdm(db_test)):
    x,y = snail_data_generation(batch,N)
    x, y, last_targets = batch_for_few_shot(N, K ,batch_size, x, y)

  0%|          | 0/50 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [26]:
labels.shape

(1224, 10)