In [10]:
import keras
from keras.models import Sequential
import keras.layers as layers
from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.model_selection import GridSearchCV
from keras.datasets import mnist
from keras.utils.vis_utils import plot_model
import numpy as np

In [3]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape((x_train.shape[0],28,28,1))
x_test = x_test.reshape((x_test.shape[0],28,28,1))

In [62]:
def build_classifier(optim):
    classifier = Sequential()
    classifier.add(layers.Conv2D(32, (3, 3), input_shape = (28, 28, 1), activation='relu'))
    classifier.add(layers.MaxPooling2D(pool_size = (2,2)))
    classifier.add(layers.Conv2D(32, (3, 3), activation='relu'))
    classifier.add(layers.MaxPooling2D(pool_size = (2,2)))
    classifier.add(layers.Flatten())
    classifier.add(layers.Dense(units = 128, activation='relu'))
    classifier.add(layers.Dense(units = 10, activation = 'softmax'))
    classifier.compile(optimizer='adam', loss = 'sparse_categorical_crossentropy', metrics = ['accuracy'])
    return classifier

In [64]:
classifier = KerasClassifier(build_fn = build_classifier)
parameters = {
    'batch_size' : [200],
    'epochs' : [10],
    'optim' : ['adam']
}
grid_search = GridSearchCV(estimator = classifier, 
                          param_grid = parameters,
                          scoring = 'accuracy',
                          cv = 10)
plot_model(build_classifier(parameters), 'fig.png', show_shapes=True, show_layer_names=True, rankdir = 'LR')
grid_search =  grid_search.fit(x_train, y_train)
best_params = grid_search.best_params_
best_acc = grid_search.best_scores_

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Epoch 1/10
Epoch 2/10

KeyboardInterrupt: 

In [65]:
build_classifier(parameters).evaluate(x_test, y_test)



[13.96299585723877, 0.0868]

In [None]:
N, H, W, D = x_train.shape
generator = Sequential()
generator.add(layers.Conv2D(32, (3, 3), input_shape = (H, W, D), activation='relu'))
generator.add(layers.MaxPooling2D(pool_size = (2,2)))
generator.add(layers.Conv2D(32, (3, 3), activation='relu'))
generator.add(layers.MaxPooling2D(pool_size = (2,2)))
generator.add(layers.Flatten())
generator.add(layers.Dense(units = 128, activation='relu'))
generator.add(layers.Dense(units = H * W * D, activation = 'relu'))
generator.compile(optimizer='adam', loss = 'sparse_categorical_crossentropy', metrics = ['accuracy'])

In [None]:
discriminator = Sequential()
discriminator.add(layers.Conv2D(32, (3, 3), input_shape = (H, W, D), activation='relu'))
discriminator.add(layers.MaxPooling2D(pool_size = (2,2)))
discriminator.add(layers.Conv2D(32, (3, 3), activation='relu'))
discriminator.add(layers.MaxPooling2D(pool_size = (2,2)))
discriminator.add(layers.Flatten())
discriminator.add(layers.Dense(units = 128, activation='relu'))
discriminator.add(layers.Dense(units = 1, activation = 'sigmoid'))
def disc_loss(y_train, y_pred):
    
discriminator.compile(optimizer='adam', loss = 'sparse_categorical_crossentropy', metrics = ['accuracy'])

In [50]:
epochs = 10
disc_epochs = 30
batch_size = 200
for i in range(epochs):
    for i in range(disc_epochs):
        noise = np.random.rand(batch_size * H * W * D).reshape(batch_size, H, W, D)
        target_data = np.random.choice(N, batch_size, replace=False)
        
        

In [42]:
x = np.random.randn(1000)
print(x.var())

0.9385792004643864


In [48]:
np.random.choice(x_train.shape[0], batch_size, replace=False)

array([35896, 58791, 35576, 58734, 23684, 42821, 13112, 30351,  8207,
       20264, 31491, 29591, 56112, 13268,   964, 48299, 10283, 50788,
       46817, 28239, 16088, 47194,  1385, 55406, 37001, 21511, 51026,
       53088, 37722, 42181,  2487, 24262, 52000, 18597, 48799, 53461,
       22899, 28270, 19284, 53861,  9668, 40636, 58529, 35506, 21091,
       32314, 46015, 45056, 28208, 16747,  1337, 13155, 57663, 10603,
          38, 28416, 38799, 37839, 53877, 41084, 18797,  1704, 26207,
       50557, 51299, 19403, 16524,  5671, 57389, 25144, 54666, 30844,
       53926, 25878, 39290, 38434, 24310, 57192, 58236,  5151, 56871,
       35708, 33984, 34208, 20784,  9674, 52314, 28302,  1177, 13826,
       15076, 39242, 12246, 50779, 51766, 32195, 16606, 25505,  5581,
       54703, 40044, 20171, 46964, 57562, 22605, 58526, 29748, 12101,
        3297, 42661,  4610, 11668,  3422, 19649,  5192, 17150, 59338,
       26613,  3050, 24890, 44131, 32479, 12355, 18875,  3444, 25591,
       14585, 41417,