# Semi-Supervised Learning using GAN as the Generative Model

The demo is developed based on 
https://machinelearningmastery.com/semi-supervised-generative-adversarial-network/

In [1]:
# example of semi-supervised gan for mnist
from numpy import expand_dims
from numpy import zeros
from numpy import ones
from numpy import asarray
from numpy.random import randn
from numpy.random import randint
from tensorflow.keras.datasets.mnist import load_data
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Reshape
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Conv2DTranspose
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import Lambda
from tensorflow.keras.layers import Activation
from matplotlib import pyplot
from tensorflow.keras import backend

### Defined a customized activation function for unsupervised discriminator model 

In [2]:
# custom activation function
def custom_activation(output):
    logexpsum = backend.sum(backend.exp(output), axis=-1, keepdims=True)
    result = logexpsum / (logexpsum + 1.0)
    return result

### Defined the Discriminator
c_model is the supervised discriminator model
d_model is the unsupervised discriminator model

In [3]:
# define the standalone supervised and unsupervised discriminator models
def define_discriminator(in_shape=(28,28,1), n_classes=10):
    # image input
    in_image = Input(shape=in_shape)
    # downsample
    fe = Conv2D(128, (3,3), strides=(2,2), padding='same')(in_image)
    fe = LeakyReLU(alpha=0.2)(fe)
    # downsample
    fe = Conv2D(128, (3,3), strides=(2,2), padding='same')(fe)
    fe = LeakyReLU(alpha=0.2)(fe)
    # downsample
    fe = Conv2D(128, (3,3), strides=(2,2), padding='same')(fe)
    fe = LeakyReLU(alpha=0.2)(fe)
    # flatten feature maps
    fe = Flatten()(fe)
    # dropout
    fe = Dropout(0.4)(fe)
    # output layer nodes
    fe = Dense(n_classes)(fe)
    # supervised output
    c_out_layer = Activation('softmax')(fe)
    # define and compile supervised discriminator model
    c_model = Model(in_image, c_out_layer)
    c_model.compile(loss='sparse_categorical_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5), metrics=['accuracy'])
    # unsupervised output
    d_out_layer = Lambda(custom_activation)(fe)
    # define and compile unsupervised discriminator model
    d_model = Model(in_image, d_out_layer)
    d_model.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5))
    return d_model, c_model

### Define the Generator

In [4]:
# define the standalone generator model
def define_generator(latent_dim):
    # image generator input
    in_lat = Input(shape=(latent_dim,))
    # foundation for 7x7 image
    n_nodes = 128 * 7 * 7
    gen = Dense(n_nodes)(in_lat)
    gen = LeakyReLU(alpha=0.2)(gen)
    gen = Reshape((7, 7, 128))(gen)
    # upsample to 14x14
    gen = Conv2DTranspose(128, (4,4), strides=(2,2), padding='same')(gen)
    gen = LeakyReLU(alpha=0.2)(gen)
    # upsample to 28x28
    gen = Conv2DTranspose(128, (4,4), strides=(2,2), padding='same')(gen)
    gen = LeakyReLU(alpha=0.2)(gen)
    # output
    out_layer = Conv2D(1, (7,7), activation='tanh', padding='same')(gen)
    # define model
    model = Model(in_lat, out_layer)
    return model

### Define the GAN

In [5]:
# define the combined generator and discriminator model, for updating the generator
def define_gan(g_model, d_model):
    # make weights in the discriminator not trainable
    d_model.trainable = False
    # connect image output from generator as input to discriminator
    gan_output = d_model(g_model.output)
    # define gan model as taking noise and outputting a classification
    model = Model(g_model.input, gan_output)
    # compile model
    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt)
    return model

### Define other functions

In [6]:
# load the images
def load_real_samples():
    # load dataset
    (trainX, trainy), (_, _) = load_data()
    # expand to 3d, e.g. add channels
    X = expand_dims(trainX, axis=-1)
    # convert from ints to floats
    X = X.astype('float32')
    # scale from [0,255] to [-1,1]
    X = (X - 127.5) / 127.5
    print(X.shape, trainy.shape)
    return [X, trainy]

In [7]:
# select a supervised subset of the dataset, ensures classes are balanced
def select_supervised_samples(dataset, n_samples=100, n_classes=10):
    X, y = dataset
    X_list, y_list = list(), list()
    n_per_class = int(n_samples / n_classes)
    for i in range(n_classes):
        # get all images for this class
        X_with_class = X[y == i]
        # choose random instances
        ix = randint(0, len(X_with_class), n_per_class)
        # add to list
        [X_list.append(X_with_class[j]) for j in ix]
        [y_list.append(i) for j in ix]
    return asarray(X_list), asarray(y_list)

In [8]:
# select real samples
def generate_real_samples(dataset, n_samples):
    # split into images and labels
    images, labels = dataset
    # choose random instances
    ix = randint(0, images.shape[0], n_samples)
    # select images and labels
    X, labels = images[ix], labels[ix]
    # generate class labels
    y = ones((n_samples, 1))
    return [X, labels], y

In [9]:
# generate points in latent space as input for the generator
def generate_latent_points(latent_dim, n_samples):
    # generate points in the latent space
    z_input = randn(latent_dim * n_samples)
    # reshape into a batch of inputs for the network
    z_input = z_input.reshape(n_samples, latent_dim)
    return z_input

In [10]:
# use the generator to generate n fake examples, with class labels
def generate_fake_samples(generator, latent_dim, n_samples):
    # generate points in latent space
    z_input = generate_latent_points(latent_dim, n_samples)
    # predict outputs
    images = generator.predict(z_input)
    # create class labels
    y = zeros((n_samples, 1))
    return images, y

