In [1]:
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import os
import io
from PIL import Image

def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

# Create a dictionary with features that may be relevant.
def image_example(image_string):
    image_shape = tf.io.decode_jpeg(image_string).shape
    feature = {
        'height': _int64_feature(image_shape[0]),
        'width': _int64_feature(image_shape[1]),
        'depth': _int64_feature(image_shape[2]),
        'image_raw': _bytes_feature(image_string),
    }
    return tf.train.Example(features=tf.train.Features(feature=feature))

def _parse_image_function(example_proto, image_feature_description):
    return tf.io.parse_single_example(example_proto, image_feature_description)
        
def load_resize_compress_image(image_path, target_size=(512, 512), quality=95):
    # Load image
    with open(image_path, 'rb') as f:
        img_data = f.read()
    image = Image.open(io.BytesIO(img_data))
    # Resize image
    resized_image = image.resize(target_size, resample=Image.LANCZOS)
    # Compress as JPEG
    output_image = io.BytesIO()
    resized_image.save(output_image, format='JPEG', quality=quality)
    output_image.seek(0)
    compressed_data = output_image.getvalue()
    return compressed_data

def convert_to_tfrecord(images_dir, output_path, display_count=False):
    counter=0
    writer = tf.io.TFRecordWriter(output_path)
    for filename in os.listdir(images_dir):
        if filename.endswith('.jpg'):
            image_path = os.path.join(images_dir, filename)
            compressed_image = load_resize_compress_image(image_path)
            tf_example = image_example(compressed_image)
            writer.write(tf_example.SerializeToString())
            if display_count is True:
                counter+=1
                print(counter)
    writer.close()

# Plotting the images from tfrecord file
def plot_img_from_TFRecord(PATH, file_name):
    raw_image_dataset = tf.data.TFRecordDataset(PATH+file_name)
    image_feature_description = {
        'height': tf.io.FixedLenFeature([], tf.int64),
        'width': tf.io.FixedLenFeature([], tf.int64),
        'depth': tf.io.FixedLenFeature([], tf.int64),
        'image_raw': tf.io.FixedLenFeature([], tf.string),
    }
    parsed_image_dataset = raw_image_dataset.map(lambda x: _parse_image_function(x, image_feature_description))
    for image_features in parsed_image_dataset:
        image_raw = image_features['image_raw'].numpy()
        image = tf.image.decode_jpeg(image_raw)
        plt.imshow(image)
        plt.axis('off')
        plt.show()

In [2]:
images_dir = "" # Path to JPGs
output_path = "" # output path
file_name = 'test_images.tfrecord'
convert_to_tfrecord(images_dir, output_path+file_name, True)

1
2
3
4
5
6
7
8
