In [1]:
import io
import pickle
import datetime
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import *
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.optimizers import Adam

In [2]:
# All models are based on MNIST dataset Image-shape=(28,28,1)
# use_bias=False where BN is applied after

In [3]:
w_init=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

G=Sequential([
    Dense(7*7*256, use_bias=False, input_shape=(100,),kernel_initializer=w_init),
    BatchNormalization(),
    ReLU(),
    
    Reshape((7, 7, 256)),
    Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False,kernel_initializer=w_init),
    BatchNormalization(),
    ReLU(),
    
    Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False,kernel_initializer=w_init),
    BatchNormalization(),
    ReLU(),
    
    Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', activation='tanh',kernel_initializer=w_init)
],name='Generator')



D=Sequential([
    Conv2D(64, (5, 5), strides=(2, 2), padding='same',use_bias=False,input_shape=(28,28,1),kernel_initializer=w_init),
    BatchNormalization(),
    LeakyReLU(0.2),
    
    Conv2D(128, (5, 5), strides=(2, 2), padding='same',use_bias=False,kernel_initializer=w_init),
    BatchNormalization(),
    LeakyReLU(0.2),
    
    Flatten(),
    Dense(1,kernel_initializer=w_init)
    
])

In [4]:
class Vanilla_DCGAN:
    def __init__(self,G,D,isGraph=False):
        self.G=G
        self.D=D
        self.isGraph=isGraph
        
        self.G_optim=Adam(learning_rate=0.0002,beta_1=0.5)
        self.D_optim=Adam(learning_rate=0.0002,beta_1=0.5)
        
        self.noise_size=self.G.input_shape[1]
        self.cross_entropy = BinaryCrossentropy(from_logits=True)
        
        #these params are for plotting
        image_shape=self.G.output_shape[1:]
        if image_shape[-1]==1:
            self.reshape_dim=image_shape[:-1]
        else:
            self.reshape_dim=image_shape

    
    def sample_z(self,batch_size):
        return tf.random.uniform([batch_size,self.noise_size])
        
    @tf.function
    def generator_loss(self,fake_output):
        return self.cross_entropy(tf.ones_like(fake_output), fake_output)
    
    @tf.function
    def discriminator_loss(self,real_output,fake_output):
        real_loss = self.cross_entropy(tf.ones_like(real_output), real_output)
        fake_loss = self.cross_entropy(tf.zeros_like(fake_output), fake_output)
        total_loss = real_loss + fake_loss
        return total_loss
    
    @tf.function
    def train_on_batch(self,x):
        batch_size=x.shape[0]
        z=self.sample_z(batch_size)
        
        with tf.GradientTape() as gen_tape, tf.GradientTape() as dis_tape:
            real_output=self.D(x,training=True)
            fake_output=self.D(self.G(z,training=True),training=True)
            d_loss=self.discriminator_loss(real_output,fake_output)
            g_loss=self.generator_loss(fake_output)
        
        d_grads=dis_tape.gradient(d_loss,self.D.trainable_variables)
        g_grads=gen_tape.gradient(g_loss,self.G.trainable_variables)
        
        self.D_optim.apply_gradients(zip(d_grads,self.D.trainable_variables))
        self.G_optim.apply_gradients(zip(g_grads,self.G.trainable_variables))
        
        return d_loss,g_loss
    
    def sample_images(self,number):
        z=self.sample_z(number)
        return self.G(z,training=False).numpy().reshape(number,*self.reshape_dim)
    
    def sample_graphs(self,number):
        z=self.sample_z(number)
        scaled_adjs=self.G(z,training=False).numpy().reshape(number,*self.reshape_dim)
        rescale_adj=(scaled_adjs*0.5)+0.5
        return np.round(rescale_adj)
    
    def get_img_of_gen_samples_from_fig(self,figure):
        buf = io.BytesIO()
        plt.savefig(buf, format='png')
        plt.close(figure)
        buf.seek(0)
        image = tf.image.decode_png(buf.getvalue(), channels=4)
        image = tf.expand_dims(image, 0)
        return image
    
    def get_generated_image_grid(self):
        gen_images=self.sample_images(25)
        figure = plt.figure(figsize=(10,10))
        for i in range(25):
            # Start next subplot.
            plt.subplot(5, 5, i + 1, title=f"gen:{i}")
            plt.xticks([])
            plt.yticks([])
            plt.grid(False)
            plt.imshow(gen_images[i], cmap=plt.cm.binary)
        return figure
    
    def get_generated_graph_grid(self):
        gen_images=self.sample_graphs(25)
        figure = plt.figure(figsize=(10,10))
        for i in range(25):
            # Start next subplot.
            plt.subplot(5, 5, i + 1, title=f"gen:{i}")
            network=nx.from_numpy_array(gen_images[i],create_using=nx.DiGraph)
            nx.draw(network,node_size=50)
        return figure
    
    def get_GenImg_out_2_Log(self):
        if self.isGraph:
            return self.get_img_of_gen_samples_from_fig(self.get_generated_graph_grid())
        else:
            return self.get_img_of_gen_samples_from_fig(self.get_generated_image_grid())
    
    def Accuracy_stats(self,x):
        batch_size=x.shape[0]
        z=self.sample_z(batch_size)
        real_prob=tf.round(tf.nn.sigmoid(self.D(x,training=False)))
        fake_prob=tf.round(tf.nn.sigmoid(self.D(self.G(z,training=False),training=False)))
        
        dis_real_acc=tf.reduce_mean(tf.cast(tf.equal(tf.ones_like(real_prob),real_prob),dtype=tf.float32))
        dis_fake_acc=tf.reduce_mean(tf.cast(tf.equal(tf.zeros_like(fake_prob),fake_prob),dtype=tf.float32))
        
        dis_acc=(dis_real_acc+dis_fake_acc).numpy()/2
        
        gen_fake_acc=tf.reduce_mean(tf.cast(tf.equal(tf.ones_like(fake_prob),fake_prob),dtype=tf.float32))
        gen_acc=gen_fake_acc.numpy()
        
        return dis_acc,gen_acc

