In [1]:
import gpytorch
import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torchvision import transforms

gpytorch.functions.use_toeplitz = False

In [2]:
ways = 5
shots = 5
train_dir_str = "way%dshot%d" %(ways, shots)
test_dir_str = "way%dtest" %ways

In [3]:
train_base_omni = torchvision.datasets.ImageFolder('/scratch/bw462/omni_data/general', transform=transforms.Compose([
                        transforms.Resize((28,28)),
                        transforms.ToTensor()
                   ]))                                              
"""
test_mnist = torchvision.datasets.ImageFolder('/tmp', split='test',
                                        download=True, transform=transforms.Compose([
                       transforms.ToTensor()
                   ]))
"""

"\ntest_mnist = torchvision.datasets.ImageFolder('/tmp', split='test',\n                                        download=True, transform=transforms.Compose([\n                       transforms.ToTensor()\n                   ]))\n"

In [4]:
class FeatureExtractor(nn.Sequential):
    
    def __init__(self):
        super(FeatureExtractor, self).__init__(nn.Conv2d(1, 32, kernel_size=5, padding=2),
                                 nn.BatchNorm2d(32),
                                 nn.ReLU(),
                                 nn.MaxPool2d(2, 2),
                                 nn.Conv2d(32, 64, kernel_size=5, padding=2),
                                 nn.BatchNorm2d(64),
                                 nn.ReLU(),
                                 nn.MaxPool2d(2, 2))
        
class Bottleneck(nn.Sequential):
    
    def __init__(self):
        super(Bottleneck, self).__init__(nn.Linear(64*7*7, 128),
                                         nn.BatchNorm1d(128),
                                 nn.ReLU(),
                                 nn.Linear(128, 128),
                                 nn.BatchNorm1d(128),
                                 nn.ReLU(),
                                 nn.Linear(128,64),
                                 nn.BatchNorm1d(64))

class LeNet(nn.Module):
    
    def __init__(self):
        super(LeNet, self).__init__()
        self.feature_extractor = FeatureExtractor()
        self.bottleneck = Bottleneck()
        self.final_layer = nn.Sequential(
                                 nn.ReLU(),
                                 nn.Linear(64,1319))
    
    def forward(self, x):
        input_x = x[:,0,:,:].unsqueeze(1)
        features = self.feature_extractor(input_x)
        bottlenecked_features = self.bottleneck(features.view(-1, 64 * 7 * 7))
        classification = self.final_layer(bottlenecked_features)
        return classification
        

In [5]:
train_data_loader = torch.utils.data.DataLoader(train_base_omni, shuffle=True, pin_memory=True, batch_size=256)

In [6]:
criterion = nn.CrossEntropyLoss().cuda()

In [7]:
model = LeNet().cuda() 
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)

In [8]:
num_epochs = 0
if num_epochs > 0:
    model.train()
    for i in range(num_epochs):
        for x, y in train_data_loader:
            optimizer.zero_grad()
            x = Variable(x).cuda()
            y = Variable(y).cuda()
            output = model(x)
            loss = criterion(output, y)
            loss.backward()
            optimizer.step()
        print("Loss: %.3f" % loss.data[0])
    torch.save(model.state_dict(), '/scratch/bw462/omni_net.dat')
else:
    model.load_state_dict(torch.load('/scratch/bw462/omni_net.dat'))


In [9]:
model.eval()
#test_data_loader = torch.utils.data.DataLoader(test_mnist, shuffle=False, pin_memory=True, batch_size=256)
avg = 0.
i = 0.
for test_batch_x, test_batch_y in train_data_loader:
    predictions = model(Variable(test_batch_x).cuda()).max(-1)[1]
    test_batch_y = Variable(test_batch_y).cuda()
    avg += torch.eq(predictions, test_batch_y).float().mean().data[0]
    i += 1.
print('Accuracy: %.4f' % (avg / i))

Accuracy: 0.9992


