Distribued Tensorflow for genetic data
=========================

# Initial Thoughts

- It might be easier to just estimate the paramters with a regression for each block than running it in parallel over a neural network
- I probably need to add an illustration
- running a local neural network would give me more flexibility in terms on intra-block interaction (although that seems not really likely)
- A way to go this would probably with tensforflow and dask. Although I am not sure how this would work in detail

## Current approach

- genotype matrix was re-written in sample-major mode

In [46]:
%matplotlib inline
import os
from wepredict.plink_reader import *
import logging
import pickle
import pandas as pd
import tensorflow as tf
from sklearn.preprocessing import scale
from tqdm import tqdm_notebook
import re

In [26]:
lg = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)

batch_folder = '../data/sample_major_1kg/'
batch_files = os.listdir(batch_folder)
batch_files = [os.path.join(batch_folder, k) for k in batch_files]
lg.info('There are %s file present', len(batch_files))

ld_blocks = pickle.load(open('../data/sim_1000G_chr10.ld_blocks.pickel', 'rb'))
ld_blocks = ld_blocks[10] # remove the chromosom
lg.info('There are %s LD blocks to process', len(ld_blocks))

fam = pd.read_table('../data/sim_1000G_chr10.txt')
pheno = fam['V1'].values
sub_pheno = pheno[0:100]
sub_pheno = sub_pheno.reshape(100, 1)
lg.info('There are %s subjects', len(pheno))

bim = pd.read_table('../data/sim_1000G_chr10.bim', header=None)
snps = bim[1].values
p = len(snps)
lg.info('There are %s SNPs', len(snps))

INFO:__main__:There are 25 file present
INFO:__main__:There are 85 LD blocks to process
INFO:__main__:There are 2504 subjects
INFO:__main__:There are 405378 SNPs


In [3]:
def make_block_id(snps, blocks):
    output = list()
    u = 0
    for i, b in enumerate(blocks):
        nn = len(b)
        mask = np.zeros(len(snps), dtype=bool)
        mask[u:(u+nn)] = True
        u+=nn
        output.append(mask)
        if i % 10 == 0:
            lg.debug('Processing LD block %s', i)
    return output

from sklearn.utils import shuffle

def data_iter(paths, pheno, shuffle_values=True):
    np.random.shuffle(paths)
    for p in paths:
        dat, bool_index = np.load(p)
        dat = dat.astype(np.float32)
        batch_pheno = pheno[bool_index].astype(float)
        batch_pheno = batch_pheno.reshape(len(batch_pheno), 1)
        if shuffle_values:
            dat, batch_pheno = shuffle(dat, batch_pheno)
        yield dat, batch_pheno, p
        
def get_block_matrix(data, block):
    num_rows, num_cols = data.shape
    num_features = np.sum(block)
    new_block = tf.boolean_mask(data, block, axis=1)
    new_block = tf.reshape(new_block, shape=(num_rows, num_features))
    lg.debug('size of new_block is %s with type %s', new_block.shape, new_block.dtype)
    return new_block

## Now the Tensorflow part

In [68]:
mini_batch_size = 100

def linear(in_var, num_var):
    lg.debug('inpyt type is: %s', in_var.dtype)
    weights = tf.get_variable('weights', [num_var, 1],
                              initializer=tf.random_normal_initializer(0, 0.001))
    bias = tf.get_variable('bias', [1, 1],
                           initializer=tf.constant_initializer(0.0))
    comb = tf.matmul(in_var, weights)
    return tf.nn.relu(comb + bias)

def linear_block_wise(data, blocks):
    linear_block_variables = list()
    with tf.variable_scope('LD_blocks'):
        for i, b in tqdm_notebook(enumerate(blocks)):
            with tf.variable_scope('linear_block_'+str(i)):
                block_matrix = get_block_matrix(data, b)
                p = sum(b)
                lg.debug('processing LD block %s wiht %s SNPs', i, p)
                linear_block_variables.append(linear(block_matrix, p))
    return linear_block_variables

