In [1]:
import numpy as np
from PIL import Image
import tensorflow as tf
import matplotlib.pyplot as plt

from os import path
import glob

In [2]:
# Datasets to read
base_path = '/home/rshuai/research/u-net-reconstruction/data/datasets/gt-records'
dataset_names = ['broad-institute', 'cil', 'denoising-fluorescence', 'hpa', 'nucleus-seg']

In [3]:
# Define util methods
obj_dims = (648, 486)

feature_description = {
    'plane': tf.io.FixedLenFeature(obj_dims, tf.float32)
}

def _parse_function(example_proto):
  # Parse the input tf.Example proto using the dictionary above.
  return tf.io.parse_single_example(example_proto, feature_description)

def np_to_pil(img_np): 
    '''Converts image in np.array format to PIL image.
    '''
    ar = np.clip(img_np*255,0,255).astype(np.uint8)
    return Image.fromarray(ar)

In [4]:
# Read in datasets and store in array
datasets = dict()
for dataset_name in dataset_names:
    filenames = glob.glob(path.join(base_path, dataset_name, '*'))
    raw_dataset = tf.data.TFRecordDataset(filenames=filenames)
    dataset = raw_dataset.map(_parse_function)
    datasets[dataset_name] = dataset

In [5]:
# Visualize each image
for dataset_name in datasets:
    for i, sample in enumerate(datasets[dataset_name]):
        img_np = sample['plane'].numpy()
        img_pil = np_to_pil(img_np)
#         img_pil.save('temp/{}/{}_{:04d}.png'.format(dataset_name, dataset_name, i))

KeyboardInterrupt: 

### Read and resave dataset

In [5]:
def normalize(im):
    """
    Normalizes im from 0 to 1.
    """
    im_max = np.max(im)
    im_min = np.min(im)
    return (im - im_min) / (im_max - im_min)

def _create_example(plane):
    """
    Creates and returns tf.Example from a given numpy array.
    """
    plane_feature = tf.train.Feature(float_list=tf.train.FloatList(value=plane.ravel()))
    feature = {
        'plane': plane_feature
    }
    return tf.train.Example(features=tf.train.Features(feature=feature))

In [6]:
for dataset_name in dataset_names:
    filenames = glob.glob(path.join('temp/{}'.format(dataset_name), '*'))

    for i in range(len(filenames)):
        sample = Image.open(filenames[i])
        sample = np.asarray(sample)
        sample = normalize(sample)

#         fwhm_scaled = fwhm_pixels
#         sigmas = fwhm_scaled / np.sqrt(8 * np.log(2))
#         sample = gaussian_filter(sample, sigmas)
#         sample = normalize(sample)

#         sample.astype(np.float16)

        record_file = path.join('temp/temp_records/{}/{}'.format(dataset_name, dataset_name) + '-%.5d' % i + '.tfrecord')
        with tf.io.TFRecordWriter(record_file) as writer:
            tf_example = _create_example(sample)
            writer.write(tf_example.SerializeToString())