In [None]:
import numpy as np
from PIL import Image
import tensorflow as tf
import matplotlib.pyplot as plt

from os import path
import glob

In [None]:
# Datasets to read
base_path = '/home/rshuai/research/u-net-reconstruction/data/datasets/gt-records'
dataset_names = ['broad-institute', 'cil', 'denoising-fluorescence', 'hpa', 'nucleus-seg']

In [None]:
# Define util methods
obj_dims = (648, 486)

feature_description = {
    'plane': tf.io.FixedLenFeature(obj_dims, tf.float32)
}

def _parse_function(example_proto):
  # Parse the input tf.Example proto using the dictionary above.
  return tf.io.parse_single_example(example_proto, feature_description)

def np_to_pil(img_np): 
    '''Converts image in np.array format to PIL image.
    '''
    ar = np.clip(img_np*255,0,255).astype(np.uint8)
    return Image.fromarray(ar)

In [None]:
# Read in datasets and store in array
datasets = dict()
for dataset_name in dataset_names:
    filenames = glob.glob(path.join(base_path, dataset_name, '*'))
    raw_dataset = tf.data.TFRecordDataset(filenames=filenames)
    dataset = raw_dataset.map(_parse_function)
    datasets[dataset_name] = dataset

In [None]:
# Visualize each image
for dataset_name in datasets:
    for i, sample in enumerate(datasets[dataset_name]):
        img_np = sample['plane'].numpy()
        img_pil = np_to_pil(img_np)
        img_pil.save('temp/{}_{:04d}.png'.format(dataset_name, i))