def add_layers_on_blocks(ld_blocks_weights, num_layers):
    n_blocks = len(ld_blocks_weights)
    stacked_blocks = tf.concat(ld_blocks_weights, 1)
    lg.debug('stacked blocks are %s', stacked_blocks.shape)
    with tf.variable_scope('layer_over_blocks'):
        weights = tf.get_variable('weights',
                                  [n_blocks, num_layers],
                                 initializer=tf.random_normal_initializer(0, 0.001))
        bias = tf.get_variable('bias', [num_layers],
                               initializer=tf.constant_initializer(0.0))
        comb = tf.nn.bias_add(tf.matmul(stacked_blocks, weights), bias)
    return tf.nn.relu(comb)

def add_layers(origin, num_layers, scope_name, output=False):
    with tf.variable_scope(scope_name):
        n, p = origin.shape
        weights = tf.get_variable('weights',
                                  [p, num_layers],
                                 initializer=tf.random_normal_initializer())
        bias = tf.get_variable('bias', [num_layers],
                               initializer=tf.constant_initializer(0.0))
        comb = tf.nn.bias_add(tf.matmul(origin, weights), bias)
    if output:
        return comb
    else:
        return tf.nn.relu(comb)

In [69]:
tf.reset_default_graph()
bool_ldblocks = make_block_id(snps, ld_blocks)

DEBUG:wepredict.plink_reader:Processing LD block 0
DEBUG:wepredict.plink_reader:Processing LD block 10
DEBUG:wepredict.plink_reader:Processing LD block 20
DEBUG:wepredict.plink_reader:Processing LD block 30
DEBUG:wepredict.plink_reader:Processing LD block 40
DEBUG:wepredict.plink_reader:Processing LD block 50
DEBUG:wepredict.plink_reader:Processing LD block 60
DEBUG:wepredict.plink_reader:Processing LD block 70
DEBUG:wepredict.plink_reader:Processing LD block 80


In [70]:
plink_data = tf.placeholder(tf.float32, [mini_batch_size, p], name='genotype')

In [71]:
pheno_p = tf.placeholder(tf.float32, [mini_batch_size, 1], name='phenotype')
block_layer = linear_block_wise(plink_data, bool_ldblocks)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

DEBUG:wepredict.plink_reader:size of new_block is (100, 1506) with type <dtype: 'float32'>
DEBUG:wepredict.plink_reader:processing LD block 0 wiht 1506 SNPs
DEBUG:wepredict.plink_reader:inpyt type is: <dtype: 'float32'>
DEBUG:wepredict.plink_reader:size of new_block is (100, 2986) with type <dtype: 'float32'>
DEBUG:wepredict.plink_reader:processing LD block 1 wiht 2986 SNPs
DEBUG:wepredict.plink_reader:inpyt type is: <dtype: 'float32'>
DEBUG:wepredict.plink_reader:size of new_block is (100, 3177) with type <dtype: 'float32'>
DEBUG:wepredict.plink_reader:processing LD block 2 wiht 3177 SNPs
DEBUG:wepredict.plink_reader:inpyt type is: <dtype: 'float32'>
DEBUG:wepredict.plink_reader:size of new_block is (100, 5049) with type <dtype: 'float32'>
DEBUG:wepredict.plink_reader:processing LD block 3 wiht 5049 SNPs
DEBUG:wepredict.plink_reader:inpyt type is: <dtype: 'float32'>
DEBUG:wepredict.plink_reader:size of new_block is (100, 4088) with type <dtype: 'float32'>
DEBUG:wepredict.plink_reader:

