<a href="https://colab.research.google.com/github/tbonne/Net_VarEncoder/blob/main/code/ITE_VarEncoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Introduction

Here we're going to test out the use of variational autoencoders to estimate the latent network which can best predict observed ITE interactions between vervet monkey groups.



In [11]:
import pandas as pd
from tensorflow import keras
from tensorflow.keras import layers

#### Data

In [17]:
df_ite = pd.read_csv("https://raw.githubusercontent.com/tbonne/Net_VarEncoder/main/data/mat_all.csv", header=None)

ite = df_ite.to_numpy()

ite.shape

(157, 2618)

#### Autoencoder

lay out the encoder... this converts the observed ite into a latent representation.

In [28]:
latent_dim = 2

encoder_inputs = keras.Input(shape=(157, 1))
x = layers.Dense(32, activation="relu")(encoder_inputs)
x = layers.Dense(64, activation="relu")(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var], name="encoder")

encoder.summary()

Model: "encoder"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_9 (InputLayer)           [(None, 157, 1)]     0           []                               
                                                                                                  
 dense_8 (Dense)                (None, 157, 32)      64          ['input_9[0][0]']                
                                                                                                  
 dense_9 (Dense)                (None, 157, 64)      2112        ['dense_8[0][0]']                
                                                                                                  
 flatten_4 (Flatten)            (None, 10048)        0           ['dense_9[0][0]']                
                                                                                            

Write the sampler

In [29]:
import tensorflow as tf

class Sampler(layers.Layer):
  
  def call(self, z_mean, z_log_var):
    batch_size = tf.shape(z_mean)[0]
    z_size = tf.shape(z_mean)[1]
    epsilon = tf.random.normal(shape=(batch_size, z_size))
    return z_mean + tf.exp(0.5 * z_log_var) * epsilon



Write the decoder

In [31]:
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(7 * 7 * 64, activation="relu")(latent_inputs)
x = layers.Reshape((7, 7, 64))(x)
x = layers.Dense(64, activation="relu")(x)
x = layers.Dense(32, activation="relu")(x)
decoder_outputs = layers.Dense(1, activation="sigmoid")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")

decoder.summary()

Model: "decoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_11 (InputLayer)       [(None, 2)]               0         
                                                                 
 dense_12 (Dense)            (None, 3136)              9408      
                                                                 
 reshape_3 (Reshape)         (None, 7, 7, 64)          0         
                                                                 
 dense_13 (Dense)            (None, 7, 7, 64)          4160      
                                                                 
 dense_14 (Dense)            (None, 7, 7, 32)          2080      
                                                                 
 dense_15 (Dense)            (None, 7, 7, 1)           33        
                                                                 
Total params: 15,681
Trainable params: 15,681
Non-trainable

Write the training code

In [32]:
class VAE(keras.Model):
  def __init__(self, encoder, decoder, **kwargs):
    super().__init__(**kwargs)
    self.encoder = encoder
    self.decoder = decoder
    self.sampler = Sampler()
    self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
    self.reconstruction_loss_tracker = keras.metrics.Mean(name="reconstruction_loss")
    self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

@property
def metrics(self):
  return [self.total_loss_tracker,self.reconstruction_loss_tracker,self.kl_loss_tracker]

def train_step(self, data):
  with tf.GradientTape() as tape:
    z_mean, z_log_var = self.encoder(data)
    z = self.sampler(z_mean, z_log_var)
    reconstruction = decoder(z)
    reconstruction_loss = tf.reduce_mean(
      tf.reduce_sum(
        keras.losses.binary_crossentropy(data, reconstruction),axis=(1, 2) )
    )
    kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
    total_loss = reconstruction_loss + tf.reduce_mean(kl_loss)
    grads = tape.gradient(total_loss, self.trainable_weights)
    self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
    self.total_loss_tracker.update_state(total_loss)
    self.reconstruction_loss_tracker.update_state(reconstruction_loss)
    self.kl_loss_tracker.update_state(kl_loss)
    return {
    "total_loss": self.total_loss_tracker.result(),
    "reconstruction_loss": self.reconstruction_loss_tracker.result(),
    "kl_loss": self.kl_loss_tracker.result(),}

Try training the model...

In [35]:
import numpy as np

#(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()


#mnist_digits = np.concatenate([x_train, x_test], axis=0)
ite_scaled = np.expand_dims(ite, -1).astype("float32") / 4

#mnist_digits.shape
#ite.shape
vae = VAE(encoder, decoder)
vae.compile(optimizer=keras.optimizers.Adam(), run_eagerly=True)
vae.fit(ite_scaled, epochs=30, batch_size=128)

KeyboardInterrupt: ignored

Epoch 1/30


Exception ignored in: 'zmq.backend.cython.message.Frame.__dealloc__'
Traceback (most recent call last):
  File "zmq/backend/cython/checkrc.pxd", line 13, in zmq.backend.cython.checkrc._check_rc
KeyboardInterrupt


NotImplementedError: ignored