In [1]:
import gpytorch
import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torchvision import transforms

gpytorch.functions.use_toeplitz = False

In [2]:
ways = 20
shots = 5
train_dir_str = "way%dshot%d" %(ways, shots)
test_dir_str = "way%dtest" %ways

In [3]:
train_base_omni = torchvision.datasets.ImageFolder('/scratch/bw462/omni_data/general', transform=transforms.Compose([
                        transforms.Resize((28,28)),
                        transforms.ToTensor()
                   ]))                                              
"""
test_mnist = torchvision.datasets.ImageFolder('/tmp', split='test',
                                        download=True, transform=transforms.Compose([
                       transforms.ToTensor()
                   ]))
"""

"\ntest_mnist = torchvision.datasets.ImageFolder('/tmp', split='test',\n                                        download=True, transform=transforms.Compose([\n                       transforms.ToTensor()\n                   ]))\n"

In [4]:
class FeatureExtractor(nn.Sequential):
    
    def __init__(self):
        super(FeatureExtractor, self).__init__(nn.Conv2d(1, 32, kernel_size=5, padding=2),
                                 nn.BatchNorm2d(32),
                                 nn.ReLU(),
                                 nn.MaxPool2d(2, 2),
                                 nn.Conv2d(32, 64, kernel_size=5, padding=2),
                                 nn.BatchNorm2d(64),
                                 nn.ReLU(),
                                 nn.MaxPool2d(2, 2))
        
class Bottleneck(nn.Sequential):
    
    def __init__(self):
        super(Bottleneck, self).__init__(nn.Linear(64*7*7, 128),
                                         nn.BatchNorm1d(128),
                                 nn.ReLU(),
                                 nn.Linear(128, 128),
                                 nn.BatchNorm1d(128),
                                 nn.ReLU(),
                                 nn.Linear(128,64),
                                 nn.BatchNorm1d(64))

class LeNet(nn.Module):
    
    def __init__(self):
        super(LeNet, self).__init__()
        self.feature_extractor = FeatureExtractor()
        self.bottleneck = Bottleneck()
        self.final_layer = nn.Sequential(
                                 nn.ReLU(),
                                 nn.Linear(64,1319))
    
    def forward(self, x):
        input_x = x[:,0,:,:].unsqueeze(1)
        features = self.feature_extractor(input_x)
        bottlenecked_features = self.bottleneck(features.view(-1, 64 * 7 * 7))
        classification = self.final_layer(bottlenecked_features)
        return classification
        

In [5]:
train_data_loader = torch.utils.data.DataLoader(train_base_omni, shuffle=True, pin_memory=True, batch_size=256)

In [6]:
criterion = nn.CrossEntropyLoss().cuda()

In [7]:
model = LeNet().cuda() 
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)

In [8]:
num_epochs = 0
if num_epochs > 0:
    model.train()
    for i in range(num_epochs):
        for x, y in train_data_loader:
            optimizer.zero_grad()
            x = Variable(x).cuda()
            y = Variable(y).cuda()
            output = model(x)
            loss = criterion(output, y)
            loss.backward()
            optimizer.step()
        print("Loss: %.3f" % loss.data[0])
    torch.save(model.state_dict(), '/scratch/bw462/omni_net.dat')
else:
    model.load_state_dict(torch.load('/scratch/bw462/omni_net.dat'))


In [9]:
model.eval()
#test_data_loader = torch.utils.data.DataLoader(test_mnist, shuffle=False, pin_memory=True, batch_size=256)
avg = 0.
i = 0.
for test_batch_x, test_batch_y in train_data_loader:
    predictions = model(Variable(test_batch_x).cuda()).max(-1)[1]
    test_batch_y = Variable(test_batch_y).cuda()
    avg += torch.eq(predictions, test_batch_y).float().mean().data[0]
    i += 1.
print('Accuracy: %.4f' % (avg / i))

Accuracy: 0.9992


In [10]:
list(model.bottleneck.modules())[-1].weight.data.fill_(1)
None

