In [1]:
# -*- coding: utf-8 -*-
import math
import numpy as np
import pandas as pd

In [2]:
import mxnet as mx
from mxnet import gluon, autograd, nd
from mxnet.gluon import nn,utils
from mxnet.gluon.data.vision import transforms
import mxnet.ndarray as F
import numpy as np
import os, sys
import os
import numpy as np
import collections
from PIL import Image
import csv
import random
from mxnet.gluon.data import Dataset, DataLoader

  from ._conv import register_converters as _register_converters


In [3]:
mx.__version__

'1.2.0'

In [4]:
class MiniImagenet(Dataset):
	"""
	put mini-imagenet files as :
	root :
		|- images/*.jpg includes all imgeas
		|- train.csv
		|- test.csv
		|- val.csv
	NOTICE: meta-learning is different from general supervised learning, especially the concept of batch and set.
	batch: contains several sets
	sets: conains n_way * k_shot for meta-train set, n_way * n_query for meta-test set.
	"""

	def __init__(self, root, mode, batchsz, n_way, k_shot, k_query, resize, ctx=mx.cpu(), startidx=0):
		"""

		:param root: root path of mini-imagenet
		:param mode: train, val or test
		:param batchsz: batch size of sets, not batch of imgs
		:param n_way:
		:param k_shot:
		:param k_query: num of qeruy imgs per class
		:param resize: resize to
        :param ctx: context
		:param startidx: start to index label from startidx
		"""

		self.batchsz = batchsz  # batch of set, not batch of imgs
		self.n_way = n_way  # n-way
		self.k_shot = k_shot  # k-shot
		self.k_query = k_query  # for evaluation
		self.setsz = self.n_way * self.k_shot  # num of samples per set
		self.querysz = self.n_way * self.k_query  # number of samples per set for evaluation
		self.resize = resize  # resize to
		self.ctx = ctx
		self.startidx = startidx  # index label not from 0, but from startidx
		print('shuffle DB :%s, b:%d, %d-way, %d-shot, %d-query, resize:%d' % (mode, batchsz, n_way, k_shot, k_query, resize))

		if mode == 'train':
			self.transform = transforms.Compose([lambda x: nd.array(np.array(Image.open(x).convert('RGB')),ctx=self.ctx),
			                                     transforms.Resize((self.resize, self.resize)),
		                                         # transforms.RandomHorizontalFlip(),
		                                         # transforms.RandomRotation(5),
			                                     transforms.ToTensor(),
			                                     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
			                                     ])
		else:
			self.transform = transforms.Compose([lambda x: nd.array(np.array(Image.open(x).convert('RGB')),ctx=self.ctx),
			                                     transforms.Resize((self.resize, self.resize)),
			                                     transforms.ToTensor(),
			                                     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
			                                     ])

		self.path = os.path.join(root, 'images')  # image path
		csvdata = self.loadCSV(os.path.join(root, mode + '.csv'))  # csv path
		self.data = []
		self.img2label = {}
		for i, (k, v) in enumerate(csvdata.items()):
			self.data.append(v)  # [[img1, img2, ...], [img111, ...]]
			self.img2label[k] = i + self.startidx  # {"img_name[:9]":label}
		self.cls_num = len(self.data)

		self.create_batch(self.batchsz)

	def loadCSV(self, csvf):
		"""
		return a dict saving the information of csv
		:param splitFile: csv file name
		:return: {label:[file1, file2 ...]}
		"""
		dictLabels = {}
		with open(csvf) as csvfile:
			csvreader = csv.reader(csvfile, delimiter=',')
			next(csvreader, None)  # skip (filename, label)
			for i, row in enumerate(csvreader):
				filename = row[0]
				label = row[1]
				# append filename to current label
				if label in dictLabels.keys():
					dictLabels[label].append(filename)
				else:
					dictLabels[label] = [filename]
		return dictLabels

	def create_batch(self, batchsz):
		"""
		create batch for meta-learning.
		×episode× here means batch, and it means how many sets we want to retain.
		:param episodes: batch size
		:return:
		"""
		self.support_x_batch = []  # support set batch
		self.query_x_batch = []  # query set batch
		for b in range(batchsz):  # for each batch
			# 1.select n_way classes randomly
			selected_cls = np.random.choice(self.cls_num, self.n_way, False)  # no duplicate
			np.random.shuffle(selected_cls)
			support_x = []
			query_x = []
			for cls in selected_cls:
				# 2. select k_shot + k_query for each class
				selected_imgs_idx = np.random.choice(len(self.data[cls]), self.k_shot + self.k_query, False)
				np.random.shuffle(selected_imgs_idx)
				indexDtrain = np.array(selected_imgs_idx[:self.k_shot])  # idx for Dtrain
				indexDtest = np.array(selected_imgs_idx[self.k_shot:])  # idx for Dtest
				support_x.append(
					np.array(self.data[cls])[indexDtrain].tolist())  # get all images filename for current Dtrain
				query_x.append(np.array(self.data[cls])[indexDtest].tolist())

			# shuffle the correponding relation between support set and query set
			random.shuffle(support_x)
			random.shuffle(query_x)

			self.support_x_batch.append(support_x)  # append set to current sets
			self.query_x_batch.append(query_x)  # append sets to current sets

	def __getitem__(self, index):
		"""
		index means index of sets, 0<= index <= batchsz-1
		:param index:
		:return:
		"""
		# [setsz, 3, resize, resize]
		#support_x = torch.FloatTensor(self.setsz, 3, self.resize, self.resize)
		support_x = nd.empty(shape=(self.setsz, 3, self.resize, self.resize))
		# [setsz]
		support_y = np.zeros((self.setsz), dtype=np.int)
		# [querysz, 3, resize, resize]
		#query_x = torch.FloatTensor(self.querysz, 3, self.resize, self.resize)
		query_x = nd.empty(shape=(self.querysz, 3, self.resize, self.resize))
		# [querysz]
		query_y = np.zeros((self.querysz), dtype=np.int)

		flatten_support_x = [os.path.join(self.path, item)
		                     for sublist in self.support_x_batch[index] for item in sublist]
		support_y = np.array(
			[self.img2label[item[:9]]  # filename:n0153282900000005.jpg, the first 9 characters treated as label
			 for sublist in self.support_x_batch[index] for item in sublist]).astype(np.int32)

		flatten_query_x = [os.path.join(self.path, item)
		                   for sublist in self.query_x_batch[index] for item in sublist]
		query_y = np.array([self.img2label[item[:9]]
		                    for sublist in self.query_x_batch[index] for item in sublist]).astype(np.int32)


		# print('global:', support_y, query_y)
		# support_y: [setsz]
		# query_y: [querysz]
		# unique: [n-way], sorted
		unique = np.unique(support_y)
		# relative means the label ranges from 0 to n-way
		support_y_relative = np.zeros(self.setsz)
		query_y_relative = np.zeros(self.querysz)
		for idx, l in enumerate(unique):
			support_y_relative[support_y==l] = idx
			query_y_relative[query_y==l] = idx

		# print('relative:', support_y_relative, query_y_relative)


		for i, path in enumerate(flatten_support_x):
			#print(path)
			support_x[i] = self.transform(path)

		for i, path in enumerate(flatten_query_x):
			query_x[i] = self.transform(path)
		# print(support_set_y)
		# return support_x, torch.LongTensor(support_y), query_x, torch.LongTensor(query_y)

		#return support_x, torch.LongTensor(support_y_relative), query_x, torch.LongTensor(query_y_relative)
		return support_x, nd.array(support_y_relative,ctx=self.ctx), query_x, nd.array(query_y_relative,ctx=self.ctx)
	def __len__(self):
		# as we have built up to batchsz of sets, you can sample some small batch size of sets.
		return self.batchsz

