In [1]:
import tensorflow as tf
import os
import matplotlib.image as mpimg

In [2]:
import random

In [3]:
print(tf.__version__)

2.0.0


In [4]:
class GenTFRecord:
    def __init__(self, labels):
        self.labels = labels
        
    def conv_img_folder(self, img_folder, tfrecord_file_name, labels_file_name):
        print('Determining list of input files and labels from %s.' % img_folder)
        unique_labels = [l.strip() for l in tf.io.gfile.GFile(labels_file_name, 'r').readlines()]
        labels = []
        filenames = []
        texts = []
        
        # Leave label index 0 empty as a background class.
        label_index = 1
        
        # Construct the list of JPEG files and labels.
        for text in unique_labels:
            jpeg_file_path = '%s/%s/*' % (img_folder, text)
            matching_files = tf.compat.v1.gfile.Glob(jpeg_file_path)
            #print(matching_files)
            #img_folder1 = os.path.join(img_folder,text)
            labels.extend([label_index] * len(matching_files))
            texts.extend([text] * len(matching_files))
            filenames.extend(matching_files)
            #print(filenames)
            if not label_index % 100:
                print('Finished finding files in %d of %d classes.' % (label_index, len(labels)))
                label_index += 1
                

        img_paths = [os.path.abspath(i) for i in filenames]
        #print(img_paths)
        with tf.io.TFRecordWriter(tfrecord_file_name) as writer:
            for img_path in img_paths:
                #img_folder1 = os.path.dirname(os.path.abspath(img_path))
                example = self._convert_image(img_path)
                writer.write(example.SerializeToString())            
        #print(filenames)

    def _is_png_image(self, filename):
        ext = os.path.splitext(filename)[1].lower()
        return ext == '.png'

    def _convert_png_to_jpeg(self, img):
        png_enc = tf.image.decode_png(img, channels = 3)
        return tf.image.encode_jpeg(png_enc, format = 'rgb', quality = 100)

    def _convert_image(self, img_path):
        img_folder1 = os.path.dirname(os.path.abspath(img_path))
        label = self._get_label_with_filename(img_folder1)
        #label = labels[img_path]
        img_shape = mpimg.imread(img_path).shape
        filename = os.path.basename(img_path).split('.')[0]
        print(filename,label)

        # Read image data in terms of bytes
        with tf.compat.v1.gfile.FastGFile(img_path, 'rb') as fid:
            image_data = fid.read()

            # Encode PNG data to JPEG data
            if self._is_png_image(img_path):
                image_data = self._convert_png_to_jpeg(image_data)

        example = tf.train.Example(features = tf.train.Features(feature = {
            'filename': tf.train.Feature(bytes_list = tf.train.BytesList(value = [filename.encode('utf-8')])),
            'rows': tf.train.Feature(int64_list = tf.train.Int64List(value = [img_shape[0]])),
            'cols': tf.train.Feature(int64_list = tf.train.Int64List(value = [img_shape[1]])),
            'channels': tf.train.Feature(int64_list = tf.train.Int64List(value = [3])),
            'image': tf.train.Feature(bytes_list = tf.train.BytesList(value = [image_data])),
            'label': tf.train.Feature(int64_list = tf.train.Int64List(value = [label])),
        }))
        return example
    
    def _get_label_with_filename(self, filename):
        basename = os.path.basename(os.path.normpath(filename))
        return self.labels[basename]
    
if __name__ == '__main__':
    labels = {'NORMAL': 0, 'PNEUMONIA': 1}
    t = GenTFRecord(labels)
    t.conv_img_folder('C:\\Users\\Prasanta\\Pictures\\images\\chest_xray\\val', 'C:\\Users\\Prasanta\\Pictures\\images\\chest_xray\\val\\images.tfrecord','C:\\Users\\Prasanta\\Pictures\\images\\chest_xray\\labels')

Determining list of input files and labels from C:\Users\Prasanta\Pictures\images\chest_xray\val.
NORMAL2-IM-1427-0001 0
Instructions for updating:
Use tf.gfile.GFile.
NORMAL2-IM-1430-0001 0
NORMAL2-IM-1431-0001 0
NORMAL2-IM-1436-0001 0
NORMAL2-IM-1437-0001 0
NORMAL2-IM-1438-0001 0
NORMAL2-IM-1440-0001 0
NORMAL2-IM-1442-0001 0
person1946_bacteria_4874 1
person1946_bacteria_4875 1
person1947_bacteria_4876 1
person1949_bacteria_4880 1
person1950_bacteria_4881 1
person1951_bacteria_4882 1
person1952_bacteria_4883 1
person1954_bacteria_4886 1


In [5]:
tfrecord_location = 'C:\\Users\\Prasanta\\Pictures\\images\\chest_xray\\val'
name = "images.tfrecord"
filename = os.path.join(tfrecord_location, name)
print(filename)

C:\Users\Prasanta\Pictures\images\chest_xray\val\images.tfrecord


In [6]:
dataset = tf.data.TFRecordDataset(filename)
type(dataset)

tensorflow.python.data.ops.readers.TFRecordDatasetV2

In [7]:
num_examples = 0

for example in dataset:
  num_examples += 1

print('Total Number of Images: {}'.format(num_examples))

Total Number of Images: 16


In [8]:
def decode(tfrecord):
  """
  Parses an image and label from the given `serialized_example`.
  It is used as a map function for `dataset.map`
  """
  IMAGE_SIZE = 224
  #IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE

        
  
  features={
      'filename': tf.compat.v1.FixedLenFeature([], tf.string),
      'rows': tf.compat.v1.FixedLenFeature([], tf.int64),
      'cols': tf.compat.v1.FixedLenFeature([], tf.int64),
      'channels': tf.compat.v1.FixedLenFeature([], tf.int64),
      'image': tf.compat.v1.FixedLenFeature([], tf.string),
      'label': tf.compat.v1.FixedLenFeature([], tf.int64),
  }
      

  # Extract the data record
  sample = tf.io.parse_single_example(tfrecord, features)

  #image = tf.image.decode_image(sample['image'])     
  image = tf.image.decode_jpeg(sample['image'], channels=3)
  img_shape = tf.stack([sample['rows'], sample['cols'], sample['channels']])
  label = tf.cast(sample['label'], tf.int32)
  filename = sample['filename']
  image = tf.image.resize(image, (IMAGE_SIZE, IMAGE_SIZE))/255.0

  return image,label 

In [9]:
dataset = dataset.map(decode)

In [11]:
for x, y in dataset:
    print(y)

tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)


In [12]:
type(dataset)

tensorflow.python.data.ops.dataset_ops.MapDataset