In [27]:
ways = 20
shots = 5
train_dir_str = "way%dshot%d" %(ways, shots)
test_dir_str = "way%dtest" %ways

In [28]:
train_shots_omni = torchvision.datasets.ImageFolder('/scratch/bw462/omni_data/' + train_dir_str, transform=transforms.Compose([
                        transforms.Scale((28,28)),
                        transforms.ToTensor()
                   ]))    

In [29]:
oneshot_model = LeNet().cuda()
oneshot_model.feature_extractor = model.feature_extractor
oneshot_model.bottleneck = model.bottleneck
shots_loader = torch.utils.data.DataLoader(train_shots_omni, batch_size=512., pin_memory=True, shuffle=True)

In [30]:
# Find optimal model hyperparameters
oneshot_model.train()

optimizer = torch.optim.Adam(oneshot_model.final_layer.parameters(), lr=0.1)
optimizer.n_iter = 0
for i in range(200):
    for j, (x, y) in enumerate(shots_loader):
        optimizer.zero_grad()
        x = Variable(x).cuda()
        y = Variable(y).cuda()
        output = oneshot_model(x)
        loss = criterion(output, y)
        loss.backward()
        optimizer.n_iter += 1
        print('Iter %d/200 - Loss: %.3f' % (
            i + 1, loss.data[0],
        ))
        optimizer.step()

oneshot_model.eval()
test_shots_omni = torchvision.datasets.ImageFolder('/scratch/bw462/omni_data/' + test_dir_str, transform=transforms.Compose([
                        transforms.Scale((28,28)),
                        transforms.ToTensor()
                   ]))    
test_shots_loader = torch.utils.data.DataLoader(test_shots_omni, batch_size=512., pin_memory=True, shuffle=True)
oneshot_model.eval()
avg = 0.
i = 0.
for test_batch_x, test_batch_y in test_shots_loader:
    predictions = oneshot_model(Variable(test_batch_x).cuda()).max(-1)[1]
    test_batch_y = Variable(test_batch_y).cuda()
    avg += torch.eq(predictions, test_batch_y).float().mean().data[0]
    i += 1.
print('Accuracy: %.4f' % (avg / i))

Iter 1/200 - Loss: 7.256
Iter 2/200 - Loss: 3.316
Iter 3/200 - Loss: 3.147
Iter 4/200 - Loss: 2.885
Iter 5/200 - Loss: 2.110
Iter 6/200 - Loss: 1.925
Iter 7/200 - Loss: 1.148
Iter 8/200 - Loss: 0.901
Iter 9/200 - Loss: 0.789
Iter 10/200 - Loss: 0.598
Iter 11/200 - Loss: 0.334
Iter 12/200 - Loss: 0.111
Iter 13/200 - Loss: 0.062
Iter 14/200 - Loss: 0.066
Iter 15/200 - Loss: 0.111
Iter 16/200 - Loss: 0.172
Iter 17/200 - Loss: 0.183
Iter 18/200 - Loss: 0.117
Iter 19/200 - Loss: 0.053
Iter 20/200 - Loss: 0.038
Iter 21/200 - Loss: 0.045
Iter 22/200 - Loss: 0.038
Iter 23/200 - Loss: 0.023
Iter 24/200 - Loss: 0.022
Iter 25/200 - Loss: 0.029
Iter 26/200 - Loss: 0.029
Iter 27/200 - Loss: 0.020
Iter 28/200 - Loss: 0.014
Iter 29/200 - Loss: 0.017
Iter 30/200 - Loss: 0.021
Iter 31/200 - Loss: 0.017
Iter 32/200 - Loss: 0.012
Iter 33/200 - Loss: 0.012
Iter 34/200 - Loss: 0.014
Iter 35/200 - Loss: 0.013
Iter 36/200 - Loss: 0.009
Iter 37/200 - Loss: 0.007
Iter 38/200 - Loss: 0.008
Iter 39/200 - Loss: 0

In [None]:
#5 way 5 shot: 95.71
# 5 way 1 shot: 58.57
# 20 way 5 shot: 89.64%
# 20 way 1 shot: 41.07