In [1]:
import matplotlib.pyplot as plt
import numpy as np

from keras import backend as K
from keras.datasets import mnist
from keras.layers import Dense, Flatten, Reshape, LeakyReLU, Dropout, Lambda
from keras.layers import Activation, BatchNormalization
from keras.layers import Conv2D, Conv2DTranspose
from keras.utils import to_categorical

from keras.models import Sequential
from keras.optimizers import Adam

In [2]:
class Dataset:
    def __init__(self, labeled_num):
        self.labeled_num = labeled_num
        (self.x_train, self.y_train),(self.x_test, self.y_test) = mnist.load_data()

        self.x_train = self._preprocess_img(self.x_train)
        self.y_train = self._preprocess_label(self.y_train)
        self.x_test = self._preprocess_img(self.x_test)
        self.y_test = self._preprocess_label(self.y_test)

    @staticmethod
    def _preprocess_img(x):
        x = (x.astype(np.float32) - 127.5) / 127.5 # Normalize data
        x = np.expand_dims(x, axis=3) # expand data
        return x

    @staticmethod
    def _preprocess_label(y):
        return y.reshape(-1, 1)

    def get_batch_labeled(self, batch_size):
        idx = np.random.randint(0, self.labeled_num, batch_size)
        imgs = self.x_train[idx]
        labels = self.y_train[idx]

        return imgs, labels

    def get_batch_unlabeled(self, batch_size):
        idx = np.random.randint(self.labeled_num, self.x_train.shape[0], batch_size)
        imgs = self.x_train[idx]

        return imgs

    def read_traindata(self):
        self.x_train = self.x_train[range(self.labeled_num)]
        self.y_train = self.y_train[range(self.labeled_num)]
        return self.x_train, self.y_train

    def read_testdata(self):
        return self.x_test, self.y_test


# Preprocessing:

## Load Dataset:

In [3]:
(x_train, _),(x_test, _) = mnist.load_data()
print(x_train.shape)
print(x_test.shape)

(60000, 28, 28)
(10000, 28, 28)


## Concat train & test data:

In [4]:
# to have more data for train
real_data = np.concatenate([x_train, x_test], axis=0)
real_data.shape

(70000, 28, 28)

## Normalize & Extend(dim) data:

In [5]:
normalized_real_data = np.expand_dims(real_data/127.5 - 1.0, axis=3)

# Define Networks:

In [6]:
def generator(zdim):
    """ Generator Network """
    model = Sequential()

    model.add(Dense(7*7*256, use_bias=False, input_shape=(zdim,))) # >>> (12544, 1, 1)
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Reshape((7, 7, 256))) # >>> (7, 7, 256)

    model.add(Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False)) # >>> (7, 7, 128)
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)) # >>> (14, 14, 64)
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')) # >>> (28, 28, 1)

    return model


In [7]:
def discriminator(img_shape, num_classes):
    """ Discriminator Network """
    model = Sequential()

    # Input >>> (28, 28, 1)
    model.add(Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=img_shape)) # >>> (14, 14, 64)
    model.add(LeakyReLU())
    model.add(Dropout(0.3))

    model.add(Conv2D(128, (5, 5), strides=(2, 2), padding='same')) # >>> (7, 7, 128)
    model.add(LeakyReLU())
    model.add(Dropout(0.3))

    model.add(Flatten()) # >>> 7 * 7 * 128
    model.add(Dense(num_classes))

    return model

def supervised_discriminator(dis_net):
    """ multi classification """
    model = Sequential()
    model.add(dis_net)
    model.add(Activation('softmax'))

    model.compile(
        loss='categorical_crossentropy',
        optimizer=Adam(),
        metrics=['accuracy']
      )

    return model

def unsupervised_discriminator(dis_net):
    """ binary classification """
    model = Sequential()
    model.add(dis_net)

    def predict(x):
        """ custome sigmoid for comparison of 10 real data classes with fake data """
        prediction = 1.0 - (1.0 / (K.sum(K.exp(x), axis=-1, keepdims=True) + 1.0))
        return prediction

    model.add(Lambda(predict))

    model.compile(
        loss='binary_crossentropy',
        optimizer=Adam()
      )

    return model

In [8]:
def define_gan(gen, dis):
    """ Generative Adversarial Network """
    model = Sequential()
    model.add(gen)
    model.add(dis)

    model.compile(
        loss='binary_crossentropy',
        optimizer=Adam()
      )

    return model

## Set Shapes:

In [9]:
z_dim = 100
labeled_num = 100
num_classes = 10
img_shape = (real_data.shape[1], real_data.shape[2], 1) # (28, 28, 1)
img_shape

