In [1]:
import numpy as np
from keras.layers import *
from keras.models import Model
from keras import backend as K
import imageio, os
from keras.datasets import mnist

Using TensorFlow backend.


In [2]:
batch_size = 128
latent_dim = 20
epochs = 50
num_classes = 10
img_dim = 28
filters = 16
intermediate_dim = 256

In [3]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((-1, img_dim, img_dim, 1))
x_test = x_test.reshape((-1, img_dim, img_dim, 1))

In [4]:
#encoder
x = Input(shape = (img_dim, img_dim, 1))
h = x
for i in range(2):
    filters *= 2
    h = Conv2D(filters = filters, kernel_size = 3, strides = 2, padding = 'same')(h)
    h = LeakyReLU(0.2)(h)
    h = Conv2D(filters = filters, kernel_size = 3, strides = 1, padding = 'same')(h)
    h = LeakyReLU(0.2)(h)

h_shape = K.int_shape(h)[1:]
h = Flatten()(h)
z_mean = Dense(latent_dim)(h) #mean of p(z|x)
z_log_var = Dense(latent_dim)(h) #log variance of p(z|x)
encoder = Model(x, z_mean)

In [5]:
encoder.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 14, 14, 32)        320       
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 14, 14, 32)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 14, 14, 32)        9248      
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 14, 14, 32)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 7, 7, 64)          18496     
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 7, 7, 64)          0         
__________

In [6]:
#decoder
z = Input(shape = (latent_dim,))
h = z
h = Dense(np.prod(h_shape))(h)
h = Reshape(h_shape)(h)

In [7]:
for i in range(2):
    h = Conv2DTranspose(filters = filters, kernel_size=3, strides = 1, padding = 'same')(h)
    h = LeakyReLU(0.2)(h)
    h = Conv2DTranspose(filters = filters, kernel_size=3, strides = 2, padding = 'same')(h)
    h = LeakyReLU(0.2)(h)
    filters //= 2

x_recon = Conv2DTranspose(filters = 1, kernel_size = 3, activation = 'sigmoid', padding = 'same')(h)
decoder = Model(z, x_recon)
decoder.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         (None, 20)                0         
_________________________________________________________________
dense_3 (Dense)              (None, 3136)              65856     
_________________________________________________________________
reshape_1 (Reshape)          (None, 7, 7, 64)          0         
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 7, 7, 64)          36928     
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 7, 7, 64)          0         
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 14, 14, 64)        36928     
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 14, 14, 64)        0         
__________

In [9]:
z = Input(shape = (latent_dim,))
y = Dense(intermediate_dim, activation = 'relu')(z)
y = Dense(num_classes, activation='softmax')(y)

classifier = Model(z, y)

In [12]:
def sampling(args):
    z_mean, z_log_var = args
    epsilion = K.random_normal(shape = K.shape(z_mean))
    return z_mean + K.exp(z_log_var / 2) * epsilion

In [13]:
z = Lambda(sampling, output_shape = (latent_dim,))([z_mean, z_log_var])
x_recon = decoder(z)
y = classifier(z)

In [29]:
class Gaussian(Layer):
    #define the mean value of q(z|y), every class has a mean value
    #output z - mu
    def __init__(self, num_classes, **kwargs):
        self.num_classes = num_classes
        super(Gaussian, self).__init__(**kwargs)
    
    def build(self, input_shape):
        latent_dim = input_shape[-1]
        self.mean = self.add_weight(name = 'mean',
                                   shape = (self.num_classes, latent_dim),
                                   initializer = 'zeros')
    
    #function
    def call(self, inputs):
        z = inputs #z.shape = (batch_size, latent_dim)
        z = K.expand_dims(z, 1) #(batch_size, 1, latent_dim)
        return z - K.expand_dims(self.mean, 0)
    
    def compute_output_shape(self, input_shape):
        return (None, self.num_classes, input_shape[-1])

In [30]:
gaussian = Gaussian(num_classes)
z_prior_mean = gaussian(z)

In [31]:
vae = Model(x, [x_recon, z_prior_mean, y])

In [32]:
vae.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 14, 14, 32)   320         input_1[0][0]                    
__________________________________________________________________________________________________
leaky_re_lu_1 (LeakyReLU)       (None, 14, 14, 32)   0           conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 14, 14, 32)   9248        leaky_re_lu_1[0][0]              
__________________________________________________________________________________________________
leaky_re_l

In [33]:
#compute loss
z_mean = K.expand_dims(z_mean, 1)
z_log_var = K.expand_dims(z_log_var, 1)

lamb = 2.5 #weight of the reconstruction error
recon_loss = 0.5 * K.mean((x - x_recon)**2, 0)

In [41]:
kl_loss = -0.5 * (z_log_var - K.square(z_prior_mean))
kl_loss = K.mean(K.batch_dot(K.expand_dims(y, 1), kl_loss), 0)

In [43]:
cat_loss = K.mean(y * K.log(y+K.epsilon()), 0)

In [44]:
vae_loss = lamb * K.sum(recon_loss) + K.sum(kl_loss) + K.sum(cat_loss)

In [45]:
vae.add_loss(vae_loss)
vae.compile(optimizer = 'adam')

In [46]:
vae.fit(x_train,
       shuffle = True,
       epochs = epochs,
       batch_size = batch_size,
       validation_data = (x_test, None))

Train on 60000 samples, validate on 10000 samples
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50


<keras.callbacks.History at 0x7f0b24a2bcf8>

In [47]:
means = K.eval(gaussian.mean)
x_train_encoded = encoder.predict(x_train)
y_train_pred = classifier.predict(x_train_encoded).argmax(axis = 1)
x_test_encoded = encoder.predict(x_test)
y_test_pred = classifier.predict(x_test_encoded).argmax(axis = 1)

In [53]:
#observe the samples belonging to the same category
def cluster_sample(path, category = 0):
    n = 8
    figure = np.zeros((img_dim*n, img_dim*n))
    idxs = np.where(y_train_pred == category)[0]
    for i in range(n):
        for j in range(n):
            digit = x_train[np.random.choice(idxs)]
            digit = digit.reshape((img_dim, img_dim))
            figure[i * img_dim : (i+1) * img_dim, j * img_dim : (j+1) * img_dim] = digit
    imageio.imwrite(path, figure * 255)

In [56]:
#generate specific category image based on the mean value of each class
def random_sample(path, category = 0, std = 1.):
    n = 8
    figure = np.zeros((img_dim*n, img_dim*n))
    for i in range(n):
        for j in range(n):
            z_sample = np.array(np.random.randn((1, latent_dim))) * std + mean[category]
            digit = decoder.predict(z_sample)
            digit = digit[0].reshape((img_dim, img_dim))
            figure[i * img_dim : (i+1) * img_dim, j * img_dim : (j+1) * img_dim] = digit
    imageio.imwrite(path, figure * 255)

In [51]:
if not os.path.exists('samples'):
    os.mkdir('samples')

In [58]:
for i in range(num_classes): 
    cluster_sample('samples/clustering_%s.png' % i, i)



In [None]:
#calculate accuracy
right = 0.
for i in range(10):
    a = np.bincount(y_train[y_train_pred == i])
    right += a[i]
print 'train acc: %s' % (right / len(y_train))

In [None]:
right = 0.
for i in range(10):
    a = np.bincount(y_test[y_test_pred == i])
    right += a[i]
print 'test acc: %s' % (right / len(y_test))