In [1]:
import sys
import tensorflow as tf
sys.path.append('../../generative-adversarial-nets/')
from ganetwork import GAN, CGAN
from sklearn import datasets
from sklearn.datasets import fetch_mldata
from sklearn.preprocessing import minmax_scale, LabelBinarizer
import matplotlib.pyplot as plt
from random import choices
from math import sqrt
import numpy as np
from imblearn.datasets import make_imbalance

In [2]:
mnist = fetch_mldata('MNIST original', data_home=".")
keep_indices = (mnist.target == 1) | (mnist.target == 7)   
X, y = minmax_scale(mnist.data.astype(np.float32))[keep_indices], mnist.target[keep_indices]
X, y = make_imbalance(X, y, ratio=0.01, random_state=0)
y = LabelBinarizer().fit_transform(y)

In [3]:
discriminator_layers=[(X.shape[1] + y.shape[1], None), (150, tf.nn.tanh), (1, None)]
generator_layers=[(100 + y.shape[1], None), (150, tf.nn.tanh), (X.shape[1], None)]
cgan = CGAN(discriminator_layers, generator_layers)

In [None]:
cgan.train(X, y, nb_epoch=1000, batch_size=64)

Epoch: 0, discriminator loss: 0.0017298910534009337, generator loss: 14.272872924804688
Epoch: 1, discriminator loss: 0.0007518412894569337, generator loss: 14.251837730407715
Epoch: 2, discriminator loss: 0.0002260994806420058, generator loss: 13.241905212402344
Epoch: 3, discriminator loss: 0.000799640896730125, generator loss: 8.753179550170898
Epoch: 4, discriminator loss: 0.0005654217675328255, generator loss: 8.116411209106445
Epoch: 5, discriminator loss: 0.00056139484513551, generator loss: 8.895748138427734
Epoch: 6, discriminator loss: 0.0005146056064404547, generator loss: 7.125438213348389
Epoch: 7, discriminator loss: 0.0009390073246322572, generator loss: 7.236150741577148
Epoch: 8, discriminator loss: 0.0013366274069994688, generator loss: 7.750124931335449
Epoch: 9, discriminator loss: 0.0012763781705871224, generator loss: 6.274715423583984
Epoch: 10, discriminator loss: 0.0008186969207599759, generator loss: 5.032691955566406
Epoch: 11, discriminator loss: 0.000538144

In [None]:
X_generated = cgan.generate_samples(50, 1)

In [None]:
plt.rcParams['figure.figsize'] = (20, 20)
fig, ax = plt.subplots(1, 20)
img_dim = int(sqrt(X.shape[1]))
X_img = X_generated.reshape(X_generated.shape[0], img_dim, -1)
for ind in range(20):    
    ax[ind].imshow(X_img[ind], cmap='gray_r')
    ax[ind].axis('off')
plt.show()