In [1]:
import torch
import torch.nn.functional as nn
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
from tensorflow.examples.tutorials.mnist import input_data

  from ._conv import register_converters as _register_converters


In [2]:
%matplotlib notebook

In [2]:
mnist = input_data.read_data_sets('../../MNIST_data',one_hot=True)

Extracting ../../MNIST_data/train-images-idx3-ubyte.gz
Extracting ../../MNIST_data/train-labels-idx1-ubyte.gz
Extracting ../../MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../../MNIST_data/t10k-labels-idx1-ubyte.gz


In [3]:
n_way = 5
n_shot = 5
mb_size = 1*n_shot
Z_dim = 100
X_dim = mnist.train.images.shape[1]
y_dim = mnist.train.labels.shape[1]
n_data = mnist.train.images.shape[0]
h_dim = 128
cnt = 0
lr = 1e-3

In [4]:
def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / np.sqrt(in_dim / 2.)
    return Variable(torch.randn(*size) * xavier_stddev, requires_grad=True)

Wxh = xavier_init(size=[X_dim + y_dim, h_dim])
bxh = Variable(torch.zeros(h_dim), requires_grad=True)

Whz_mu = xavier_init(size=[h_dim, Z_dim])
bhz_mu = Variable(torch.zeros(Z_dim), requires_grad=True)

Whz_var = xavier_init(size=[h_dim, Z_dim])
bhz_var = Variable(torch.zeros(Z_dim), requires_grad=True)


def Q(X, c):
    inputs = torch.cat([X, c], 1)
    h = nn.relu(inputs @ Wxh + bxh.repeat(inputs.size(0), 1))
    z_mu = h @ Whz_mu + bhz_mu.repeat(h.size(0), 1)
    z_var = h @ Whz_var + bhz_var.repeat(h.size(0), 1)
    return z_mu, z_var


def sample_z(mu, log_var):
    eps = Variable(torch.randn(mb_size, Z_dim))
    return mu + torch.exp(log_var / 2) * eps
    #return mu + eps*log_var
Wzh = xavier_init(size=[Z_dim + y_dim, h_dim])
bzh = Variable(torch.zeros(h_dim), requires_grad=True)

Whx = xavier_init(size=[h_dim, X_dim])
bhx = Variable(torch.zeros(X_dim), requires_grad=True)


def P(z, c):
    inputs = torch.cat([z, c], 1)
    h = nn.relu(inputs @ Wzh + bzh.repeat(inputs.size(0), 1))
    X = nn.sigmoid(h @ Whx + bhx.repeat(h.size(0), 1))
    return X

In [5]:
train_images = mnist.train.images
train_labels = mnist.train.labels

img,lab = [[] for _ in range(10)],[[] for _ in range(10)]
for i in range(n_data):
    index = train_labels[i].tolist().index(1)
    img[index].append(train_images[i])
    lab[index].append(train_labels[i])
for i in range(len(img)):
    img[i] = np.asarray(img[i])
    lab[i] = np.asarray(lab[i])


In [6]:
params = [Wxh, bxh, Whz_mu, bhz_mu, Whz_var, bhz_var,
          Wzh, bzh, Whx, bhx]

solver = optim.Adam(params, lr=lr)

import random
for it in range(100000):
    #select_class = random.sample(range(0,len(img)),n_way)
    #for i in select_class:
    #    training_set, training_label = [], []
    #    index = random.sample(range(1, len(img[i])), n_shot)
    #    for j in index:
    #        training_set.append(img[i][j])
    #        training_label.append(lab[i][j])
        
    #    X ,c = np.asarray(training_set), np.asarray(training_label)
    X, c = mnist.train.next_batch(mb_size)

    X = Variable(torch.from_numpy(X))
    c = Variable(torch.from_numpy(c.astype('float32')))
    # Forward
    z_mu, z_var = Q(X, c)
    z = sample_z(z_mu, z_var)
    X_sample = P(z, c)
    # Loss
    recon_loss = nn.binary_cross_entropy(X_sample, X, size_average=False) / mb_size
    kl_loss = torch.mean(0.5 * torch.sum(torch.exp(z_var) + z_mu**2 - 1. - z_var, 1))
    loss = recon_loss + kl_loss

    loss.backward()
    solver.step()
    for p in params:
        if p.grad is not None:
            data = p.grad.data
            p.grad = Variable(data.new().resize_as_(data).zero_())

    if it % 5000 == 0:
        print('Iter-{}; Loss: {:.4}'.format(it, loss.data[0]))

        c = np.zeros(shape=[mb_size, y_dim], dtype='float32')
        c[:, np.random.randint(0, 10)] = 1.
        print('now generating imgae with {}'.format(c[0].tolist().index(1.)))
        c = Variable(torch.from_numpy(c))
        z = Variable(torch.randn(mb_size, Z_dim))
        samples = P(z, c).data.numpy()[:16]

        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)

        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

        if not os.path.exists('out/'):
            os.makedirs('out/')

        plt.savefig('out/{}.png'.format(str(cnt).zfill(3)), bbox_inches='tight')
        cnt += 1
        plt.close(fig)

Iter-0; Loss: 665.6
now generating imgae with 9
Iter-5000; Loss: 139.4
now generating imgae with 0
Iter-10000; Loss: 113.8
now generating imgae with 3
Iter-15000; Loss: 122.8
now generating imgae with 7
Iter-20000; Loss: 128.2
now generating imgae with 5


KeyboardInterrupt: 

In [12]:
from mpl_toolkits.mplot3d import Axes3D
zz = z.data.numpy()
xx = []
yy = []
value = []
for i in range(zz.shape[0]):
    for j in range(zz.shape[1]):
        xx.append(i)
        yy.append(j)
        value.append(zz[i][j])

ax = plt.subplot(111, projection='3d')
ax.scatter(xx,yy,zz,s=1)
plt.show()

<IPython.core.display.Javascript object>

-2.4094708