In [11]:
# generate samples and save as a plot and save the model
def summarize_performance(step, g_model, c_model, latent_dim, dataset, n_samples=100):
    # prepare fake examples
    X, _ = generate_fake_samples(g_model, latent_dim, n_samples)
    # scale from [-1,1] to [0,1]
    X = (X + 1) / 2.0
    # plot images
    for i in range(100):
        # define subplot
        pyplot.subplot(10, 10, 1 + i)
        # turn off axis
        pyplot.axis('off')
        # plot raw pixel data
        pyplot.imshow(X[i, :, :, 0], cmap='gray_r')
    # save plot to file
    filename1 = 'generated_plot_%04d.png' % (step+1)
    pyplot.savefig(filename1)
    pyplot.close()
    # evaluate the classifier model
    X, y = dataset
    _, acc = c_model.evaluate(X, y, verbose=0)
    print('Classifier Accuracy: %.3f%%' % (acc * 100))
    # save the generator model
    filename2 = 'g_model_%04d.h5' % (step+1)
    g_model.save(filename2)
    # save the classifier model
    filename3 = 'c_model_%04d.h5' % (step+1)
    c_model.save(filename3)
    print('>Saved: %s, %s, and %s' % (filename1, filename2, filename3))

### Train the semi-supervised GAN

In [12]:
# train the generator and discriminator
def train(g_model, d_model, c_model, gan_model, dataset, latent_dim, n_epochs=4, n_batch=100):
    # select supervised dataset
    X_sup, y_sup = select_supervised_samples(dataset)
    print(X_sup.shape, y_sup.shape)
    # calculate the number of batches per training epoch
    bat_per_epo = int(dataset[0].shape[0] / n_batch)
    # calculate the number of training iterations
    n_steps = bat_per_epo * n_epochs
    # calculate the size of half a batch of samples
    half_batch = int(n_batch / 2)
    print('n_epochs=%d, n_batch=%d, 1/2=%d, b/e=%d, steps=%d' % (n_epochs, n_batch, half_batch, bat_per_epo, n_steps))
    # manually enumerate epochs
    for i in range(n_steps):
        # update supervised discriminator (c)
        [Xsup_real, ysup_real], _ = generate_real_samples([X_sup, y_sup], half_batch)
        c_loss, c_acc = c_model.train_on_batch(Xsup_real, ysup_real)
        # update unsupervised discriminator (d)
        [X_real, _], y_real = generate_real_samples(dataset, half_batch)
        d_loss1 = d_model.train_on_batch(X_real, y_real)
        X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
        d_loss2 = d_model.train_on_batch(X_fake, y_fake)
        # update generator (g)
        X_gan, y_gan = generate_latent_points(latent_dim, n_batch), ones((n_batch, 1))
        g_loss = gan_model.train_on_batch(X_gan, y_gan)
        # summarize loss on this batch
        print('>%d, c[%.3f,%.0f], d[%.3f,%.3f], g[%.3f]' % (i+1, c_loss, c_acc*100, d_loss1, d_loss2, g_loss))
        # evaluate the model performance every so often
        if (i+1) % (bat_per_epo * 1) == 0:
            summarize_performance(i, g_model, c_model, latent_dim, dataset)

In [13]:
# size of the latent space
latent_dim = 100
# create the discriminator models
d_model, c_model = define_discriminator()
# create the generator
g_model = define_generator(latent_dim)
# create the gan
gan_model = define_gan(g_model, d_model)
# load image data
dataset = load_real_samples()
# train model
train(g_model, d_model, c_model, gan_model, dataset, latent_dim)

(60000, 28, 28, 1) (60000,)
(100, 28, 28, 1) (100,)
n_epochs=4, n_batch=100, 1/2=50, b/e=600, steps=2400
>1, c[2.304,8], d[0.097,2.399], g[0.095]
>2, c[2.301,8], d[0.092,2.398], g[0.095]
>3, c[2.296,18], d[0.088,2.398], g[0.095]
>4, c[2.271,18], d[0.085,2.400], g[0.095]
>5, c[2.258,22], d[0.081,2.402], g[0.095]
>6, c[2.283,14], d[0.080,2.402], g[0.096]
>7, c[2.240,28], d[0.080,2.396], g[0.097]
>8, c[2.249,22], d[0.082,2.385], g[0.098]
>9, c[2.251,20], d[0.083,2.373], g[0.100]
>10, c[2.224,34], d[0.087,2.368], g[0.100]
>11, c[2.215,18], d[0.086,2.363], g[0.100]
>12, c[2.187,28], d[0.082,2.358], g[0.101]
>13, c[2.151,36], d[0.072,2.350], g[0.102]
>14, c[2.143,30], d[0.062,2.341], g[0.103]
>15, c[2.123,30], d[0.053,2.324], g[0.106]
>16, c[2.076,36], d[0.042,2.298], g[0.110]
>17, c[2.097,32], d[0.033,2.264], g[0.115]
>18, c[2.067,32], d[0.025,2.219], g[0.125]
>19, c[2.009,38], d[0.022,2.140], g[0.141]
>20, c[1.940,46], d[0.026,2.004], g[0.174]
>21, c[1.848,42], d[0.047,1.825], g[0.214]
>22

