In [1]:
import gpytorch
import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torchvision import transforms

gpytorch.functions.use_toeplitz = False

In [2]:
class FeatureExtractor(nn.Sequential):
    
    def __init__(self):
        super(FeatureExtractor, self).__init__(nn.Conv2d(1, 32, kernel_size=5, padding=2),
                                 nn.BatchNorm2d(32),
                                 nn.ReLU(),
                                 nn.MaxPool2d(2, 2),
                                 nn.Conv2d(32, 64, kernel_size=5, padding=2),
                                 nn.BatchNorm2d(64),
                                 nn.ReLU(),
                                 nn.MaxPool2d(2, 2))
        
class Bottleneck(nn.Sequential):
    
    def __init__(self):
        super(Bottleneck, self).__init__(nn.Linear(64*7*7, 128),
                                         nn.BatchNorm1d(128),
                                 nn.ReLU(),
                                 nn.Linear(128, 128),
                                 nn.BatchNorm1d(128),
                                 nn.ReLU(),
                                 nn.Linear(128,64),
                                 nn.BatchNorm1d(64))

class LeNet(nn.Module):
    
    def __init__(self):
        super(LeNet, self).__init__()
        self.feature_extractor = FeatureExtractor()
        self.bottleneck = Bottleneck()
        self.final_layer = nn.Sequential(
                                 nn.ReLU(),
                                 nn.Linear(64,10))
    
    def forward(self, x):
        features = self.feature_extractor(x)
        bottlenecked_features = self.bottleneck(features.view(-1, 64 * 7 * 7))
        classification = self.final_layer(bottlenecked_features)
        return classification
        

In [3]:
train_mnist = torchvision.datasets.MNIST('/scratch/bw462/mnist/', train=True,
                                         download=True, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))                

train_indices = train_mnist.train_labels.lt(7).nonzero().squeeze()
train_mnist.train_data = train_mnist.train_data.index_select(0, train_indices)
train_mnist.train_labels = train_mnist.train_labels.index_select(0, train_indices)

test_mnist = torchvision.datasets.MNIST('/scratch/bw462/mnist/', train=False,
                                        download=True, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))

test_indices = test_mnist.test_labels.lt(7).nonzero().squeeze()
test_mnist.test_data = test_mnist.test_data.index_select(0, test_indices)
test_mnist.test_labels = test_mnist.test_labels.index_select(0, test_indices)

len(train_mnist), len(test_mnist)

(41935, 6989)

In [4]:
train_data_loader = torch.utils.data.DataLoader(train_mnist, shuffle=True, pin_memory=True, batch_size=256)

In [5]:
criterion = nn.CrossEntropyLoss().cuda()

In [6]:
model = LeNet().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

In [7]:
num_epochs = 0
if num_epochs > 0:
    for i in range(num_epochs):
        for x, y in train_data_loader:
            optimizer.zero_grad()
            x = Variable(x.cuda())
            y = Variable(y.cuda())
            output = model(x)
            loss = criterion(output, y)
            loss.backward()
            optimizer.step()
        print("Loss: %.3f" % loss.data[0])
    torch.save(model.state_dict(), '/scratch/bw462/mnist/lenet_oneshot.dat')
else:
    model.load_state_dict(torch.load('/scratch/bw462/mnist/lenet_oneshot.dat'))
    
model.eval()
test_data_loader = torch.utils.data.DataLoader(test_mnist, shuffle=False, pin_memory=True, batch_size=256)
avg = 0.
i = 0.
for test_batch_x, test_batch_y in test_data_loader:
    predictions = model(Variable(test_batch_x).cuda()).max(-1)[1]
    test_batch_y = Variable(test_batch_y).cuda()
    avg += torch.eq(predictions, test_batch_y).float().mean().data[0]
    i += 1.
print('Accuracy: %.4f' % (avg / i))

Accuracy: 0.9958


In [8]:
list(model.bottleneck.modules())[-1].weight.data.fill_(1)
None

In [9]:
len(train_mnist)

41935

In [10]:
one_shot_model = LeNet().cuda()
one_shot_model.feature_extractor = model.feature_extractor
one_shot_model.bottleneck = model.bottleneck

In [11]:
train_mnist_789 = torchvision.datasets.MNIST('/scratch/bw462/mnist/', train=True,
                                         download=True, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))                

train_indices = torch.cat([
    train_mnist_789.train_labels.eq(7).nonzero().squeeze()[:5],
    train_mnist_789.train_labels.eq(8).nonzero().squeeze()[:5],
    train_mnist_789.train_labels.eq(9).nonzero().squeeze()[:5],
])
train_mnist_789.train_data = train_mnist_789.train_data.index_select(0, train_indices)
train_mnist_789.train_labels = train_mnist_789.train_labels.index_select(0, train_indices) - 7

test_mnist_789 = torchvision.datasets.MNIST('/scratch/bw462/mnist/', train=False,
                                        download=True, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))

test_indices = test_mnist_789.test_labels.ge(7).nonzero().squeeze()
test_mnist_789.test_data = test_mnist_789.test_data.index_select(0, test_indices)
test_mnist_789.test_labels = test_mnist_789.test_labels.index_select(0, test_indices) - 7

len(train_mnist_789), len(test_mnist_789)

(15, 3011)

In [12]:
few_shot_train_loader = torch.utils.data.DataLoader(train_mnist_789, batch_size=2048.,
                                                    pin_memory=True, shuffle=True)
optimizer = torch.optim.SGD(one_shot_model.parameters(), lr=0.01)

In [None]:
num_epochs = 1000
for i in range(num_epochs):
    for x, y in few_shot_train_loader:
        optimizer.zero_grad()
        x = Variable(x.cuda())
        y = Variable(y.cuda())
        output = one_shot_model(x)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()
    print("Loss: %.3f" % loss.data[0])

    
one_shot_model.eval()
few_shot_test_data_loader = torch.utils.data.DataLoader(test_mnist_789, shuffle=False, pin_memory=True, batch_size=256)
avg = 0.
i = 0.
for test_batch_x, test_batch_y in few_shot_test_data_loader:
    predictions = one_shot_model(Variable(test_batch_x).cuda()).max(-1)[1]
    test_batch_y = Variable(test_batch_y).cuda()
    avg += torch.eq(predictions, test_batch_y).float().mean().data[0]
    i += 1.
print('Accuracy: %.4f' % (avg / i))

Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss: 0.001
Loss