In [5]:
os.chdir('/home/skinet/work/datasets/')

In [6]:
##setting hyperparameters
n_way = 5
k_shot = 5
k_query = 2
batch_size = 10

In [7]:
mini = MiniImagenet('mini-imagenet/', mode='train', n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=5000, resize=84)

shuffle DB :train, b:5000, 5-way, 5-shot, 2-query, resize:84


In [8]:
db = DataLoader(mini, batch_size=batch_size, shuffle=True, num_workers=1)

In [9]:
class CasualConv1d(nn.Block):
    
    def __init__(self,in_channels,out_channels, dilation=1,**kwargs):
        super(CasualConv1d,self).__init__(**kwargs)
        self.padding = dilation
        
        with self.name_scope():
            self.casual_conv = nn.Conv1D(in_channels=in_channels,channels=out_channels,kernel_size=2,padding = self.padding, dilation = dilation)
            
    def forward(self,x):
        return self.casual_conv(x)[:,:,:-self.padding]
         

In [10]:
class DenseBlock(nn.Block):
    def __init__(self, in_channels, filters, dilation=1, **kwargs):
        super(DenseBlock,self).__init__(**kwargs)
        
        with self.name_scope():
            self.casual_conv1 = CasualConv1d(in_channels, filters,dilation = dilation)
            self.casual_conv2 = CasualConv1d(in_channels, filters,dilation = dilation)
            
    def forward(self, x):
        #print(x.shape)
        tanh = F.tanh(self.casual_conv1(x))
        sigmoid = F.sigmoid(self.casual_conv1(x))
        out =  F.concat(x,tanh*sigmoid, dim=1)
        #print("Dense success")
        return out   

