In [1]:
import numpy as np
import torchvision

import torch.nn.functional as F
import torch.nn as nn
import torch
import torch.optim as optim

import math
from torch.nn import init
from torch.autograd import Variable
from tqdm import trange
from torch.distributions.categorical import Categorical
import scipy
import scipy.linalg
from collections import Counter


#from model import CnnActorCriticNetwork, RNDModel
from utils import global_grad_norm_
device = 'cuda'


In [3]:
n_epochs = 3
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10

random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

<torch._C.Generator at 0x7f5ce26d4fb0>

In [61]:
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./data/MNIST/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=1, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./data/MNIST/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=1000, shuffle=True)

In [5]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)

In [62]:
network = Net()
optimizer = optim.SGD(network.parameters(), lr=learning_rate,
                      momentum=momentum)

In [51]:
train_losses = []
train_counter = []
test_losses = []
test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]

In [52]:
num_of_shots = 11
break_trashold = num_of_shots*15
few_shot_dataset = []
few_shot_dataset_y = []
few_shot_dataset_y_np = list(range(0,10))
for batch_idx, (data, target) in enumerate(train_loader):
    num_of_samples = [x for x in Counter(few_shot_dataset_y_np).values()]
    pos_of_samples = [x for x in Counter(few_shot_dataset_y_np).keys()]
    if num_of_samples[pos_of_samples.index(target.cpu().numpy()[0])]<num_of_shots:
        few_shot_dataset.append(data)
        few_shot_dataset_y.append(target)
        few_shot_dataset_y_np.append(target.cpu().numpy()[0])
    if batch_idx>break_trashold:
        break

In [53]:
def train(epoch, dataset):
    network.train()
    for batch_idx, (data, target) in enumerate(dataset):
        optimizer.zero_grad()
        output = network(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

In [58]:
def test():
    network.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
          output = network(data)
          test_loss += F.nll_loss(output, target, size_average=False).item()
          pred = output.data.max(1, keepdim=True)[1]
          correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)
    print('Test set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

In [63]:
test()
for epoch in range(1, 30 + 1):
  train(epoch, zip(few_shot_dataset, few_shot_dataset_y))
  test()



Test set: Avg. loss: 2.3297, Accuracy: 923/10000 (9%)

Test set: Avg. loss: 2.2975, Accuracy: 1733/10000 (17%)

Test set: Avg. loss: 2.2817, Accuracy: 1439/10000 (14%)

Test set: Avg. loss: 2.2352, Accuracy: 2883/10000 (28%)

Test set: Avg. loss: 2.1042, Accuracy: 4201/10000 (42%)

Test set: Avg. loss: 2.0229, Accuracy: 3545/10000 (35%)

Test set: Avg. loss: 1.6951, Accuracy: 3911/10000 (39%)

Test set: Avg. loss: 1.6238, Accuracy: 4654/10000 (46%)

Test set: Avg. loss: 1.6330, Accuracy: 5453/10000 (54%)

Test set: Avg. loss: 1.6424, Accuracy: 5772/10000 (57%)

Test set: Avg. loss: 1.4794, Accuracy: 5897/10000 (58%)

Test set: Avg. loss: 1.6292, Accuracy: 4766/10000 (47%)

Test set: Avg. loss: 1.2377, Accuracy: 6760/10000 (67%)

Test set: Avg. loss: 1.4848, Accuracy: 5974/10000 (59%)

Test set: Avg. loss: 1.1240, Accuracy: 6316/10000 (63%)

Test set: Avg. loss: 1.9670, Accuracy: 2819/10000 (28%)

Test set: Avg. loss: 1.3919, Accuracy: 5894/10000 (58%)

Test set: Avg. loss: 1.2393, Accu