>188, c[0.121,98], d[0.902,0.915], g[1.494]
>189, c[0.139,98], d[0.760,0.867], g[1.461]
>190, c[0.133,98], d[0.940,0.886], g[1.178]
>191, c[0.161,100], d[0.663,0.781], g[1.400]
>192, c[0.135,96], d[0.881,0.704], g[1.278]
>193, c[0.143,98], d[0.827,0.892], g[1.274]
>194, c[0.158,100], d[0.712,0.810], g[1.370]
>195, c[0.135,100], d[0.744,1.039], g[1.459]
>196, c[0.111,98], d[0.882,0.804], g[1.440]
>197, c[0.094,98], d[0.880,1.164], g[1.218]
>198, c[0.176,96], d[0.709,0.882], g[1.493]
>199, c[0.109,100], d[1.078,1.025], g[1.422]
>200, c[0.142,96], d[1.170,1.089], g[1.206]
>201, c[0.142,96], d[0.981,0.995], g[1.198]
>202, c[0.154,96], d[0.965,1.046], g[1.125]
>203, c[0.143,98], d[0.913,0.780], g[1.297]
>204, c[0.080,100], d[0.956,0.736], g[1.124]
>205, c[0.124,100], d[0.832,0.807], g[0.972]
>206, c[0.094,98], d[0.589,0.693], g[1.293]
>207, c[0.150,98], d[0.794,0.706], g[1.369]
>208, c[0.116,100], d[0.840,0.955], g[1.263]
>209, c[0.111,100], d[0.907,0.947], g[1.381]
>210, c[0.168,90], d[0.8

>372, c[0.099,98], d[0.604,0.895], g[1.094]
>373, c[0.057,100], d[0.646,0.930], g[1.400]
>374, c[0.041,100], d[0.707,0.608], g[1.167]
>375, c[0.045,100], d[0.827,0.630], g[1.101]
>376, c[0.044,100], d[0.755,0.703], g[1.014]
>377, c[0.043,100], d[0.584,0.814], g[1.206]
>378, c[0.071,100], d[0.897,0.848], g[1.207]
>379, c[0.031,100], d[0.940,1.053], g[1.099]
>380, c[0.058,100], d[0.634,0.818], g[1.183]
>381, c[0.050,100], d[0.681,0.830], g[1.625]
>382, c[0.047,100], d[0.742,0.538], g[1.447]
>383, c[0.045,100], d[0.630,0.692], g[1.464]
>384, c[0.034,100], d[1.019,0.660], g[1.152]
>385, c[0.064,100], d[0.653,0.926], g[1.375]
>386, c[0.047,100], d[0.679,0.689], g[1.486]
>387, c[0.066,98], d[1.058,0.841], g[1.350]
>388, c[0.046,100], d[1.008,0.857], g[1.154]
>389, c[0.064,100], d[0.602,0.733], g[1.232]
>390, c[0.059,100], d[0.806,0.910], g[1.321]
>391, c[0.052,100], d[0.827,0.739], g[1.257]
>392, c[0.062,100], d[0.744,0.742], g[1.131]
>393, c[0.061,100], d[0.754,1.011], g[1.270]
>394, c[0.04

>557, c[0.028,100], d[0.752,0.808], g[1.161]
>558, c[0.017,100], d[0.795,0.882], g[1.280]
>559, c[0.037,100], d[0.632,0.739], g[1.180]
>560, c[0.028,100], d[0.831,0.930], g[1.258]
>561, c[0.034,100], d[0.937,0.880], g[1.162]
>562, c[0.054,100], d[0.589,0.745], g[1.279]
>563, c[0.030,100], d[0.667,0.750], g[1.433]
>564, c[0.019,100], d[0.832,0.589], g[1.339]
>565, c[0.054,100], d[0.759,0.521], g[1.416]
>566, c[0.034,100], d[0.492,0.587], g[1.241]
>567, c[0.019,100], d[0.653,0.593], g[1.188]
>568, c[0.034,100], d[0.751,0.652], g[1.178]
>569, c[0.040,100], d[0.837,0.756], g[1.240]
>570, c[0.024,100], d[0.580,0.632], g[1.266]
>571, c[0.022,100], d[0.988,0.816], g[1.117]
>572, c[0.051,98], d[0.735,0.925], g[1.106]
>573, c[0.031,100], d[0.682,0.832], g[1.112]
>574, c[0.041,100], d[0.570,0.720], g[1.162]
>575, c[0.032,100], d[0.731,0.835], g[1.302]
>576, c[0.031,100], d[0.845,0.664], g[1.183]
>577, c[0.026,100], d[0.744,0.883], g[1.408]
>578, c[0.029,100], d[0.655,0.848], g[1.350]
>579, c[0.0

>738, c[0.017,100], d[0.716,0.667], g[1.162]
>739, c[0.046,100], d[0.631,0.682], g[1.217]
>740, c[0.027,100], d[0.669,0.585], g[1.234]
>741, c[0.020,100], d[0.667,0.721], g[1.324]
>742, c[0.019,100], d[0.776,0.600], g[1.129]
>743, c[0.031,100], d[0.608,0.836], g[1.088]
>744, c[0.027,100], d[0.739,0.598], g[1.291]
>745, c[0.018,100], d[0.565,0.704], g[1.292]
>746, c[0.024,100], d[0.636,0.654], g[1.444]
>747, c[0.016,100], d[0.762,0.539], g[1.292]
>748, c[0.017,100], d[0.659,0.652], g[1.291]
>749, c[0.017,100], d[0.718,0.666], g[1.238]
>750, c[0.042,100], d[0.679,0.588], g[1.130]
>751, c[0.046,100], d[0.794,0.715], g[1.115]
>752, c[0.018,100], d[0.465,0.537], g[1.285]
>753, c[0.011,100], d[0.600,0.653], g[1.290]
>754, c[0.035,100], d[0.653,0.727], g[1.366]
>755, c[0.010,100], d[0.690,0.555], g[1.386]
>756, c[0.010,100], d[0.544,0.525], g[1.390]
>757, c[0.024,100], d[0.815,0.756], g[1.175]
>758, c[0.016,100], d[0.690,0.566], g[1.163]
>759, c[0.026,100], d[0.507,0.731], g[1.277]
>760, c[0.

>922, c[0.012,100], d[0.776,0.839], g[1.308]
>923, c[0.014,100], d[0.634,0.771], g[1.273]
>924, c[0.017,100], d[0.629,0.570], g[1.264]
>925, c[0.009,100], d[0.755,0.853], g[1.366]
>926, c[0.016,100], d[0.588,0.848], g[1.228]
>927, c[0.022,100], d[0.679,0.512], g[1.390]
>928, c[0.021,100], d[0.700,0.608], g[1.509]
>929, c[0.008,100], d[0.734,0.802], g[1.138]
>930, c[0.010,100], d[0.713,0.773], g[1.228]
>931, c[0.020,100], d[0.858,0.649], g[1.257]
>932, c[0.031,100], d[0.711,0.740], g[1.391]
>933, c[0.022,100], d[0.632,0.763], g[1.401]
>934, c[0.010,100], d[0.681,0.594], g[1.370]
>935, c[0.016,100], d[0.623,0.793], g[1.400]
>936, c[0.014,100], d[0.592,0.627], g[1.365]
>937, c[0.019,100], d[0.609,0.620], g[1.463]
>938, c[0.014,100], d[0.852,0.688], g[1.253]
>939, c[0.021,100], d[0.917,0.643], g[1.159]
>940, c[0.014,100], d[0.535,1.024], g[1.443]
>941, c[0.016,100], d[0.713,0.570], g[1.288]
>942, c[0.012,100], d[0.676,0.761], g[1.232]
>943, c[0.015,100], d[0.615,0.611], g[1.339]
>944, c[0.

>1104, c[0.015,100], d[0.616,0.627], g[1.236]
>1105, c[0.011,100], d[0.612,0.805], g[1.514]
>1106, c[0.010,100], d[0.917,0.638], g[1.388]
>1107, c[0.049,98], d[0.763,0.576], g[1.206]
>1108, c[0.024,100], d[0.564,0.649], g[1.243]
>1109, c[0.009,100], d[0.671,0.660], g[1.148]
>1110, c[0.023,100], d[0.437,0.903], g[1.315]
>1111, c[0.011,100], d[0.733,0.776], g[1.428]
>1112, c[0.008,100], d[0.781,0.897], g[1.350]
>1113, c[0.016,100], d[1.007,0.720], g[1.281]
>1114, c[0.020,100], d[0.504,0.735], g[1.361]
>1115, c[0.015,100], d[0.736,0.522], g[1.134]
>1116, c[0.019,100], d[0.546,0.598], g[1.133]
>1117, c[0.022,100], d[0.766,0.760], g[1.277]
>1118, c[0.019,100], d[0.577,0.734], g[1.314]
>1119, c[0.009,100], d[1.053,0.837], g[1.190]
>1120, c[0.013,100], d[0.627,0.636], g[1.262]
>1121, c[0.021,100], d[0.696,0.766], g[1.514]
>1122, c[0.019,100], d[0.692,0.565], g[1.373]
>1123, c[0.013,100], d[0.718,0.571], g[1.206]
>1124, c[0.027,100], d[0.703,0.628], g[1.133]
>1125, c[0.009,100], d[0.703,0.649]

>1283, c[0.007,100], d[0.554,0.767], g[1.580]
>1284, c[0.013,100], d[0.682,0.646], g[1.633]
>1285, c[0.005,100], d[0.875,0.601], g[1.560]
>1286, c[0.010,100], d[0.612,0.683], g[1.408]
>1287, c[0.015,100], d[0.667,0.621], g[1.359]
>1288, c[0.012,100], d[0.910,0.664], g[1.399]
>1289, c[0.015,100], d[0.670,0.771], g[1.325]
>1290, c[0.011,100], d[0.634,0.594], g[1.250]
>1291, c[0.011,100], d[0.617,0.694], g[1.317]
>1292, c[0.017,100], d[0.847,0.767], g[1.274]
>1293, c[0.013,100], d[0.471,0.583], g[1.378]
>1294, c[0.015,100], d[0.614,0.631], g[1.315]
>1295, c[0.013,100], d[0.637,0.598], g[1.308]
>1296, c[0.014,100], d[0.759,0.850], g[1.363]
>1297, c[0.012,100], d[0.823,0.581], g[1.315]
>1298, c[0.008,100], d[0.643,0.568], g[1.272]
>1299, c[0.012,100], d[0.655,0.830], g[1.597]
>1300, c[0.016,100], d[0.611,0.836], g[1.492]
>1301, c[0.009,100], d[0.692,0.573], g[1.381]
>1302, c[0.012,100], d[0.703,0.693], g[1.448]
>1303, c[0.013,100], d[0.471,0.586], g[1.339]
>1304, c[0.010,100], d[0.688,0.443

>1464, c[0.006,100], d[0.796,0.666], g[1.417]
>1465, c[0.007,100], d[0.693,0.699], g[1.411]
>1466, c[0.008,100], d[0.751,0.744], g[1.352]
>1467, c[0.009,100], d[0.745,0.823], g[1.512]
>1468, c[0.009,100], d[0.998,0.642], g[1.305]
>1469, c[0.009,100], d[0.641,0.710], g[1.218]
>1470, c[0.010,100], d[0.867,0.777], g[1.236]
>1471, c[0.018,100], d[0.728,0.745], g[1.407]
>1472, c[0.021,100], d[0.731,0.696], g[1.193]
>1473, c[0.013,100], d[0.533,0.751], g[1.146]
>1474, c[0.027,100], d[0.675,0.761], g[1.182]
>1475, c[0.012,100], d[0.488,0.823], g[1.331]
>1476, c[0.008,100], d[0.586,0.856], g[1.379]
>1477, c[0.007,100], d[0.784,0.596], g[1.399]
>1478, c[0.014,100], d[0.669,0.660], g[1.387]
>1479, c[0.016,100], d[0.734,0.528], g[1.390]
>1480, c[0.005,100], d[0.580,0.698], g[1.295]
>1481, c[0.010,100], d[0.556,0.495], g[1.145]
>1482, c[0.007,100], d[0.632,0.743], g[1.368]
>1483, c[0.007,100], d[0.857,0.896], g[1.383]
>1484, c[0.007,100], d[0.649,0.785], g[1.365]
>1485, c[0.009,100], d[0.676,0.768

>1644, c[0.004,100], d[0.653,0.553], g[1.335]
>1645, c[0.006,100], d[0.737,0.610], g[1.204]
>1646, c[0.012,100], d[0.599,0.593], g[1.381]
>1647, c[0.003,100], d[0.765,0.793], g[1.219]
>1648, c[0.007,100], d[0.569,0.649], g[1.295]
>1649, c[0.005,100], d[0.683,0.605], g[1.139]
>1650, c[0.004,100], d[0.438,0.682], g[1.490]
>1651, c[0.014,100], d[0.620,0.595], g[1.500]
>1652, c[0.007,100], d[0.615,0.644], g[1.182]
>1653, c[0.009,100], d[0.829,0.760], g[1.289]
>1654, c[0.010,100], d[0.540,0.523], g[1.328]
>1655, c[0.018,100], d[0.729,0.653], g[1.424]
>1656, c[0.006,100], d[0.772,0.640], g[1.367]
>1657, c[0.004,100], d[0.764,0.701], g[1.476]
>1658, c[0.005,100], d[0.604,0.586], g[1.532]
>1659, c[0.007,100], d[0.794,0.650], g[1.417]
>1660, c[0.012,100], d[0.586,0.496], g[1.286]
>1661, c[0.007,100], d[0.682,0.717], g[1.175]
>1662, c[0.012,100], d[0.641,0.622], g[1.224]
>1663, c[0.004,100], d[0.483,0.650], g[1.220]
>1664, c[0.008,100], d[0.737,0.622], g[1.327]
>1665, c[0.010,100], d[0.551,0.784

>1822, c[0.003,100], d[0.749,0.630], g[1.435]
>1823, c[0.004,100], d[0.686,0.638], g[1.282]
>1824, c[0.003,100], d[0.713,0.419], g[1.157]
>1825, c[0.005,100], d[0.648,0.733], g[1.254]
>1826, c[0.003,100], d[0.720,0.770], g[1.183]
>1827, c[0.004,100], d[0.595,0.798], g[1.273]
>1828, c[0.007,100], d[0.565,0.586], g[1.280]
>1829, c[0.004,100], d[0.760,0.750], g[1.352]
>1830, c[0.005,100], d[0.708,0.876], g[1.267]
>1831, c[0.007,100], d[0.743,0.565], g[1.410]
>1832, c[0.005,100], d[0.481,0.873], g[1.412]
>1833, c[0.006,100], d[0.937,0.795], g[1.244]
>1834, c[0.008,100], d[0.579,0.671], g[1.432]
>1835, c[0.009,100], d[0.848,0.715], g[1.292]
>1836, c[0.006,100], d[0.725,0.638], g[1.450]
>1837, c[0.006,100], d[0.828,0.992], g[1.213]
>1838, c[0.005,100], d[0.595,0.791], g[1.297]
>1839, c[0.012,100], d[0.786,0.558], g[1.348]
>1840, c[0.013,100], d[0.641,0.532], g[1.331]
>1841, c[0.008,100], d[0.659,0.556], g[1.267]
>1842, c[0.004,100], d[0.637,0.639], g[1.288]
>1843, c[0.006,100], d[0.569,0.549

>2001, c[0.003,100], d[0.577,0.608], g[1.312]
>2002, c[0.003,100], d[0.652,0.561], g[1.168]
>2003, c[0.006,100], d[0.663,0.526], g[1.148]
>2004, c[0.007,100], d[0.610,0.834], g[1.283]
>2005, c[0.003,100], d[0.534,0.488], g[1.280]
>2006, c[0.004,100], d[0.520,0.731], g[1.253]
>2007, c[0.007,100], d[0.606,0.648], g[1.174]
>2008, c[0.006,100], d[0.682,0.805], g[1.232]
>2009, c[0.006,100], d[0.623,0.710], g[1.483]
>2010, c[0.007,100], d[0.829,0.615], g[1.226]
>2011, c[0.003,100], d[0.533,0.609], g[1.232]
>2012, c[0.004,100], d[0.884,0.710], g[1.125]
>2013, c[0.003,100], d[0.457,0.558], g[1.317]
>2014, c[0.007,100], d[0.546,0.749], g[1.351]
>2015, c[0.004,100], d[0.868,0.797], g[1.282]
>2016, c[0.004,100], d[0.825,0.738], g[1.410]
>2017, c[0.004,100], d[0.795,0.737], g[1.293]
>2018, c[0.006,100], d[0.523,0.496], g[1.239]
>2019, c[0.009,100], d[0.630,0.637], g[1.297]
>2020, c[0.004,100], d[0.766,0.617], g[1.213]
>2021, c[0.010,100], d[0.673,0.895], g[1.238]
>2022, c[0.005,100], d[0.760,0.792

>2181, c[0.003,100], d[0.645,0.667], g[1.368]
>2182, c[0.009,100], d[0.697,0.683], g[1.254]
>2183, c[0.005,100], d[0.701,0.588], g[1.226]
>2184, c[0.003,100], d[0.591,0.680], g[1.216]
>2185, c[0.003,100], d[0.663,0.701], g[1.216]
>2186, c[0.003,100], d[0.545,0.711], g[1.195]
>2187, c[0.005,100], d[0.704,0.636], g[1.212]
>2188, c[0.010,100], d[0.691,0.594], g[1.239]
>2189, c[0.004,100], d[0.678,0.574], g[1.236]
>2190, c[0.003,100], d[0.479,0.667], g[1.152]
>2191, c[0.008,100], d[0.697,0.621], g[1.083]
>2192, c[0.008,100], d[0.614,0.844], g[1.318]
>2193, c[0.007,100], d[0.464,0.548], g[1.266]
>2194, c[0.008,100], d[0.678,0.698], g[1.120]
>2195, c[0.003,100], d[0.655,0.629], g[1.192]
>2196, c[0.004,100], d[0.521,0.698], g[1.315]
>2197, c[0.003,100], d[0.692,0.629], g[1.134]
>2198, c[0.007,100], d[0.744,0.674], g[1.391]
>2199, c[0.003,100], d[0.655,0.725], g[1.412]
>2200, c[0.005,100], d[0.478,0.546], g[1.443]
>2201, c[0.006,100], d[0.649,0.602], g[1.377]
>2202, c[0.003,100], d[0.750,0.451

>2362, c[0.012,100], d[0.590,0.757], g[1.158]
>2363, c[0.004,100], d[0.520,0.852], g[1.339]
>2364, c[0.005,100], d[0.702,0.770], g[1.266]
>2365, c[0.004,100], d[0.774,0.726], g[1.399]
>2366, c[0.004,100], d[0.785,0.651], g[1.226]
>2367, c[0.003,100], d[0.563,0.774], g[1.194]
>2368, c[0.003,100], d[0.669,0.778], g[1.446]
>2369, c[0.002,100], d[0.611,0.739], g[1.408]
>2370, c[0.006,100], d[0.863,0.554], g[1.212]
>2371, c[0.005,100], d[0.580,0.580], g[1.301]
>2372, c[0.005,100], d[0.714,0.728], g[1.382]
>2373, c[0.006,100], d[0.662,0.824], g[1.156]
>2374, c[0.006,100], d[0.547,0.751], g[1.264]
>2375, c[0.003,100], d[0.777,1.034], g[1.362]
>2376, c[0.009,100], d[0.854,0.940], g[1.333]
>2377, c[0.012,100], d[0.766,0.553], g[1.232]
>2378, c[0.003,100], d[0.690,0.830], g[1.238]
>2379, c[0.004,100], d[0.638,0.443], g[1.207]
>2380, c[0.006,100], d[0.601,0.665], g[1.187]
>2381, c[0.005,100], d[0.582,0.767], g[1.173]
>2382, c[0.003,100], d[0.810,0.801], g[1.329]
>2383, c[0.006,100], d[0.625,0.719

### The accuracy of the classifer trained by semi-supervised GAN

In [14]:
# example of loading the classifier model and generating images
from numpy import expand_dims
from tensorflow.keras.models import load_model
from tensorflow.keras.datasets.mnist import load_data
# load the model
model = load_model('c_model_2400.h5')
# load the dataset
(trainX, trainy), (testX, testy) = load_data()
# expand to 3d, e.g. add channels
trainX = expand_dims(trainX, axis=-1)
testX = expand_dims(testX, axis=-1)
# convert from ints to floats
trainX = trainX.astype('float32')
testX = testX.astype('float32')
# scale from [0,255] to [-1,1]
trainX = (trainX - 127.5) / 127.5
testX = (testX - 127.5) / 127.5
# evaluate the model
_, train_acc = model.evaluate(trainX, trainy, verbose=0)
print('Train Accuracy: %.3f%%' % (train_acc * 100))
_, test_acc = model.evaluate(testX, testy, verbose=0)
print('Test Accuracy: %.3f%%' % (test_acc * 100))

Train Accuracy: 92.745%
Test Accuracy: 93.260%


### Compared with a direct classifier

In [15]:
def define_classifier(in_shape=(28,28,1), n_classes=10):
    # image input
    in_image = Input(shape=in_shape)
    # downsample
    fe = Conv2D(128, (3,3), strides=(2,2), padding='same')(in_image)
    fe = LeakyReLU(alpha=0.2)(fe)
    # downsample
    fe = Conv2D(128, (3,3), strides=(2,2), padding='same')(fe)
    fe = LeakyReLU(alpha=0.2)(fe)
    # downsample
    fe = Conv2D(128, (3,3), strides=(2,2), padding='same')(fe)
    fe = LeakyReLU(alpha=0.2)(fe)
    # flatten feature maps
    fe = Flatten()(fe)
    # dropout
    fe = Dropout(0.4)(fe)
    # output layer nodes
    fe = Dense(n_classes)(fe)
    # supervised output
    c_out_layer = Activation('softmax')(fe)
    # define and compile supervised discriminator model
    c_model = Model(in_image, c_out_layer)
    c_model.compile(loss='sparse_categorical_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5), metrics=['accuracy'])
    
    return c_model

In [16]:
# train the generator and discriminator
def train_classifier(classifier, dataset, n_epochs=4, n_batch=100):
    # select supervised dataset
    X_sup, y_sup = select_supervised_samples(dataset)
    print(X_sup.shape, y_sup.shape)
    # calculate the number of batches per training epoch
    bat_per_epo = int(dataset[0].shape[0] / n_batch)
    # calculate the number of training iterations
    n_steps = bat_per_epo * n_epochs
    # calculate the size of half a batch of samples
    half_batch = int(n_batch / 2)
    print('n_epochs=%d, n_batch=%d, 1/2=%d, b/e=%d, steps=%d' % (n_epochs, n_batch, half_batch, bat_per_epo, n_steps))
    # manually enumerate epochs
    for i in range(n_steps):
        # update supervised discriminator (c)
        c_loss, c_acc = classifier.train_on_batch(X_sup, y_sup)
        
        # summarize loss on this batch
        print('>%d, c[%.3f,%.0f]' % (i+1, c_loss, c_acc*100))
        
        if (i+1) % (bat_per_epo * 1) == 0:
            filename = 'classifier_%04d.h5' % (i+1)
            classifier.save(filename)

In [17]:
classifier = define_classifier()
dataset = load_real_samples()
X_sup, y_sup = select_supervised_samples(dataset)

(60000, 28, 28, 1) (60000,)


In [18]:
train_classifier(classifier, dataset)

(100, 28, 28, 1) (100,)
n_epochs=4, n_batch=100, 1/2=50, b/e=600, steps=2400
>1, c[2.306,8]
>2, c[2.287,18]
>3, c[2.281,15]
>4, c[2.265,28]
>5, c[2.246,36]
>6, c[2.244,37]
>7, c[2.224,37]
>8, c[2.206,48]
>9, c[2.190,49]
>10, c[2.168,57]
>11, c[2.144,60]
>12, c[2.117,60]
>13, c[2.088,68]
>14, c[2.061,59]
>15, c[2.031,63]
>16, c[1.995,62]
>17, c[1.958,67]
>18, c[1.900,61]
>19, c[1.862,69]
>20, c[1.789,59]
>21, c[1.744,66]
>22, c[1.688,66]
>23, c[1.599,68]
>24, c[1.543,67]
>25, c[1.486,63]
>26, c[1.413,73]
>27, c[1.329,76]
>28, c[1.227,81]
>29, c[1.211,75]
>30, c[1.165,75]
>31, c[1.115,72]
>32, c[1.043,78]
>33, c[1.022,73]
>34, c[0.939,82]
>35, c[0.875,80]
>36, c[0.808,78]
>37, c[0.790,82]
>38, c[0.777,82]
>39, c[0.656,82]
>40, c[0.683,86]
>41, c[0.655,82]
>42, c[0.580,90]
>43, c[0.552,90]
>44, c[0.587,85]
>45, c[0.474,90]
>46, c[0.507,87]
>47, c[0.460,91]
>48, c[0.459,96]
>49, c[0.426,89]
>50, c[0.484,88]
>51, c[0.445,87]
>52, c[0.396,90]
>53, c[0.381,92]
>54, c[0.353,93]
>55, c[0.350,92

>445, c[0.003,100]
>446, c[0.003,100]
>447, c[0.003,100]
>448, c[0.002,100]
>449, c[0.002,100]
>450, c[0.002,100]
>451, c[0.007,100]
>452, c[0.003,100]
>453, c[0.002,100]
>454, c[0.001,100]
>455, c[0.003,100]
>456, c[0.002,100]
>457, c[0.002,100]
>458, c[0.006,100]
>459, c[0.004,100]
>460, c[0.003,100]
>461, c[0.003,100]
>462, c[0.002,100]
>463, c[0.002,100]
>464, c[0.003,100]
>465, c[0.002,100]
>466, c[0.001,100]
>467, c[0.003,100]
>468, c[0.003,100]
>469, c[0.001,100]
>470, c[0.002,100]
>471, c[0.002,100]
>472, c[0.001,100]
>473, c[0.004,100]
>474, c[0.002,100]
>475, c[0.002,100]
>476, c[0.003,100]
>477, c[0.002,100]
>478, c[0.002,100]
>479, c[0.002,100]
>480, c[0.002,100]
>481, c[0.002,100]
>482, c[0.001,100]
>483, c[0.002,100]
>484, c[0.002,100]
>485, c[0.002,100]
>486, c[0.003,100]
>487, c[0.002,100]
>488, c[0.003,100]
>489, c[0.008,100]
>490, c[0.001,100]
>491, c[0.002,100]
>492, c[0.001,100]
>493, c[0.003,100]
>494, c[0.002,100]
>495, c[0.006,100]
>496, c[0.001,100]
>497, c[0.00

>893, c[0.000,100]
>894, c[0.001,100]
>895, c[0.001,100]
>896, c[0.000,100]
>897, c[0.000,100]
>898, c[0.001,100]
>899, c[0.001,100]
>900, c[0.001,100]
>901, c[0.002,100]
>902, c[0.001,100]
>903, c[0.001,100]
>904, c[0.001,100]
>905, c[0.001,100]
>906, c[0.000,100]
>907, c[0.000,100]
>908, c[0.001,100]
>909, c[0.000,100]
>910, c[0.001,100]
>911, c[0.001,100]
>912, c[0.001,100]
>913, c[0.000,100]
>914, c[0.001,100]
>915, c[0.000,100]
>916, c[0.001,100]
>917, c[0.001,100]
>918, c[0.001,100]
>919, c[0.001,100]
>920, c[0.001,100]
>921, c[0.001,100]
>922, c[0.001,100]
>923, c[0.001,100]
>924, c[0.001,100]
>925, c[0.001,100]
>926, c[0.000,100]
>927, c[0.001,100]
>928, c[0.000,100]
>929, c[0.002,100]
>930, c[0.000,100]
>931, c[0.000,100]
>932, c[0.002,100]
>933, c[0.000,100]
>934, c[0.000,100]
>935, c[0.001,100]
>936, c[0.000,100]
>937, c[0.000,100]
>938, c[0.001,100]
>939, c[0.000,100]
>940, c[0.000,100]
>941, c[0.003,100]
>942, c[0.000,100]
>943, c[0.000,100]
>944, c[0.001,100]
>945, c[0.00

>1333, c[0.000,100]
>1334, c[0.000,100]
>1335, c[0.000,100]
>1336, c[0.000,100]
>1337, c[0.000,100]
>1338, c[0.000,100]
>1339, c[0.000,100]
>1340, c[0.000,100]
>1341, c[0.000,100]
>1342, c[0.000,100]
>1343, c[0.000,100]
>1344, c[0.000,100]
>1345, c[0.000,100]
>1346, c[0.000,100]
>1347, c[0.001,100]
>1348, c[0.001,100]
>1349, c[0.000,100]
>1350, c[0.001,100]
>1351, c[0.000,100]
>1352, c[0.001,100]
>1353, c[0.000,100]
>1354, c[0.000,100]
>1355, c[0.000,100]
>1356, c[0.000,100]
>1357, c[0.001,100]
>1358, c[0.000,100]
>1359, c[0.000,100]
>1360, c[0.000,100]
>1361, c[0.000,100]
>1362, c[0.000,100]
>1363, c[0.000,100]
>1364, c[0.000,100]
>1365, c[0.000,100]
>1366, c[0.000,100]
>1367, c[0.000,100]
>1368, c[0.000,100]
>1369, c[0.000,100]
>1370, c[0.000,100]
>1371, c[0.000,100]
>1372, c[0.000,100]
>1373, c[0.000,100]
>1374, c[0.000,100]
>1375, c[0.000,100]
>1376, c[0.001,100]
>1377, c[0.001,100]
>1378, c[0.000,100]
>1379, c[0.001,100]
>1380, c[0.000,100]
>1381, c[0.000,100]
>1382, c[0.000,100]


>1758, c[0.000,100]
>1759, c[0.000,100]
>1760, c[0.000,100]
>1761, c[0.000,100]
>1762, c[0.000,100]
>1763, c[0.000,100]
>1764, c[0.000,100]
>1765, c[0.000,100]
>1766, c[0.000,100]
>1767, c[0.000,100]
>1768, c[0.000,100]
>1769, c[0.000,100]
>1770, c[0.000,100]
>1771, c[0.000,100]
>1772, c[0.000,100]
>1773, c[0.000,100]
>1774, c[0.000,100]
>1775, c[0.000,100]
>1776, c[0.000,100]
>1777, c[0.000,100]
>1778, c[0.000,100]
>1779, c[0.000,100]
>1780, c[0.000,100]
>1781, c[0.000,100]
>1782, c[0.001,100]
>1783, c[0.000,100]
>1784, c[0.000,100]
>1785, c[0.000,100]
>1786, c[0.000,100]
>1787, c[0.000,100]
>1788, c[0.000,100]
>1789, c[0.000,100]
>1790, c[0.000,100]
>1791, c[0.000,100]
>1792, c[0.000,100]
>1793, c[0.000,100]
>1794, c[0.000,100]
>1795, c[0.000,100]
>1796, c[0.000,100]
>1797, c[0.000,100]
>1798, c[0.000,100]
>1799, c[0.000,100]
>1800, c[0.001,100]
>1801, c[0.000,100]
>1802, c[0.000,100]
>1803, c[0.000,100]
>1804, c[0.000,100]
>1805, c[0.000,100]
>1806, c[0.000,100]
>1807, c[0.000,100]


>2175, c[0.000,100]
>2176, c[0.000,100]
>2177, c[0.000,100]
>2178, c[0.000,100]
>2179, c[0.000,100]
>2180, c[0.000,100]
>2181, c[0.000,100]
>2182, c[0.000,100]
>2183, c[0.000,100]
>2184, c[0.000,100]
>2185, c[0.000,100]
>2186, c[0.000,100]
>2187, c[0.000,100]
>2188, c[0.000,100]
>2189, c[0.000,100]
>2190, c[0.000,100]
>2191, c[0.000,100]
>2192, c[0.000,100]
>2193, c[0.000,100]
>2194, c[0.000,100]
>2195, c[0.000,100]
>2196, c[0.000,100]
>2197, c[0.000,100]
>2198, c[0.000,100]
>2199, c[0.000,100]
>2200, c[0.000,100]
>2201, c[0.000,100]
>2202, c[0.000,100]
>2203, c[0.000,100]
>2204, c[0.000,100]
>2205, c[0.000,100]
>2206, c[0.000,100]
>2207, c[0.000,100]
>2208, c[0.000,100]
>2209, c[0.000,100]
>2210, c[0.000,100]
>2211, c[0.000,100]
>2212, c[0.000,100]
>2213, c[0.000,100]
>2214, c[0.000,100]
>2215, c[0.000,100]
>2216, c[0.000,100]
>2217, c[0.000,100]
>2218, c[0.000,100]
>2219, c[0.000,100]
>2220, c[0.000,100]
>2221, c[0.000,100]
>2222, c[0.000,100]
>2223, c[0.000,100]
>2224, c[0.000,100]


In [19]:
# example of loading the classifier model and generating images
from numpy import expand_dims
from tensorflow.keras.models import load_model
from tensorflow.keras.datasets.mnist import load_data
# load the model
model = load_model('classifier_2400.h5')
# load the dataset
(trainX, trainy), (testX, testy) = load_data()
# expand to 3d, e.g. add channels
trainX = expand_dims(trainX, axis=-1)
testX = expand_dims(testX, axis=-1)
# convert from ints to floats
trainX = trainX.astype('float32')
testX = testX.astype('float32')
# scale from [0,255] to [-1,1]
trainX = (trainX - 127.5) / 127.5
testX = (testX - 127.5) / 127.5
# evaluate the model
_, train_acc = model.evaluate(trainX, trainy, verbose=0)
print('Train Accuracy: %.3f%%' % (train_acc * 100))
_, test_acc = model.evaluate(testX, testy, verbose=0)
print('Test Accuracy: %.3f%%' % (test_acc * 100))

Train Accuracy: 74.052%
Test Accuracy: 75.390%
