In [1]:
import tensorflow as tf
from pandas_plink import read_plink
import pandas as pd
import numpy as np

tfrecords_opts =  tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB)

  from ._conv import register_converters as _register_converters


Read in the PLINK data into a dask/pandas data structures.

In [None]:
plink_file = 'data/large_test'
bim, fam, G = read_plink(plink_file)
G = np.array(G.T, dtype=np.int8)
G.fillna(0, inplace=True)
N = G.shape[0]
M = G.shape[1]

Mapping files: 100%|██████████| 3/3 [00:01<00:00,  1.85it/s]


Write a .tfrecords file for the genotype matrix.

In [84]:
def write_record(row, writer_handle):
    '''
    row: a sample's genotype vector.
    '''
    # wrap raw byte values
    genotypes_feature = tf.train.Feature(
        bytes_list=tf.train.BytesList(value=[row.tostring()]))

    # convert to Example
    example = tf.train.Example(
        features=tf.train.Features(
            feature={'genotypes': genotypes_feature}))

    writer_handle.write(example.SerializeToString())

with tf.python_io.TFRecordWriter('data/test.tfrecords', options=tfrecords_opts) as tfwriter:
    np.apply_along_axis(write_record, axis=1, arr=G, writer_handle=tfwriter)


Write a decoder for the .tfrecords file.

In [85]:
def decode_tfrecords(tfrecords_filename, m_variants):
    '''
    Parse a tf.string pointing to *.tfrecords into a genotype tensor,  rows: variants, cols: samples)
    Helpful blog post:
    http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/
    '''
    data = tf.parse_example([tfrecords_filename],
        {'genotypes': tf.FixedLenFeature([], tf.string)})

    gene_vector = tf.decode_raw(data['genotypes'], tf.int8)
    gene_vector = tf.reshape(gene_vector, [1, m_variants])

    return gene_vector

Check our that the decoded results match the input.

In [86]:
graph = tf.Graph()
with graph.as_default():
    dataset = tf.data.TFRecordDataset('data/test.tfrecords', compression_type=tf.constant('ZLIB'))
    dataset = dataset.map(lambda fn: decode_tfrecords(fn, M))
    iterator = dataset.make_one_shot_iterator()
    x = iterator.get_next()

In [87]:
with tf.Session(graph=graph) as sess:
    decoded_rec = sess.run(x)
    print(decoded_rec)

[[2 2 1 1 2 2 2 2 2 1 2 0 2 2 2 1 2 2 1 1 1 0 1 2 2 2 1 1 1 2 2 1 1 0 1 2
  2 1 1 2 2 2 2 2 1 2 2 2 2 2 1 2 2 1 1 2 2 2 2 1 1 1 1 1 2 1 1 2 2 2 2 2
  2 2 2 1 1 1 1 1 2 1 2 2 1 2 1 2 1 2 1 2 2 1 1 2 2 1 1 2 1 2 2 2 1 1 1 1
  2 1 1 1 1 2 1 2 2 2 2 2 2 2 0 2 1 2 2 1 2 1 2 0 1 2 2 1 2 2 1 2 2 2 2 2
  1 1 2 1 2 1 1 1 2 2 1 1 1 0 2 2 2 0 1 1 2 2 2 2 2 2 2 1 2 2 1 1 1 2 2 2
  2 2 2 1 1 2 2 2 1 2 1 2 1 1 2 1 1 1 1 1 1 1 1 2 1 1 1 1 2 2 0 0 2 2 2 2
  2 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2 2 2 1 2 2 1 2 2 2 2 1 2 1 1 2 1 1
  1 1 2 1 2 2 2 2 2 2 2 2 2 2 2 2 1 2 2 2 0 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
  2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 0 1 2 1 1 2 2 1 2 2 2 1 1 1 2 2 1 1 1 2
  2 1 2 0 2 1 2 1 2 1 1 2 2 1 2 1 1 2 2 2 2 2 2 2 2 2 2 2 2 1 2 2 2 1 2 2
  2 1 2 1 2 2 2 2 2 1 1 0 1 0 1 1 1 1 2 2 2 1 0 2 2 1 1 1 0 2 0 1 2 0 1 2
  2 1 2 1 1 2 2 1 2 1 1 2 1 1 1 2 2 1 1 1 2 2 2 1 1 2 2 2 1 1 1 1 1 2 1 2
  1 1 2 1 2 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 1 1 1 2 1 1 2 1 1 1 2 2
  2 2 2 1 1 2 2 2 1 2 0 2 2 2 2 2 1 2 