In [1]:
import sys
import tensorflow as tf
sys.path.append('../../generative-adversarial-nets/')
from ganetwork import GAN, CGAN
from sklearn import datasets
from sklearn.datasets import fetch_mldata
from sklearn.preprocessing import minmax_scale, LabelBinarizer
import matplotlib.pyplot as plt
from random import choices
from math import sqrt
import numpy as np
from imblearn.datasets import make_imbalance

In [2]:
mnist = fetch_mldata('MNIST original', data_home=".")
keep_indices = (mnist.target == 1) | (mnist.target == 7)   
X, y = minmax_scale(mnist.data.astype(np.float32))[keep_indices], mnist.target[keep_indices]
X, y = make_imbalance(X, y, ratio=0.01, random_state=0)
y = LabelBinarizer().fit_transform(y)

In [3]:
discriminator_layers=[(X.shape[1] + y.shape[1], None), (128, tf.nn.relu), (1, None)]
generator_layers=[(100 + y.shape[1], None), (128, tf.nn.relu), (X.shape[1], None)]
cgan = CGAN(discriminator_layers, generator_layers)

In [None]:
cgan.train(X, y, nb_epoch=10000, batch_size=64)

Epoch: 0, discriminator loss: 1.257534298915175e-09, generator loss: 0.22027164697647095
Epoch: 1, discriminator loss: 4.1022425028368445e-16, generator loss: 0.00012214250455144793
Epoch: 2, discriminator loss: 6.677991552621744e-23, generator loss: 0.06112115457653999
Epoch: 3, discriminator loss: 7.758989739151443e-24, generator loss: 3.9213037490844727
Epoch: 4, discriminator loss: 7.33682384588113e-21, generator loss: 9.206265449523926
Epoch: 5, discriminator loss: 1.8565339815058537e-21, generator loss: 12.966229438781738
Epoch: 6, discriminator loss: 1.967146002811058e-20, generator loss: 5.976703643798828
Epoch: 7, discriminator loss: 7.342251522785318e-16, generator loss: 2.1401638984680176
Epoch: 8, discriminator loss: 1.2421468179732377e-16, generator loss: 3.3467154502868652
Epoch: 9, discriminator loss: 2.932823119560987e-14, generator loss: 2.8317461013793945
Epoch: 10, discriminator loss: 1.1427316341447447e-12, generator loss: 0.9329402446746826
Epoch: 11, discriminator

In [None]:
X_generated = cgan.generate_samples(50, 1)

In [None]:
plt.rcParams['figure.figsize'] = (20, 20)
fig, ax = plt.subplots(1, 20)
img_dim = int(sqrt(X.shape[1]))
X_img = X_generated.reshape(X_generated.shape[0], img_dim, -1)
for ind in range(20):    
    ax[ind].imshow(X_img[ind], cmap='gray_r')
    ax[ind].axis('off')
plt.show()