In [54]:
""" Implementing training confidence calibrated classifier paper for synthetic data experiment """

import numpy as np

from IPython.core.debugger import Tracer

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Activation, Lambda
from keras.layers import BatchNormalization
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Model, Sequential
from keras.optimizers import Adam

from sklearn.model_selection import train_test_split
from keras.utils import np_utils

import matplotlib.pyplot as plt


from tensorflow import set_random_seed
# for reproducibility 
np.random.seed(1)
set_random_seed(1)

In [90]:

class GAN(object):

    """ Generative Adversarial Network class """
    def __init__(self):


        self.number_of_classes = 2
        self.input_dim = 2
        self.gen_input_dim = 100
        self.epochs = 20000
        self.batch_size = 200
        self.save_interval = 100
        # There are 2 inlier circles
        # Inlier circle centers on the 1st axis; for all other dimensions, the centers are at 0.
        self.x1 = 0.3
        self.x2 = 0.6
        self.r = 0.1
        # Number of inlier samples for each class
        self.n = 1000
        # width of the outlier outer square side
        self.s = 100

        # Confidence loss weight
        self.beta_G = 1.0
        self.beta_C = 0.5
        # get in-distribution data from 2 circles (2 classes)
        self.data, self.labels = self.get_data()

        print(self.data.shape, self.labels.shape)
        self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(self.data, self.labels, test_size=0.5, shuffle=True, stratify=self.labels,
                                                            random_state=1)

        self.Y_train = np_utils.to_categorical(self.y_train, self.number_of_classes)
        self.Y_test = np_utils.to_categorical(self.y_test, self.number_of_classes)

        self.optimizer = Adam(lr=0.0002, beta_1=0.5, decay=8e-8)

        self.G = self.__generator()
        self.G.compile(loss='binary_crossentropy', optimizer=self.optimizer)

        self.D = self.__discriminator()
        self.D.compile(loss='binary_crossentropy', optimizer=self.optimizer, metrics=['accuracy'])

        self.C = self.__classifier()
        self.C.compile(loss='categorical_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5, decay=1e-4), metrics=['accuracy'])

        self.stacked_generator_discriminator_classifier = self.__stacked_generator_discriminator_classifier()
        self.stacked_generator_discriminator_classifier.compile(loss={'out_d':'binary_crossentropy', 'out_c':'categorical_crossentropy'}, 
                                       loss_weights={'out_d': 1.0, 'out_c': self.beta_G}, optimizer=self.optimizer, 
                                       metrics={'out_d':['acc'], 'out_c':['categorical_crossentropy']})




    # return True if point p (a list) lies inside a hypersphere centered at the origin with radius r
    def is_in_hypersphere(self, r, p):
        p = np.array(p)
        return np.sum(p*p) <= r**2


    # return True if point p (a list) lies inside a hypersphere centered at the (x, 0, 0, ...) with radius r
    def is_in_hypersphere_centered_at(self, x, r, p):
        p = np.array(p)
        p[0] -= x
        return np.sum(p*p) <= r**2


    # uniformly sample a point from a square centered at the origin with side = s
    def sample_point_from_square(self, s):
        p = []
        # first point
        for i in range(0, self.input_dim):
            p.append(np.random.uniform(-1*s/2, s/2, 1))
        return p


    # sample points using rejection sampling from a sphere centered at the origin
    def sample_point_from_hypersphere(self, r):
        p = self.sample_point_from_square(2*r)
        while not self.is_in_hypersphere(r, p):
            p = self.sample_point_from_square(2*r)
        return p
            

    # sample from a hypersphre with rejection sampling
    def get_hypersphere_data(self, x, r, n):
        data = np.zeros((n, self.input_dim))
        count = 0
        while count<n:
            data[count, :] = self.sample_point_from_hypersphere(r)
            # update the point with offset for the 1st dimension
            data[count, 0] += x
            count += 1

        data = data.astype(np.float32)
        return data
        # plt.scatter(x, y, s = 4)
    
    
    def get_data(self):
        # class 0 data
        data = self.get_hypersphere_data(self.x1, self.r, self.n)
        labels = np.ones((self.n, 1))*0
        print(labels.shape)
        # class 1 data
        data = np.concatenate((data, self.get_hypersphere_data(self.x2, self.r, self.n)), axis=0)
        labels = np.concatenate((labels, np.ones((self.n, 1))*1), axis=0)
        print(labels.shape)

        return data.astype(np.float32), labels.astype(np.float32)


    def __generator(self):
        """ Declare generator """

        model = Sequential()
        model.add(Dense(256, input_shape=(self.gen_input_dim,)))
        model.add(LeakyReLU(alpha=0.2))
        #model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        #model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        #model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(self.input_dim, activation='linear'))

        return model

    def __discriminator(self):
        """ Declare discriminator """

        model = Sequential()
        model.add(Dense(256, input_shape=(self.input_dim, )))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

        return model

    
    """
    def __generator(self):
        # Declare generator 

        model = Sequential()
        model.add(Dense(500, input_shape=(self.gen_input_dim,)))
        model.add(Activation('relu'))
        model.add(Dense(500))
        model.add(Activation('relu'))
        model.add(Dense(self.input_dim, activation='linear'))

        return model

    def __discriminator(self):
        # Declare discriminator

        model = Sequential()
        model.add(Dense(500, input_shape=(self.input_dim, )))
        model.add(Activation('relu'))
        model.add(Dense(500))
        model.add(Activation('relu'))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

        return model
    """

    def __classifier(self):
        model = Sequential()
        model.add(Dense(500, input_shape=(self.input_dim,)))
        model.add(Activation('relu'))
        model.add(Dense(500))
        model.add(Activation('relu'))
        model.add(Dense(self.number_of_classes))
        model.add(Activation('softmax'))
        return model


    def __stacked_generator_discriminator_classifier(self):
        self.D.trainable = False
        self.C.trainalble = False

        # generator input
        in_g = Input(shape=(self.gen_input_dim, ))
        # generator output
        out_g = self.G(in_g)

        # discriminator output
        out_d = self.D(out_g)
        # classifier output
        out_c = self.C(out_g)

        out_d = Lambda(lambda x: x, name = 'out_d')(out_d)
        out_c = Lambda(lambda x: x, name = 'out_c')(out_c)

        model = Model(inputs=[in_g], outputs=[out_d, out_c])
        return model


    def train(self):

        # gen_data = np.empty((self.epochs*10, self.input_dim))
        # count = 0
        y_discriminator_ones = np.ones((self.batch_size/2, 1)).astype(np.float32)
        y_discriminator_zeros = np.zeros((self.batch_size/2, 1)).astype(np.float32)
        y_out = np.ones((self.batch_size/2, 2)).astype(np.float32)
        sample_weight = np.ones((self.batch_size)).astype(np.float32)
        sample_weight[self.batch_size/2:] = self.beta_C

        for cnt in range(self.epochs):

            ## train discriminator
            random_index = np.random.randint(0, len(self.X_train) - self.batch_size/2)
            legit_images = self.X_train[random_index : random_index + self.batch_size/2]

            gen_noise = np.random.normal(0, 1, (self.batch_size/2, self.gen_input_dim))
            synthetic_images = self.G.predict(gen_noise)

            x_combined_batch = np.concatenate((legit_images, synthetic_images))
            y_combined_batch = np.concatenate((y_discriminator_ones, y_discriminator_zeros))

            d_loss = self.D.train_on_batch(x_combined_batch, y_combined_batch)

            """
            gen_noise = np.random.normal(0, 1, (100, 100))
            syntetic_images = self.G.predict(gen_noise)
            preds = self.D.predict(syntetic_images)
            ind = np.where(preds < self.thresh)[0]
            if len(ind) > 0:
                gen_data[count:count+min(10, ind.shape[0]),] = syntetic_images[ind[0:min(10, ind.shape[0])],]
                count = count + min(10, ind.shape[0])
            """

            # train generator
            noise = np.random.normal(0, 1, (self.batch_size/2, self.gen_input_dim))
            g_loss = self.stacked_generator_discriminator_classifier.train_on_batch([noise], [y_discriminator_ones, y_out])

            # train classifier
            random_index = np.random.randint(0, len(self.X_train) - self.batch_size/2)
            legit_images = self.X_train[random_index : random_index + self.batch_size/2]
            legit_labels = self.Y_train[random_index : random_index + self.batch_size/2]

            gen_noise = np.random.normal(0, 1, (self.batch_size/2, self.gen_input_dim))
            synthetic_images = self.G.predict(gen_noise)

            x_combined_batch = np.concatenate((legit_images, synthetic_images))
            y_combined_batch = np.concatenate((legit_labels, y_out))

            c_loss = self.C.train_on_batch(x_combined_batch, y_combined_batch, sample_weight=sample_weight)
            #print ('epoch: %d, [Discriminator :: d_loss: %f], [ Generator :: loss: %f], [ Classifier :: loss: %f]' % 
            #       (cnt, d_loss[0], g_loss, 10.1))

            if cnt % self.save_interval == 0:
                print('epoch: %d' % cnt)
                print("d_loss", d_loss)
                print("g_loss", g_loss)
                c_loss = self.C.evaluate(self.X_test, self.Y_test)
                print("c_loss:", c_loss)
                self.plot_data(save2file=True, step=cnt)



    def plot_data(self, save2file, step=0):
        ''' Plot and generated images '''

        gen_noise = np.random.normal(0, 1, (self.X_test.shape[0], self.gen_input_dim))
        synthetic_images = self.G.predict(gen_noise)


        plt.figure(figsize=(5,5))
        ind = np.where(self.y_test == 0)[0]
        plt.scatter(self.X_test[ind, 0], self.X_test[ind, 1], s=10, c='r', marker='D')
        ind = np.where(self.y_test == 1)[0]
        plt.scatter(self.X_test[ind, 0], self.X_test[ind, 1], s=10, c='b', marker='o')
        plt.scatter(synthetic_images[:, 0], synthetic_images[:, 1], s=10, c='y', marker='*')
        plt.legend(('$class\ 0$','$class\ 1$','$OOD\ sample$'), fontsize='x-small');

        if save2file:
            filename = "./images/mnist_%d.png" % step
            plt.savefig(filename)
        plt.close('all')
        #plt.show()


        

In [91]:
gan = GAN()
gan.train()

(1000, 1)
(2000, 1)
((2000, 2), (2000, 1))
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_330 (Dense)            (None, 256)               768       
_________________________________________________________________
leaky_re_lu_142 (LeakyReLU)  (None, 256)               0         
_________________________________________________________________
dense_331 (Dense)            (None, 256)               65792     
_________________________________________________________________
leaky_re_lu_143 (LeakyReLU)  (None, 256)               0         
_________________________________________________________________
dense_332 (Dense)            (None, 256)               65792     
_________________________________________________________________
leaky_re_lu_144 (LeakyReLU)  (None, 256)               0         
_________________________________________________________________
dense_333 (Dense)            (Non

epoch: 3400
('d_loss', [0.6185015, 0.755])
('g_loss', [2.1979754, 0.7987216, 1.3992538, 0.0, 1.3992538])
('c_loss:', [0.2506622346639633, 1.0])
epoch: 3500
('d_loss', [0.61125624, 0.735])
('g_loss', [2.3001392, 0.9074166, 1.3927226, 0.0, 1.3927226])
('c_loss:', [0.3020480636358261, 0.988])
epoch: 3600
('d_loss', [0.5239623, 0.755])
('g_loss', [2.3323357, 0.94179255, 1.3905431, 0.0, 1.3905431])
('c_loss:', [0.26105241930484774, 0.979])
epoch: 3700
('d_loss', [0.63140684, 0.685])
('g_loss', [2.3216884, 0.92838556, 1.3933029, 0.0, 1.3933029])
('c_loss:', [0.34361365461349486, 0.922])
epoch: 3800
('d_loss', [0.616743, 0.685])
('g_loss', [2.3243597, 0.9293277, 1.3950319, 0.0, 1.3950319])
('c_loss:', [0.37417788314819334, 0.773])
epoch: 3900
('d_loss', [0.67713284, 0.56])
('g_loss', [2.299767, 0.90664923, 1.3931178, 0.0, 1.3931178])
('c_loss:', [0.39975203704833984, 0.856])


KeyboardInterrupt: 