<a href="https://colab.research.google.com/github/rahiakela/gans-in-action/blob/part-2-advanced-topics-in-gans/7-semi_supervised_gan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Semi-Supervised GAN(SGAN)


Semi-supervised learning is one of the most promising areas of practical application of GANs. 

Unlike supervised learning, in which we need a label for every example in our
dataset, and unsupervised learning, in which no labels are used, semi-supervised
learning has a class label for only a small subset of the training dataset. 

By internalizing hidden structures in the data, semi-supervised learning strives to generalize from the small subset of labeled data points to effectively classify new, previously unseen examples. 

Importantly, for semi-supervised learning to work, the labeled and unlabeled
data must come from the same underlying distribution.

Interestingly, semi-supervised learning may also be one of the closest machine
learning analogs to the way humans learn. 

When schoolchildren learn to read and write, the teacher does not have to take them on a road trip to see tens of thousands of examples of letters and numbers, ask them to identify these symbols, and correct them as needed—similarly to the way a supervised learning algorithm would operate.
Instead, a single set of examples is all that is needed for children to learn letters and numerals and then be able to recognize them regardless of font, size, angle, lighting conditions, and many other factors. Semi-supervised learning aims to teach machines in a similarly efficient manner.

Serving as a source of additional information that can be used for training, generative models proved useful in improving the accuracy of semi-supervised models. Unsurprisingly, GANs have proven the most promising.

## What is a Semi-Supervised GAN?

Semi-Supervised GAN (SGAN) is a Generative Adversarial Network whose Discriminator is a multiclass classifier. Instead of distinguishing between only two classes (real and fake), it learns to distinguish between N + 1 classes, where N is the number of classes in the training dataset, with one added for the fake examples produced by the Generator.

For example, the MNIST dataset of handwritten digits has 10 labels (one label for each numeral, 0 to 9), so the SGAN Discriminator trained on this dataset would predict between 10 + 1 = 11 classes. 

In our implementation, the output of the SGAN Discriminator will be represented as a vector of 10 class probabilities (that sum up to 1.0) plus another probability that represents whether the image is real or fake.


<img src='https://github.com/rahiakela/img-repo/blob/master/gans-in-action/sgan-architecture.png?raw=1' width='800'/>

---

In this Semi-Supervised GAN, the Generator takes in a random noise
vector z and produces a fake example x*. The Discriminator receives three kinds of data inputs: fake data from the Generator, real unlabeled examples x, and real labeled examples (x, y), where y is the label corresponding to the given example. The Discriminator then outputs a classification; its goal is to distinguish fake examples from the real ones and, for the real examples, identify the correct class. Notice that the portion of examples with labels is much smaller than the portion of the unlabeled data. In practice, the contrast is even starker than the one shown, with labeled data forming only a tiny fraction (often as little as 1–2%) of the training data.

---

The task of distinguishing between multiple classes not only
impacts the Discriminator itself, but also adds complexity to the SGAN architecture, its training process, and its training objectives, as compared to the traditional GAN.






## Architecture

The SGAN Generator’s purpose is the same as in the original GAN: it takes in a vector of random numbers and produces fake examples whose goal is to be indistinguishable from the training dataset—no change here.

The SGAN Discriminator, however, diverges considerably from the original GAN
implementation. Instead of two, it receives three kinds of inputs: fake examples produced by the Generator (x*), real examples without labels from the training dataset (x), and real examples with labels from the training dataset (x, y), where y denotes the label for the given example x.Instead of binary classification, the SGAN Discriminator’s goal is to correctly categorize the input example into its corresponding class if the example is real, or reject the example as fake (which can be thought of as a special additional class).

<img src='https://github.com/rahiakela/img-repo/blob/master/gans-in-action/sgan-network-table.png?raw=1' width='800'/>



## Training process

In a regular GAN, we train the Discriminator by computing the loss for
$D(x)$ and $D(x*)$ and backpropagating the total loss to update the Discriminator’s trainable parameters to minimize the loss. The Generator is trained by backpropagating the Discriminator’s loss for D(x*), seeking to maximize it, so that the fake examples it synthesizes are misclassified as real.

To train the SGAN, in addition to D(x) and D(x*), we also have to compute the loss for the supervised training examples: D((x, y)). These losses correspond to the dual learning objective that the SGAN Discriminator has to grapple with: distinguishing real examples from the fake ones while also learning to classify real examples to their correct classes.

## Training objective

All the GAN generative models goal is to produce realistic-looking data samples; hence, the Generator network has been of primary interest. The main purpose of the Discriminator network has been to help the Generator improve the quality of images it produces. 

**At the end of the training, we often disregard the Discriminator and use only the fully trained Generator to create realistic-looking synthetic data.**

In contrast, in a SGAN, we care primarily about the Discriminator. The goal of the training process is to make this network into a semi-supervised classifier whose accuracy is as close as possible to a fully supervised classifier (one that has labels available for each example in the training dataset), while using only a small fraction of the labels. The Generator’s goal is to aid this process by serving as a source of additional information (the fake data it produces) that helps the Generator learn the relevant patterns in the data, enhancing its classification accuracy. 