DEBUG:wepredict.plink_reader:processing LD block 37 wiht 5140 SNPs
DEBUG:wepredict.plink_reader:inpyt type is: <dtype: 'float32'>
DEBUG:wepredict.plink_reader:size of new_block is (100, 4175) with type <dtype: 'float32'>
DEBUG:wepredict.plink_reader:processing LD block 38 wiht 4175 SNPs
DEBUG:wepredict.plink_reader:inpyt type is: <dtype: 'float32'>
DEBUG:wepredict.plink_reader:size of new_block is (100, 2310) with type <dtype: 'float32'>
DEBUG:wepredict.plink_reader:processing LD block 39 wiht 2310 SNPs
DEBUG:wepredict.plink_reader:inpyt type is: <dtype: 'float32'>
DEBUG:wepredict.plink_reader:size of new_block is (100, 1159) with type <dtype: 'float32'>
DEBUG:wepredict.plink_reader:processing LD block 40 wiht 1159 SNPs
DEBUG:wepredict.plink_reader:inpyt type is: <dtype: 'float32'>
DEBUG:wepredict.plink_reader:size of new_block is (100, 6773) with type <dtype: 'float32'>
DEBUG:wepredict.plink_reader:processing LD block 41 wiht 6773 SNPs
DEBUG:wepredict.plink_reader:inpyt type is: <dtyp

DEBUG:wepredict.plink_reader:inpyt type is: <dtype: 'float32'>
DEBUG:wepredict.plink_reader:size of new_block is (100, 2486) with type <dtype: 'float32'>
DEBUG:wepredict.plink_reader:processing LD block 75 wiht 2486 SNPs
DEBUG:wepredict.plink_reader:inpyt type is: <dtype: 'float32'>
DEBUG:wepredict.plink_reader:size of new_block is (100, 6502) with type <dtype: 'float32'>
DEBUG:wepredict.plink_reader:processing LD block 76 wiht 6502 SNPs
DEBUG:wepredict.plink_reader:inpyt type is: <dtype: 'float32'>
DEBUG:wepredict.plink_reader:size of new_block is (100, 6804) with type <dtype: 'float32'>
DEBUG:wepredict.plink_reader:processing LD block 77 wiht 6804 SNPs
DEBUG:wepredict.plink_reader:inpyt type is: <dtype: 'float32'>
DEBUG:wepredict.plink_reader:size of new_block is (100, 3160) with type <dtype: 'float32'>
DEBUG:wepredict.plink_reader:processing LD block 78 wiht 3160 SNPs
DEBUG:wepredict.plink_reader:inpyt type is: <dtype: 'float32'>
DEBUG:wepredict.plink_reader:size of new_block is (10




In [72]:
next_layer = add_layers_on_blocks(block_layer, 50)
n1_layer = add_layers(next_layer, 20, 'layer_1')
n2_layer = add_layers(n1_layer, 20, 'layer_2')
n3_layer = add_layers(n2_layer, 20, 'layer_3')
out_layer = add_layers(n3_layer, 1, 'outputlayer', output=True)

DEBUG:wepredict.plink_reader:stacked blocks are (100, 85)


In [73]:
LD_block_weights = [k for k in tf.trainable_variables() if bool(re.match('(?=.*LD_blocks.*)(?=.*weight.*)', k.name))]

In [75]:
#writer = tf.summary.FileWriter('.')
regular = tf.contrib.layers.l1_regularizer(0.001, scope=None)
penal = tf.contrib.layers.apply_regularization(regularizer=regular, weights_list=LD_block_weights)
loss = tf.reduce_mean(tf.square(out_layer-pheno_p)) + penal
optimizer = tf.train.GradientDescentOptimizer(0.0001)
train = optimizer.minimize(loss)

tf.summary.scalar('loss', loss)

init =  tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    for i in range(100):
        iterator = data_iter(batch_files, pheno)
        u = 0
        for x, y, path in iterator:
            _, c = sess.run([train, loss], feed_dict={plink_data: x, pheno_p: y})
            assert np.isfinite(c)
            if u % 10 == 0:
                print(u, ':', c)
            u += 1
        if i % 5 == 0:
            print(c)

0 : 1.1742045


AssertionError: 