In [10]:
list(model.bottleneck.modules())[-1].weight.data.fill_(1)
None

In [11]:
ways = 20
shots = 5
train_dir_str = "way%dshot%d" %(ways, shots)
test_dir_str = "way%dtest" %ways

In [12]:
train_shots_omni = torchvision.datasets.ImageFolder('/scratch/bw462/omni_data/' + train_dir_str, transform=transforms.Compose([
                        transforms.Scale((28,28)),
                        transforms.ToTensor()
                   ]))    



In [13]:
oneshot_model = LeNet().cuda()
oneshot_model.feature_extractor = model.feature_extractor
oneshot_model.bottleneck = model.bottleneck
shots_loader = torch.utils.data.DataLoader(train_shots_omni, batch_size=512., pin_memory=True, shuffle=True)

In [14]:
# Find optimal model hyperparameters
oneshot_model.train()

optimizer = torch.optim.Adam(oneshot_model.final_layer.parameters(), lr=0.5)
optimizer.n_iter = 0
for i in range(200):
    for j, (x, y) in enumerate(shots_loader):
        optimizer.zero_grad()
        x = Variable(x).cuda()
        y = Variable(y).cuda()
        output = oneshot_model(x)
        loss = criterion(output, y)
        loss.backward()
        optimizer.n_iter += 1
        if i%10 == 0:
            print('Iter %d/200 - Loss: %.3f' % (
                i + 1, loss.data[0],
            ))
        optimizer.step()

oneshot_model.eval()
test_shots_omni = torchvision.datasets.ImageFolder('/scratch/bw462/omni_data/' + test_dir_str, transform=transforms.Compose([
                        transforms.Scale((28,28)),
                        transforms.ToTensor()
                   ]))    
test_shots_loader = torch.utils.data.DataLoader(test_shots_omni, batch_size=512., pin_memory=True, shuffle=True)
oneshot_model.eval()
avg = 0.
summed_confidence = 0.
i = 0.
for test_batch_x, test_batch_y in test_shots_loader:
    model_output = oneshot_model(Variable(test_batch_x).cuda())
    predictions = model_output.max(-1)[1]
    confidences = nn.functional.softmax(model_output, dim=1).max(dim=1)[0]
    avg_confidence = confidences.mean()
    summed_confidence += avg_confidence
    test_batch_y = Variable(test_batch_y).cuda()
    avg += torch.eq(predictions, test_batch_y).float().mean().data[0]
    i += 1.
print('Accuracy: %.4f' % (avg / i))
print('Average Confidence')
print(summed_confidence.data[0] / i)

Iter 1/200 - Loss: 7.763
Iter 11/200 - Loss: 12.044
Iter 21/200 - Loss: 1.846
Iter 31/200 - Loss: 0.041
Iter 41/200 - Loss: 0.002
Iter 51/200 - Loss: 0.001
Iter 61/200 - Loss: 0.000
Iter 71/200 - Loss: 0.000
Iter 81/200 - Loss: 0.000
Iter 91/200 - Loss: 0.000
Iter 101/200 - Loss: 0.000
Iter 111/200 - Loss: 0.000
Iter 121/200 - Loss: 0.000
Iter 131/200 - Loss: 0.000
Iter 141/200 - Loss: 0.000
Iter 151/200 - Loss: 0.000
Iter 161/200 - Loss: 0.000
Iter 171/200 - Loss: 0.000
Iter 181/200 - Loss: 0.000
Iter 191/200 - Loss: 0.000
Accuracy: 0.8429
Average Confidence
0.984777152538


In [15]:
#5 way 5 shot: Accuracy 95.71%    Confidence 98.5%
# 5 way 1 shot: Accuracy 58.57%   Confidence 98.5%
# 20 way 5 shot: Accuracy 89.64%  Confidence 99.2%
# 20 way 1 shot: Accuracy 41.07%  Confidence 98.6%