In [5]:
def load_mnist():
    (x_train, y_train), (x_test, y_test)=tf.keras.datasets.mnist.load_data(path="mnist.npz")
    return ((x_train- 127.5) / 127.5).reshape(-1,28,28,1)

def load_graphs(community_size=28):
    with open(f"./data/true_data_{community_size}.pickle", 'rb') as handle:
        data = pickle.load(handle)['true_data']

    graphs=np.array([nx.graphmatrix.adj_matrix(d).toarray() for d in data])
    scaled_graphs=((graphs-0.5)/0.5)#*0.95 #label smoothing
    graphs_reshaped=scaled_graphs.reshape(-1,community_size,community_size,1)
    return graphs_reshaped

In [6]:
class TrainingVGAN:
    def __init__(self,VGAN,dataset,batch_size=128,experiment_name=""):
        self.VGAN=VGAN
        self.dataset=tf.data.Dataset.from_tensor_slices(dataset).shuffle(dataset.shape[0]).batch(batch_size,
                                                                                                 drop_remainder=True)
        self.epoch=0
        self.num_of_batch=dataset.shape[0]//batch_size
        
        self.current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        self.log_dir = 'logs/VGAN/' + self.current_time + experiment_name
        self.summary_writer = tf.summary.create_file_writer(self.log_dir)
        
    def train(self,epochs=50,log_epoch=1):
        while self.epoch<=epochs:
            D_LOSS=0;G_LOSS=0
            D_ACC=0;G_ACC=0
            for images in self.dataset:
                d_acc,g_acc=self.VGAN.Accuracy_stats(images)
                d_loss,g_loss=self.VGAN.train_on_batch(images)
                D_LOSS+=d_loss.numpy();G_LOSS+=g_loss.numpy()
                D_ACC+=d_acc;G_ACC+=g_acc
            D_LOSS/=self.num_of_batch;G_LOSS/=self.num_of_batch
            D_ACC/=self.num_of_batch;G_ACC/=self.num_of_batch
            
            if self.epoch%log_epoch==0:
                print("ON EPOCH {}".format(self.epoch))
                GEN_IMAGES=self.VGAN.get_GenImg_out_2_Log()
                with self.summary_writer.as_default():
                    tf.summary.scalar('loss/Generator', G_LOSS, step=self.epoch)
                    tf.summary.scalar('loss/Discriminator', D_LOSS, step=self.epoch)
                    tf.summary.scalar('acc/Generator', G_ACC, step=self.epoch)
                    tf.summary.scalar('acc/Discriminator', D_ACC, step=self.epoch)
                    tf.summary.image('Generated_images',GEN_IMAGES,step=self.epoch)
                
            
            self.epoch+=1

# Training

GAN=Vanilla_DCGAN(G,D,isGraph=False)

algo=TrainingVGAN(GAN,load_mnist(),128,"mnist")

algo.train(50)

GAN=Vanilla_DCGAN(G,D,isGraph=True)

algo=TrainingVGAN(GAN,load_graphs(28),128,"yeast-graph")

algo.train(300) # 120 epoch is enough

In [2]:
#NICE