In [11]:
class TCBlock(nn.Block):
    def __init__(self, in_channels,seq_len, filters, **kwargs):
        super(TCBlock,self).__init__(**kwargs)
        layer_count = math.ceil(math.log(seq_len)/math.log(2))
        channel_count = in_channels
        with self.name_scope():
            self.blocks = nn.Sequential()
            for layer in range(layer_count):
                self.blocks.add(DenseBlock(channel_count, filters, dilation = 2**layer))
                channel_count += filters
                
    def forward(self, x):
        return self.blocks(x)

In [12]:
class AttentionBlock(nn.Block):
    def __init__(self, k_size, v_size, **kwargs):
        super(AttentionBlock,self).__init__(**kwargs)
        
        with self.name_scope():
            self.key_layer = nn.Dense(k_size,flatten=False)
            self.query_layer = nn.Dense(k_size,flatten=False)
            self.value_layer = nn.Dense(v_size,flatten=False)
            self.sqrt_k = math.sqrt(k_size)
    
    def forward(self, x):
        x = x.swapaxes(1,2)
        keys = self.key_layer(x)       
        queries = self.query_layer(x)
        values = self.value_layer(x)
        #print(keys.shape)
        #print(queries.shape)
        logits = nd.linalg_gemm2(queries,keys.swapaxes(2,1))
        #print(logits.shape)
        mask = np.full(shape=(logits.shape[1],logits.shape[2]),fill_value=1).astype('float')
        mask = np.triu(mask,1)
        mask = np.expand_dims(mask,0)
        mask = np.repeat(mask,logits.shape[0],0)
        np.place(mask,mask==1,0.0)
        np.place(mask,mask==0,1.0)
        #np.place(mask,mask==0,0.0)
        #np.place(mask,mask==1,1.0)
        mask = nd.array(mask,ctx=mx.gpu())
        logits =  F.elemwise_mul(logits,mask)
        probs = F.softmax(logits / self.sqrt_k, axis=2)
        #print(probs.shape)
        #print(values.shape)
        read = nd.linalg_gemm2(probs,values)
        concat_data = F.concat(x,read,dim=2)
        return concat_data.swapaxes(2,1)
        #return queries,probs


In [13]:
class CnnEmbedding(nn.Block):
    
    def __init__(self,**kwargs):
        super(CnnEmbedding,self).__init__(**kwargs)
        with self.name_scope():
            self.cnn1 = nn.Conv2D(64,3)
            self.max1 = nn.MaxPool2D(ceil_mode=True)
            self.cnn2 = nn.Conv2D(64,3)
            self.max2 = nn.MaxPool2D(ceil_mode=True)
            self.fc = nn.Dense(64)
    
    def forward(self,x):
        out = self.cnn1(x)
        out = self.max1(out)
        out = self.cnn2(out)
        out = self.max2(out)
        out = self.fc(out)
        return out.expand_dims(1)

