# Genotype Factorized Variational Autoencoder

Putting together our various strains of work for a toy factorized variational autoencoder based on genotypes. The idea is to get the basic framework running, with the aim of getting some sensible results during training (i.e. can we maximize the evidence lower bound).

In [3]:
import tensorflow as tf
from pandas_plink import read_plink
import pandas as pd
import numpy as np
import datetime

tfd = tf.contrib.distributions

  from ._conv import register_converters as _register_converters


Define model and analysis parameters.

In [4]:
N = 100 # samples
M = 1000 # sites
D = 2 # latent dimension

batch_size = N
epochs = 10

Setup some `io` related things.

In [5]:
def decode_tfrecords(tfrecords_filename, m_variants):
    '''
    Parse a tf.string pointing to *.tfrecords into a genotype tensor,  rows: variants, cols: samples)
    Helpful blog post:
    http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/
    '''
    data = tf.parse_example([tfrecords_filename],
        {'genotypes': tf.FixedLenFeature([], tf.string)})

    gene_vector = tf.decode_raw(data['genotypes'], tf.int8)
    gene_vector = tf.reshape(gene_vector, [1, m_variants])

    return gene_vector

Some helper functions for model specification.

In [6]:
def make_encoder(data, z_dim, batch_size, num_features):
    data = tf.reshape(data, [batch_size, num_features])

    # sample latent variables
    x = tf.layers.dense(inputs=data,
            units=64, activation=tf.nn.sigmoid)
    x = tf.layers.dense(inputs=x,
            units=32, activation=tf.nn.sigmoid)
    x = tf.layers.dense(inputs=x,
            units=128, activation=tf.nn.sigmoid)
    u_net = tf.layers.dense(inputs=x,
                      units = z_dim * 2,
                      activation=None)
    u_loc = u_net[..., :z_dim]
    u_scale = tf.nn.softplus(u_net[..., z_dim:] + 0.5)
    u = tfd.MultivariateNormalDiag(u_loc, scale_diag=u_scale,
                                   name='sample_latent_U')
    
    # observation latent variables
    x_t = tf.transpose(data)
    x_t = tf.layers.dense(inputs=x_t,
            units=64, activation=tf.nn.sigmoid)
    x_t = tf.layers.dense(inputs=x_t,
            units=32, activation=tf.nn.sigmoid)
    x_t = tf.layers.dense(inputs=x_t,
            units=16, activation=tf.nn.sigmoid)
    v_net = tf.layers.dense(inputs=x_t,
                      units = z_dim * 2,
                      activation=None)
    v_loc = v_net[..., z_dim:]    
    v_scale = tf.nn.softplus(v_net[..., :z_dim] + 0.5)
    
    v = tfd.MultivariateNormalDiag(v_loc, scale_diag=v_scale,
                                   name='observation_latent_V')
    
    return u, v


def make_decoder(u, v, batch_size, num_features, z_dim):
    
    # "dot product decoder"
    z = tf.tensordot(u, v, axes=[[1], [1]])
    z = tf.reshape(z, [1, num_features*batch_size])
    logits = tf.nn.softplus(z)
    
    # assume fixed, unit variance
    data_dist = tfd.Independent(tfd.Binomial(logits=logits, total_count=2.0),
                    reinterpreted_batch_ndims=1,
                    name='posterior_p')
        
    return data_dist


def make_prior(z_dim):
    u_prior =  tfd.MultivariateNormalDiag(scale_diag=tf.ones(z_dim),
                                    name='U')
    v_prior = tfd.MultivariateNormalDiag(scale_diag=tf.ones(z_dim),
                                    name='V')

    return u_prior, v_prior

Model and input pipeline definition.

In [7]:
graph = tf.Graph()
with graph.as_default():
    # input pipeline
    dataset = tf.data.TFRecordDataset('data/test.tfrecords', compression_type=tf.constant('ZLIB'))
    dataset = dataset.map(lambda fn: decode_tfrecords(fn, M))
    dataset = dataset.batch(batch_size)
    iterator = dataset.make_initializable_iterator()
    data = iterator.get_next()
    data = tf.cast(data, tf.float32)
    
    with tf.variable_scope('priors'):
        u_prior, u_prior = make_prior(z_dim=D)
        
    # inference network; encoder
    with tf.variable_scope('encoder'):
        u_encoder, v_encoder = make_encoder(data, z_dim=D,
                                            batch_size=batch_size,
                                            num_features=M)
    
    u = u_encoder.sample()
    v = v_encoder.sample()

    # generative network; decoder
    with tf.variable_scope('decoder'):
        decoder_p = make_decoder(u, v, z_dim=D, num_features=M,
                                 batch_size=batch_size)
    
    # prior
    with tf.variable_scope('prior'):
        u_prior, v_prior = make_prior(z_dim=D)

    # loss
    u_kl = tf.reduce_sum(tfd.kl_divergence(u_encoder, u_prior))
    v_kl = tf.reduce_sum(tfd.kl_divergence(v_encoder, v_prior))
    likelihood = tf.reduce_sum(decoder_p.log_prob(tf.reshape(data, [1, N*M])))
    elbo = -u_kl - v_kl + likelihood
    tf.summary.scalar('elbo', elbo)
    tf.summary.scalar('minus_u_kl', tf.negative(u_kl))
    tf.summary.scalar('minus_v_kl', tf.negative(v_kl))
    tf.summary.scalar('likelihood', likelihood)

    
    # optimizer
    optimizer = tf.train.AdamOptimizer(0.001).minimize(-elbo)
    merged = tf.summary.merge_all()

Estimate model parameters and monitor the routine in tensorboard.

In [8]:
# tensorboard
run = 'run-{date:%d.%m.%Y_%H:%M:%S}'.format(date=datetime.datetime.now())
tb_writer = tf.summary.FileWriter('/logs/geno_fvae/' + run, graph=graph)

# training
with tf.Session(graph=graph) as sess:    
    sess.run(tf.global_variables_initializer())
    for epoch in range(epochs):
        sess.run(iterator.initializer)
        while True:
            try:
                _, tb_summary, epoch_elbo = sess.run([optimizer, merged, elbo])
                print(epoch_elbo)
                tb_writer.add_summary(tb_summary, epoch)
            except tf.errors.OutOfRangeError:
                break

-100976.47
-100537.88
-97555.02
-96816.87
-96311.48
-95627.76
-94476.54
-94997.836
-95082.195
-94932.6
