In [1]:
from collections import OrderedDict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [9]:
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1):
    """convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
                     padding=padding, bias=False)

def conv_block(in_channels, out_channels):
    '''
    returns a block conv-bn-relu-pool
    '''
    return nn.Sequential(OrderedDict([
        ('conv', nn.Conv2d(in_channels, out_channels, 3, padding=1)),
        ('bn', nn.BatchNorm2d(out_channels, momentum=1)),
        #('bn', nn.BatchNorm2d(out_channels)),
        ('relu', nn.ReLU()),
        ('pool', nn.MaxPool2d(2))
    ]))

In [10]:
class OmniglotNet(nn.Module):
    '''
    Model as described in the reference paper,
    source: https://github.com/jakesnell/prototypical-networks/blob/f0c48808e496989d01db59f86d4449d7aee9ab0c/protonets/models/few_shot.py#L62-L84
    '''
    def __init__(self, x_dim=1, hid_dim=64, z_dim=64):
        super(OmniglotNet, self).__init__()
        self.encoder = nn.Sequential(OrderedDict([
            ('block1', conv_block(x_dim, hid_dim)),
            ('block2', conv_block(hid_dim, hid_dim)),
            ('block3', conv_block(hid_dim, hid_dim)),
            ('block4', conv_block(hid_dim, z_dim)),
        ]))

    def forward(self, x, weights=None):
        if weights is None:
            x = self.encoder(x)
        else:
            x = F.conv2d(x, weights['encoder.block1.conv.weight'], weights['encoder.block1.conv.bias'])
            x = batchnorm(x, weight=weights['encoder.block1.bn.weight'], bias=weights['encoder.block1.bn.bias'])
            x = F.relu(x)
            x = F.max_pool2d(x, 2, 2)
            x = F.conv2d(x, weights['encoder.block2.conv.weight'], weights['encoder.block2.conv.bias'])
            x = batchnorm(x, weight=weights['encoder.block2.bn.weight'], bias=weights['encoder.block2.bn.bias'])
            x = F.relu(x)
            x = F.max_pool2d(x, 2, 2)
            x = F.conv2d(x, weights['encoder.block3.conv.weight'], weights['encoder.block3.conv.bias'])
            x = batchnorm(x, weight=weights['encoder.block3.bn.weight'], bias=weights['encoder.block3.bn.bias'])
            x = F.relu(x)
            x = F.max_pool2d(x, 2, 2)
            x = F.conv2d(x, weights['encoder.block4.conv.weight'], weights['encoder.block4.conv.bias'])
            x = batchnorm(x, weight=weights['encoder.block4.bn.weight'], bias=weights['encoder.block4.bn.bias'])
            x = F.relu(x)
            x = F.max_pool2d(x, 2, 2)
        return x.view(x.size(0), -1)


In [11]:
x = np.load('x.npy')

In [14]:
x.shape

(1224, 1, 28, 28)

In [15]:
inputs = torch.randn(1224, 1, 28, 28)

In [16]:
encoder = OmniglotNet()

In [17]:
out = encoder(inputs)

In [18]:
out.shape

torch.Size([1224, 64])

In [19]:
tmp = torch.randn(24, 51, 1514)

In [22]:
import math
N = 10
K = 5
num_filters = int(math.ceil(math.log(N * K + 1)))
num_channels = 64 + N
num_channels += 32
num_channels += num_filters * 128
num_channels += 128
num_channels += num_filters * 128
num_channels += 256

In [23]:
num_channels

1514

In [24]:
fc = nn.Linear(num_channels,N)

In [25]:
fc_out = fc(tmp)

In [26]:
fc_out.shape

torch.Size([24, 51, 10])

In [6]:
size = 5

In [7]:
mask = np.array([[1 if i>j else 0 for i in range(size)] for j in range(size)])

In [8]:
mask

array([[0, 1, 1, 1, 1],
       [0, 0, 1, 1, 1],
       [0, 0, 0, 1, 1],
       [0, 0, 0, 0, 1],
       [0, 0, 0, 0, 0]])

In [13]:
mask = torch.ByteTensor(mask)

In [14]:
mask

tensor([[ 0,  1,  1,  1,  1],
        [ 0,  0,  1,  1,  1],
        [ 0,  0,  0,  1,  1],
        [ 0,  0,  0,  0,  1],
        [ 0,  0,  0,  0,  0]], dtype=torch.uint8)

In [9]:
temp = torch.rand((20,5,5))

In [11]:
temp.shape

torch.Size([20, 5, 5])

In [15]:
temp.data.masked_fill_(mask,-float('inf'))

tensor([[[ 0.2699,    -inf,    -inf,    -inf,    -inf],
         [ 0.4614,  0.0633,    -inf,    -inf,    -inf],
         [ 0.2330,  0.2013,  0.7973,    -inf,    -inf],
         [ 0.2812,  0.3697,  0.2282,  0.8384,    -inf],
         [ 0.0285,  0.0449,  0.4213,  0.9759,  0.6087]],

        [[ 0.1471,    -inf,    -inf,    -inf,    -inf],
         [ 0.2804,  0.2778,    -inf,    -inf,    -inf],
         [ 0.6791,  0.7563,  0.8883,    -inf,    -inf],
         [ 0.1166,  0.2283,  0.1182,  0.3741,    -inf],
         [ 0.8630,  0.6852,  0.9552,  0.0516,  0.8987]],

        [[ 0.6251,    -inf,    -inf,    -inf,    -inf],
         [ 0.1287,  0.2921,    -inf,    -inf,    -inf],
         [ 0.3378,  0.7609,  0.7919,    -inf,    -inf],
         [ 0.0354,  0.4564,  0.8991,  0.5906,    -inf],
         [ 0.0520,  0.6638,  0.9711,  0.8368,  0.1523]],

        [[ 0.2933,    -inf,    -inf,    -inf,    -inf],
         [ 0.7090,  0.8675,    -inf,    -inf,    -inf],
         [ 0.2611,  0.8310,  0.1839,    -i

In [None]:
mask = np.full(shape=(logits.shape[1],logits.shape[2]),fill_value=1).astype('float')
mask = np.triu(mask,1)
mask = np.expand_dims(mask,0)
mask = np.repeat(mask,logits.shape[0],0)