In [14]:
class SNAIL(nn.Block):
    def __init__(self,N,K,**kwargs):
        super(SNAIL,self).__init__(**kwargs)
        T = N * K + 1
        layer_count = math.ceil(math.log(T)/math.log(2))
        with self.name_scope():
            self.cnn_emb =  CnnEmbedding()
            self.mod0 = AttentionBlock(64,32)
            self.mod1 = TCBlock(1+32,T,128)
            self.mod2 = AttentionBlock(256,128)
            self.mod3 = TCBlock(801,T,128)
            self.mod4 = AttentionBlock(512,256)
            self.conv1d = nn.Conv1D(N,1)
    
    def forward(self,x):
        out = self.cnn_emb(x)
        #print("cnn_emb output size {}".format(out.shape))
        out = self.mod0(out)
        #print("mod0 output size {}".format(out.shape))
        out = self.mod1(out)
        #print("mod1 output size {}".format(out.shape))
        out = self.mod2(out)
        #print("mod2 output size {}".format(out.shape))
        out = self.mod3(out)
        #print("mod3 output size {}".format(out.shape))
        out = self.mod4(out)
        #print("mod4 output size {}".format(out.shape))
        out = self.conv1d(out)
        #print(out.shape)
        out = out.sum(2)
        #print("output size {}".format(out.shape))
        return out

In [15]:
ctx = mx.gpu()

In [16]:
N = 5
K = 5
max_grad_norm = 10

In [17]:
T = N * K + 1

In [18]:
T

26

In [19]:
model = SNAIL(N, K)
model.collect_params().initialize(ctx=ctx)
loss = gluon.loss.SoftmaxCrossEntropyLoss()
trainer = gluon.Trainer(model.collect_params(),optimizer='Adam',optimizer_params={'learning_rate':0.0001})

In [27]:
perf_dis = dict()
perf_dis['train_loss'] = list()
perf_dis['val_loss'] = list()
perf_dis['train_acc'] = list()
perf_dis['val_acc'] = list()

In [27]:
global_va_acc = 0.0

for epoch in range(100):
    loss_mean = list()
    va_loss_mean = list()
    tr_acc = list()
    va_acc = list()
    train_acc = mx.metric.Accuracy()
    val_acc = mx.metric.Accuracy()
    cnt = 0
    for step, batch in enumerate(db):
        for i in range(batch_size):
            support_x = batch[0][i]
            support_y = batch[1][i]
            #query_x = batch[2][i]
            #query_y = batch[3][i]
            with autograd.record():
                data = support_x.copyto(ctx)
                label = support_y.copyto(ctx)
                output = model(data)
                loss_val = loss(output,label)
            train_acc.update(preds=nd.argmax(output,1),labels=label)
            tr_acc.append(train_acc.get()[1])
            loss_mean.append(nd.mean(loss_val).asscalar())
            
        grads = [i.grad() for i in model.collect_params().values()]
        gluon.utils.clip_global_norm(grads, max_grad_norm)
        loss_val.backward()
        trainer.step(batch_size,ignore_stale_grad=True) 
        
        for i in range(batch_size):
            #support_x = batch[0][i]
            #support_y = batch[1][i]
            query_x = batch[2][i]
            query_y = batch[3][i]
            va_data = query_x.copyto(ctx)
            va_label = query_y.copyto(ctx)
            va_output = model(va_data)
            va_loss_val = loss(va_output,va_label)
            va_loss_mean.append(nd.mean(va_loss_val).asscalar())
            val_acc.update(preds=nd.argmax(va_output,1),labels=va_label)
            va_acc.append(val_acc.get()[1])       
        current_va_acc = round(np.mean(va_acc),2)
        if global_va_acc < current_va_acc:
            filename = '/home/skinet/work/Research/SNAIL/models/best_perf_epoch_'+str(epoch)+"_iter_"+str(cnt)+"_val_acc_"+str(current_va_acc)
            model.save_params(filename)
            global_va_acc = current_va_acc
        if(cnt %100 == 0):
            print("epoch {e} iter {it} train loss :{l} train_acc:{ta} val loss : {vl} val_acc :{va}".format(e=epoch,it=cnt,l=np.mean(loss_mean),vl=np.mean(va_loss_mean),ta=np.mean(tr_acc),va=np.mean(va_acc)))
            
        if(cnt %20 == 0):
            perf_dis['train_loss'].append(np.mean(loss_mean))
            perf_dis['val_loss'].append(np.mean(va_loss_mean))
            perf_dis['train_acc'].append(np.mean(tr_acc))
            perf_dis['val_acc'].append(np.mean(va_acc))
        cnt = cnt+1

