### Imports

In [1]:
% matplotlib inline

import sys
import os
os.chdir('../..')
sys.path.append('../..')

from io import BytesIO
from pathlib import Path

import menpo.io as mio
import tensorflow as tf
from menpo.visualize import print_progress
from project.utils import tfrecords

# Generate TFRecords files

The recommended format for TensorFlow is a TFRecords file containing `tf.train.Example` protocol buffers (which contain `Features` as a field).

Here is a little program that gets your data, stuffs it in an `Example` protocol buffer, serializes the protocol buffer to a string, and then writes the string to a TFRecords file using the `tf.python_io.TFRecordWriter`.

In [2]:
""" Functions for writing TFRecord features """


def _int_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def get_jpg_string(image):
    # Gets the serialized jpg from a menpo `Image`.
    fp = BytesIO()
    mio.export_image(image, fp, extension='jpg')
    fp.seek(0)
    return fp.read()

In [3]:
def face_iterator(images):
    """ Given an iterable of images, returns a generator of cat face data """
    for idx, img in enumerate(print_progress(images, end_with_newline=False)):
        image_name = img.path.name

        yield image_name, img


def generate(iterator,
             store_path='./',
             record_name='inference.tfrecords',
             store_records=True):
    store_path = Path(store_path)

    if store_records:
        writer = tf.python_io.TFRecordWriter(str(store_path / record_name))

    for img_name, pimg in iterator:
        
        # resize image to 256 * 256
        cimg = pimg.resize([256, 256])

        img_path = store_path / '{}'.format(img_name)

        if store_records:
            try:
                # construct the Example proto object
                example = tf.train.Example(
                    features=tf.train.Features(
                        # Features contains a map of string to Feature proto objects
                        feature={
                            # images
                            'image': tfrecords.bytes_feature(get_jpg_string(cimg)),
                            'height': tfrecords.int_feature(cimg.shape[0]),
                            'width': tfrecords.int_feature(cimg.shape[1]),
                        }))
                # use the proto object to serialize the example to a string
                serialized = example.SerializeToString()
                # write the serialized object to disk
                writer.write(serialized)

            except Exception as e:
                print('Something bad happened when processing image: "{}"'.format(img_name))
                print(e)

    if store_records:
        writer.close()

## Generate TFRecords for Inference

Run the following cells to generate TFRecords for a directory of images to perform inference on

In [4]:
# Where are the images located?
images_folder = Path('data/images')

# where should the resulting TFRecords files be written to?
store_path = Path('data/images')
inference_record_name = "inference.tfrecords"

In [5]:
# Run this to generate the TFRecord file!
from menpo.landmark import labeller, left_ventricle_34,left_ventricle_34_trimesh,left_ventricle_34_trimesh1
from menpodetect import load_dlib_left_ventricle_detector
# load the images
#images = mio.import_images(images_folder)
detector=load_dlib_left_ventricle_detector("detector.svm")
def load_database(path_to_images, crop_percentage,max_diagonal=400, max_images=None):
    images = []
    # load landmarked images
    for i in mio.import_images(path_to_images, max_images=max_images, verbose=True):
           
        # convert it to grayscale if needed
        if i.n_channels == 3:
            i = i.as_greyscale(mode='luminosity')
        
        
        d = i.diagonal()
        if d > max_diagonal:
            i = i.rescale(float(max_diagonal) / d)
        bboxes = detector(i)
        #print("{} detected .".format(len(bboxes)),len(images),i.path)
        initial_bbox = bboxes[0]
        # crop image
        i = i.crop_to_pointcloud(initial_bbox,boundary=40 )
        # append it to the list
        images.append(i)
    return images
crop_percentage = 0.5
path_to_lfpw = images_folder
images = load_database(path_to_lfpw,crop_percentage)
print('Found {} assets'.format(len(images)))

# generate TFRecords
generate(face_iterator(images), store_path, inference_record_name,
         store_records=True)

Found 168 assets, index the returned LazyList to import.
Found 168 assets