**At the end of the training, the Generator gets discarded, and we use the trained Discriminator as a classifier.**

## Implementing a Semi-Supervised GAN

We implement an SGAN model that learns to classify handwritten digits in the MNIST dataset by using only 100 training examples. At the end of it, we compare the model’s classification accuracy to an equivalent fully supervised model to see for ourselves the improvement achieved by semi-supervised learning.

<img src='https://github.com/rahiakela/img-repo/blob/master/gans-in-action/sgan-diagram.png?raw=1' width='800'/>

The Generator turns random noise into fake examples. The Discriminator receives real images with labels (x, y), real images without labels (x), and fake images produced by the Generator (x*). To distinguish real examples from fake ones, the Discriminator uses the sigmoid function. To distinguish between the real classes, the Discriminator uses the
softmax function.


---

To solve the multiclass classification problem of distinguishing between the real labels, the Discriminator uses the softmax function, which gives probability distribution over a specified number of classes—in our case, 10. The higher the probability assigned to a given label, the more confident the Discriminator is that the example
belongs to the given class. To compute the classification error, we use cross-entropy loss, which measures the difference between the output probabilities and the target,
one-hot-encoded labels.

To output the real-versus-fake probability, the Discriminator uses the sigmoid activation function and trains its parameters by backpropagating the binary cross-entropy loss.



### Setup

In [1]:
import tensorflow as tf

from tensorflow.keras.layers import (Activation, BatchNormalization, Concatenate, Dense, Dropout,
                                     Flatten, Input, Lambda, Reshape, LeakyReLU, Conv2D, Conv2DTranspose)
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import mnist
from tensorflow.keras import backend as K

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

We also specify the input image size, the size of the noise vector $z$, and the number of the real classes for the semi-supervised classification (one for each numeral our Discriminator will learn to identify), as shown in the following listing.

In [2]:
# Model input dimensions
img_rows = 28
img_cols = 28
channels = 1

# Input image dimensions
img_shape = (img_rows, img_cols, channels)

# Size of the noise vector, used as input to the Generator
z_dim = 100

# Number of classes in the dataset
num_classes = 10

### The dataset

Although the MNIST training dataset has `50,000` labeled training images, we will use only a small fraction of them (specified by the `num_labeled` parameter) for training
and pretend that all the remaining ones are unlabeled. We accomplish this by sampling only from the first num_labeled images when generating batches of labeled data and from the remaining `(50,000 – num_labeled)` images when generating batches of unlabeled examples.

The Dataset object also provides a function to return all the
`num_labeled` training examples along with their labels as well as a function to return all `10,000` labeled test images in the MNIST dataset. After training, we will use the test set to evaluate how well the model’s classifications generalize to previously unseen examples.

In [5]:
class Dataset:

  def __init__(self, num_labeled):
    # Number of labeled examples to use for training
    self.num_labeled = num_labeled
    # Load the MNIST dataset
    (self.x_train, self.y_train), (self.x_test, self.y_test) = mnist.load_data()

    def preprocess_imgs(x):
      # Rescale [0, 255] grayscale pixel values to [-1, 1]
      x = (x.astype(np.float32) - 127.5) / 127.5
      # Expand image dimensions to width x height x channels
      x = np.expand_dims(x, axis=3)
      return x

    def preprocess_labels(y):
      return y.reshape(1, -1)

    # Training data
    self.x_train = preprocess_imgs(self.x_train)
    self.y_train = preprocess_labels(self.y_train)

    # Testing data
    self.x_test = preprocess_imgs(self.x_test)
    self.y_test = preprocess_labels(self.y_test)
  
  def batch_labeled(self, batch_size):
    # Get a random batch of labeled images and their labels
    idx = np.random.randint(0, self.num_labeled, batch_size)
    imgs = self.x_train[idx]
    labels = self.y_train[idx]
    return imgs, labels

  def batch_unlabeled(self, batch_size):
    # Get a random batch of unlabeled images
    idx = np.random.randint(self.num_labeled, self.x_train.shape[0], batch_size)
    imgs = self.x_train[idx]
    return imgs

  def training_set(self):
    x_train = self.x_train[range(self.num_labeled)]
    y_train = self.y_train[range(self.num_labeled)]
    return x_train, y_train

  def test_set(self):
    return self.x_test, self.y_test

We will pretend that we have only 100 labeled MNIST images for training:

In [7]:
# Number of labeled examples to use (rest will be used as unlabeled)
num_labeled = 100

dataset = Dataset(num_labeled)
print(dataset.num_labeled)

100


### The Generator

The Generator network is the same as the one we implemented for the DCGAN. Using transposed convolution layers, the Generator transforms the input random noise vector into 28 × 28 × 1 image; see the following listing.