Process Process-14:
  File "<ipython-input-4-1175bd14f7c6>", line 166, in __getitem__
    query_x[i] = self.transform(path)
Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()


KeyboardInterrupt: 

  File "/usr/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/venv/lib/python3.6/site-packages/mxnet/gluon/data/dataloader.py", line 119, in worker_loop
    batch = batchify_fn([dataset[i] for i in samples])
  File "/opt/venv/lib/python3.6/site-packages/mxnet/gluon/data/dataloader.py", line 119, in <listcomp>
    batch = batchify_fn([dataset[i] for i in samples])
  File "/opt/venv/lib/python3.6/site-packages/mxnet/gluon/block.py", line 413, in __call__
    return self.forward(*args)
  File "/opt/venv/lib/python3.6/site-packages/mxnet/gluon/nn/basic_layers.py", line 53, in forward
    x = block(x)
  File "<ipython-input-4-1175bd14f7c6>", line 40, in <lambda>
    self.transform = transforms.Compose([lambda x: nd.array(np.array(Image.open(x).convert('RGB')),ctx=self.ctx),
  File "/opt/venv/lib/python3.6/site-packages/mxnet/ndarray/utils.py", line 146, in array
    return _array(source_array, ctx=ctx, dtype=dtype)
  File 

In [30]:
 for step, batch in enumerate(db):
        for i in range(batch_size):
            support_x = batch[0][i]
            support_y = batch[1][i]
            query_x = batch[2][i]
            query_y = batch[3][i]
            break

KeyboardInterrupt: 

Process Process-15:
Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/venv/lib/python3.6/site-packages/mxnet/gluon/data/dataloader.py", line 119, in worker_loop
    batch = batchify_fn([dataset[i] for i in samples])
  File "/opt/venv/lib/python3.6/site-packages/mxnet/gluon/data/dataloader.py", line 119, in <listcomp>
    batch = batchify_fn([dataset[i] for i in samples])
  File "<ipython-input-4-1175bd14f7c6>", line 163, in __getitem__
    support_x[i] = self.transform(path)
  File "/opt/venv/lib/python3.6/site-packages/mxnet/gluon/block.py", line 413, in __call__
    return self.forward(*args)
  File "/opt/venv/lib/python3.6/site-packages/mxnet/gluon/nn/basic_layers.py", line 53, in forward
    x = block(x)
  File "<ipython-input-4-1175bd14f7c6>", line 40, in <lambda>
    self.

In [31]:
query_y


[2. 2. 3. 3. 0. 0. 1. 1. 4. 4.]
<NDArray 10 @cpu_shared(0)>

In [57]:
va_acc

[]

In [55]:
val_acc.get()[1]

0.26

In [51]:
train_acc.get()

('accuracy', 0.2)

In [29]:
data.shape

(25, 3, 84, 84)

In [38]:
nd.softmax(va_output)[3]


[4.5847210e-09 1.9899325e-08 5.3198522e-01 1.0563797e-04 3.7190720e-01
 1.1855067e-03 8.3273480e-04 8.4661752e-02 1.4400263e-13 9.3218870e-03]
<NDArray 10 @gpu(0)>

In [39]:
nd.argmax(va_output,1)


[7. 7. 2. 2. 2. 2. 4. 7. 4. 2.]
<NDArray 10 @gpu(0)>

In [None]:
support_x.shape[0]

In [42]:
loss_val


[3.1161484e+01 4.7606003e+01 3.2631065e+01 2.7621670e+01 2.8246191e+01
 5.5496621e-01 3.1153409e+00 1.7690205e+01 1.0726947e+00 1.7560464e+00
 2.9597314e-02 3.6860043e-01 1.8370737e+00 2.5596783e+00 2.7099288e+00
 1.2178973e+01 1.6275099e+01 1.0593153e+01 1.3429583e+01 2.0996267e+01
 8.1341248e+00 2.8600287e+00 2.2412050e-01 6.0982933e+00 3.4674497e+00
 9.2115488e+00 9.8984728e+00 1.1255022e+01 9.4889727e+00 5.8908181e+00
 1.1569217e+01 1.1682887e+01 1.1630000e+01 1.1166780e+01 1.1765790e+01
 1.1544019e+01 5.6001492e+00 1.0123260e+01 1.5429238e+01 2.8409996e+01
 3.0453697e+01 1.7150373e+01 1.9238771e+01 1.7895241e+01 1.9916370e+01
 1.9897079e+01 1.9414797e+01 1.3818964e+01 1.9827072e+01 2.6191713e+01]
<NDArray 50 @gpu(0)>

In [None]:
nd.argmax(output,axis=1)

In [41]:
data.shape

(50, 3, 84, 84)

In [21]:
acc = mx.metric.Accuracy()

In [22]:
acc.update(preds=nd.argmax(output,axis=1),labels=label)

In [25]:
print("{}".format(acc.get()))

('accuracy', 0.9375)


In [29]:
data.swapaxes(1,2).shape

(128, 28, 1, 28)

In [31]:
output = model(data)

In [32]:
output


[[ -4.70972    -1.0343139  -5.9070234 ...  10.077577    3.4620867
    6.4959126]
 [ -5.9397173  -1.4487193  -5.829402  ...  11.126983    2.5540066
    8.838758 ]
 [ 10.021524   31.977276   -2.28505   ... -20.696571   -4.942632
  -32.530014 ]
 ...
 [  7.2817163   4.4386163   1.8730931 ... -10.065019   -0.6611936
  -14.264185 ]
 [  2.403361    3.1729379   2.2249227 ...  -9.469499   -1.2017921
   -7.741617 ]
 [ -8.38825    -1.443956   -0.8961567 ...   2.689477    0.8111732
   11.417488 ]]
<NDArray 128x10 @gpu(0)>

In [42]:
label[0:20]


[7. 7. 1. 9. 4. 1. 9. 9. 1. 8. 1. 2. 3. 6. 5. 3. 5. 4. 6. 6.]
<NDArray 20 @gpu(0)>

In [34]:
soft = nd.softmax(output)

In [43]:
nd.argmax(soft[0:20],axis=1)


[7. 7. 1. 9. 2. 1. 9. 9. 1. 8. 1. 2. 3. 6. 5. 3. 5. 4. 6. 6.]
<NDArray 20 @gpu(0)>

In [37]:
soft[3]


[1.2977268e-08 3.4302653e-07 1.9245779e-06 1.4412031e-08 1.9874494e-06
 4.5362492e-03 1.3529815e-08 2.5637108e-03 2.0838407e-04 9.9268734e-01]
<NDArray 10 @gpu(0)>

### 

In [10]:
test = nd.random_normal(shape=(128, 64, 1),ctx=mx.gpu())

In [11]:
test.shape

(128, 64, 1)

In [12]:
attn =AttentionBlock(64,32)
attn.collect_params().initialize(ctx=mx.gpu())

In [13]:
a = attn(test)

In [14]:
a.shape

(128, 33, 64)

In [19]:
mod1 = TCBlock(33,T,128)
mod1.collect_params().initialize(ctx=mx.gpu())

In [20]:
out = mod1(a)

(128, 33, 64)
Dense success
(128, 161, 64)
Dense success
(128, 289, 64)
Dense success
(128, 417, 64)
Dense success
(128, 545, 64)
Dense success
(128, 673, 64)
Dense success
(128, 801, 64)
Dense success


In [21]:
out.shape

(128, 929, 64)

In [56]:
casual_conv = nn.Conv1D(65,kernel_size=2,padding = 3, dilation = 1)
casual_conv.collect_params().initialize()

In [57]:
casual_conv.collect_params().keys()

odict_keys(['conv4_weight', 'conv4_bias'])

In [60]:
casual_conv.collect_params()['conv4_weight'].data()

DeferredInitializationError: Parameter 'conv4_weight' has not been initialized yet because initialization was deferred. Actual initialization happens during the first forward pass. Please pass one batch of data through the network before accessing Parameters. You can also avoid deferred initialization by specifying in_units, num_features, etc., for network layers.

In [59]:
nd.argmax(out,1)


[[40. 33. 15. ... 55. 51. 62.]
 [43. 39.  3. ... 53. 43. 15.]
 [ 3. 56. 49. ... 40. 29. 30.]
 ...
 [18. 27. 61. ... 12.  3. 16.]
 [55. 17. 36. ... 35. 51. 13.]
 [17. 57. 14. ... 32.  9. 46.]]
<NDArray 128x65 @cpu(0)>

In [98]:
loss = gluon.loss.SoftmaxCrossEntropyLoss()

In [100]:
loss(out,label)


[2.2953513 2.268253  2.2788887 2.2653809 2.2001681 2.3938656 2.352553
 2.1504626 2.2000809 2.3043551 2.247101  2.2911987 2.3168476 2.3106627
 2.316557  2.3206165 2.3051178 2.400387  2.2482119 2.3681443 2.2640896
 2.3992994 2.21503   2.2837143 2.403338  2.2815537 2.2992458 2.3473132
 2.2210064 2.3172932 2.4186    2.35235   2.3988147 2.1612604 2.4434488
 2.2524354 2.3666632 2.381284  2.3241143 2.308886  2.3716798 2.3149261
 2.3523326 2.2905169 2.3369887 2.2060206 2.3598323 2.41046   2.3438096
 2.3870363 2.211364  2.322115  2.3808486 2.353376  2.1823812 2.2900634
 2.2563624 2.3698769 2.3828924 2.3387487 2.3133004 2.2479045 2.26498
 2.3250399 2.1810293 2.3330922 2.3212776 2.2873394 2.3081412 2.3115413
 2.384853  2.3017528 2.26551   2.291781  2.3923473 2.2603652 2.1976027
 2.3875916 2.2776318 2.291148  2.2299504 2.3043473 2.4108205 2.3028996
 2.2174864 2.3599975 2.3221505 2.2949238 2.3127408 2.3572457 2.3136406
 2.2745225 2.2646685 2.3358376 2.3021743 2.3337815]
<NDArray 96 @cpu(0)>

In [76]:
N = 10
K = 4
T = N * K + 1
layer_count = math.ceil(math.log(T)/math.log(2))

In [48]:
cnn_emb = CnnEmbedding()
cnn_emb.collect_params().initialize()

In [49]:
test1 = cnn_emb(data)

In [50]:
test1.shape

(96, 64)

In [52]:
test1.expand_dims(2).shape

(96, 64, 1)

In [77]:
snail = SNAIL(N, K)
snail.collect_params().initialize()

In [78]:
out = snail(data)

In [79]:
out.shape

(96, 1600, 2497)

In [80]:
conv1 = nn.Conv1D(N,1)
conv1.collect_params().initialize()

In [81]:
a = conv1(out)

In [82]:
a.shape

(96, 10, 2497)

In [83]:
b = a.sum(2)

In [85]:
b.shape

(96, 10)

In [89]:
c = nd.softmax(b[0])

In [90]:
sum(c)


[1.0000001]
<NDArray 1 @cpu(0)>

In [36]:
a[0][0][100]


[nan]
<NDArray 1 @cpu(0)>

In [71]:
mod0 = AttentionBlock(64, 32)
mod0.collect_params().initialize()

In [72]:
emb_rst.shape

(96, 64, 64)

In [73]:
mod0_rst=mod0(emb_rst)

In [74]:
mod0_rst.shape

(96, 64, 128)

In [66]:
mod0_rst.shape

(96, 64, 128)

In [76]:
mod1 = TCBlock(65+32, T, 128)
mod1.collect_params().initialize()

In [77]:
mod1_rst=mod1(mod0_rst)

In [78]:
mod1_rst.shape

(96, 1472, 128)

In [117]:
F.broadcast_mul(x,y).shape

(1, 2, 3)

In [322]:
mask = np.full(shape=(logits.shape[1],logits.shape[2]),fill_value=1).astype('float')
mask = np.triu(mask,1)
mask = np.expand_dims(mask,0)
mask = np.repeat(mask,logits.shape[0],0)

In [324]:
mask[0]

array([[0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,

In [327]:
def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return subsequent_mask

In [328]:
msk = subsequent_mask(10)

In [329]:
msk

array([[[0, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]], dtype=uint8)

In [71]:
        mask = np.full(shape=(5,5),fill_value=1).astype('float')
        mask = np.triu(mask,1)
        #mask = np.expand_dims(mask,0)
        #mask = np.repeat(mask,logits.shape[0],0)
        #np.place(mask,mask==1,float('-inf'))

In [72]:
mask

array([[0., 1., 1., 1., 1.],
       [0., 0., 1., 1., 1.],
       [0., 0., 0., 1., 1.],
       [0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 0.]])