(28, 28, 1)

In [10]:
dataset = Dataset(labeled_num)

## Call Networks:

In [11]:
dis = discriminator(img_shape, num_classes)

dis_supervised = supervised_discriminator(dis)
dis_unsupervised = unsupervised_discriminator(dis)

gen = generator(z_dim)

dis_unsupervised.trainable = False
GAN = define_gan(gen, dis_unsupervised)

# Train Model:

## Useful Tools:

In [12]:
def generate_latent_vector(batch_size, zdim):
    """ generate a random latent vector """
    return np.random.normal(0, 1, (batch_size, zdim))


## Define operation algorithm:

In [13]:
supervised_losses = []
iter_checks = []

def train(r_data, gen, dis, dis_sup, dis_unsup, GAN, zdim, iters_range, batch_size, interval):
    r_labels = np.ones((batch_size, 1))
    f_labels = np.zeros((batch_size, 1))

    for iter in range(iters_range):
        print(f"------------------{iter+1}------------------")

        # get imgs
        imgs_labeled, labels = dataset.get_batch_labeled(batch_size)
        labels = to_categorical(labels, num_classes=num_classes)

        imgs_unlabeled = dataset.get_batch_unlabeled(batch_size)

        # gen part
        z = generate_latent_vector(batch_size, zdim)
        f_imgs = gen.predict(z)

        # dis part
        dloss_sup, acc_sup = dis_sup.train_on_batch(imgs_labeled, labels)

        r_dloss_unsup = dis_unsupervised.train_on_batch(imgs_unlabeled, r_labels)
        f_dloss_unsup = dis_unsupervised.train_on_batch(f_imgs, f_labels)

        dloss_unsup = 0.5 * (r_dloss_unsup + f_dloss_unsup)

        # GAN part
        z = generate_latent_vector(batch_size, zdim)
        gloss = GAN.train_on_batch(z, r_labels)

        if (iter + 1) % interval == 0:
            supervised_losses.append(dloss_sup)
            iter_checks.append(iter+1)

            print(
                """
                D supervised_loss[%.4f] acc[%.2f] \n
                D unsupervised_loss[%.4f]

                """
                % (dloss_sup, 100.0 * acc_sup, dloss_unsup)
            )


## Start training:

In [14]:
train(
    normalized_real_data, gen, dis, dis_supervised, dis_unsupervised, GAN,
    z_dim, 4000, 256, 1000
  )

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
------------------1510------------------
------------------1511------------------
------------------1512------------------
------------------1513------------------
------------------1514------------------
------------------1515------------------
------------------1516------------------
------------------1517------------------
------------------1518------------------
------------------1519------------------
------------------1520------------------
------------------1521------------------
------------------1522------------------
------------------1523------------------
------------------1524------------------
------------------1525------------------
------------------1526------------------
------------------1527------------------
------------------1528------------------
------------------1529------------------
------------------1530------------------
------------------1531------------------
------------------1532-----------

# Test & Compare supervised with semi-supervised:

## supervised:

In [15]:
x_traindata, y_traindata = dataset.read_traindata()
y_traindata = to_categorical(y_traindata, num_classes=num_classes)

_, acc_train = dis_supervised.evaluate(x_traindata, y_traindata)
print("Training Accuracy: %.2f" % (100.0 * acc_train))

x_testdata, y_testdata = dataset.read_testdata()
y_testdata = to_categorical(y_testdata, num_classes=num_classes)

_, acc_test = dis_supervised.evaluate(x_testdata, y_testdata)
print("Test Accuracy: %.2f" % (100.0 * acc_test))

Training Accuracy: 100.00
Test Accuracy: 91.08


## semi-supervised:

In [24]:
mnist_classifier = supervised_discriminator(discriminator(img_shape, num_classes))
imgs, labels = dataset.read_traindata()
labels = to_categorical(labels, num_classes=num_classes)

history = mnist_classifier.fit(imgs, labels, epochs=100, batch_size=256)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78

In [None]:
x_traindata, y_traindata = dataset.read_traindata()
y_traindata = to_categorical(y_traindata, num_classes=num_classes)

_, acc_train = mnist_classifier.evaluate(x_traindata, y_traindata)
print("Training Accuracy: %.2f" % (100.0 * acc_train))

x_testdata, y_testdata = dataset.read_testdata()
y_testdata = to_categorical(y_testdata, num_classes=num_classes)

_, acc_test = mnist_classifier.evaluate(x_testdata, y_testdata)
print("Test Accuracy: %.2f" % (100.0 * acc_test))

# ***FIN :3***