# Writing TFRecords with DICOM Metadata
 
This Notebook creates Tensorflow Records for the Pulmonary Embolism competition.

It is heavily derived from [https://www.kaggle.com/cdeotte/how-to-create-tfrecords] Courtesy of Chris Deotte.

Please let me know if there are sections of the Notebook where additional credit should be given.


TFRecords can be used by CPU, GPU and TPU. To use them with a TPU I usually make them public. There are ways to use them when they are private, but there is some configuration involved that I haven't tackled.

There are two ways to get these TFRecords into a dataset:

1. If you run this notebook in interactive mode, download the output and then upload to a dataset.

2. If you run in Commit mode, you can send the output files directly to a dataset. 


This notebook stores Images, targets and metadata from the DICOM files in the TFRecords. They can be used in models that combine Image and Metadata into one model.

First, we import our typical libraries (probably more than I need for this notebook).

Paths point to the train or test dataset and a location to store the TFRec files. They might take some adjustment if you are running outside of Notebooks. There is a flag that should allow you to run them for both train and test data.

Due to the size of the data, you can specify a start record and a size of the extract. Note that this is based on a sorted order (StudyInstanceUID and SOPInstanceUID), so as not to rely on the order of the train/test file or the directories.

To break things up, we specify a start position in the train file list and a number of files to get.

10,000 images result in about 250 MB. So I estimate 45 GB of data.

Reasonable block sizes would be between 50,000 and 100,000 images. 100,000 might get close to 9 hours of processing time.

Processing in notebooks will hit disk space and time limitations.


In [None]:
import numpy as np, pandas as pd, os
import re, math, random, csv, cv2
import pydicom
import tensorflow as tf
from skimage.transform import resize

!conda install -y -c conda-forge gdcm

train_or_test = 'train'

mypath = '../input/rsna-str-pulmonary-embolism-detection'
images_folder =    mypath+'/'+train_or_test
tfrec_folder = ''

csv_file = mypath + '/' + train_or_test + '.csv'

# change this prefix _a_ -> _b_, etc for each time you run this routine
prefix = train_or_test + '_tfrec_version1_a_-'

# for processing, break the Train data into smaller chunks.

df_start = 100000

#change to 50000 for typical processing. df_size = 1000 just for testing
df_size = 1000

metadata_file = 'metadata_' + train_or_test + '_' + str(df_start) + '.csv'


# First we will extract the DICOM Metadata.

This routine reads a DICOM directory and gets the Metadata from all the images.

I use try/except blocks since there is no guarantee that a DICOM Tag exists in every file. Pydicom also lets you explicitly check for a tag before reading it.

I provide default values if the tag doesn't exist.

I write these out to a file, which can be Downloaded. I'll use it as we create TFRecords.

I break out Tags that contain arrays into separate elements. For some, I just take the first element.

In [None]:
def extract_DICOM_attributes(folder):
    images = list(os.listdir(os.path.join(folder)))
    for image in images:
        image_name = image.split(".")[0]
        dicom_file_path = os.path.join(folder,image)
        ds = pydicom.read_file(dicom_file_path)
        myStudyInstanceUID = ds.StudyInstanceUID.strip()
        mySeriesInstanceUID = ds.SeriesInstanceUID.strip()
        mySOPInstanceUID = ds.SOPInstanceUID.strip()
        slicethickness = ds.SliceThickness
        rows = ds.Rows
        columns = ds.Columns
        pixelspacing0 = ds.PixelSpacing[0]
        pixelspacing1 = ds.PixelSpacing[1]
        wc = ds.WindowCenter
        if hasattr(wc, "__len__"):
            wc = wc[0]
        ww = ds.WindowWidth
        if hasattr(ww, "__len__"):
            ww = ww[0]
        ri = ds.RescaleIntercept
        rs = ds.RescaleSlope
        try:
            pp = ds.PatientPosition
        except:
            pp = 'FFS'
        try:
            ipp0 = ds.ImagePositionPatient[0]
        except:
            ipp0 = 0
        try:
            ipp1 = ds.ImagePositionPatient[1]
        except:
            ipp1 = 0
        try:
            ipp2 = ds.ImagePositionPatient[2]
        except:
            ipp2 = 0
        try:
            io0 = ds.ImageOrientationPatient[0]
        except:
            io0 = 0
        try:
            io1 = ds.ImageOrientationPatient[1]
        except:
            io1 = 0
        try:
            io2 = ds.ImageOrientationPatient[2]
        except:
            io2 = 0
        try:
            io3 = ds.ImageOrientationPatient[3]
        except:
            io3 = 0
        try:
            io4 = ds.ImageOrientationPatient[4]
        except:
            io4 = 0
        try:
            io5 = ds.ImageOrientationPatient[5]
        except:
            io5 = 0
        try:
            inum = ds.InstanceNumber
        except:
            inum = 0
        try:
            kvp = ds.KVP
        except:
            kvp = 0
        try:
            tc = ds.XRayTubeCurrent
        except:
            tc = 0
        try:
            exposure = ds.Exposure
        except:
            exposure = 0

        mystring= myStudyInstanceUID + ',' + \
            mySeriesInstanceUID + ',' + \
            mySOPInstanceUID + ',' + \
            str(slicethickness) + ',' + \
            str(pixelspacing0) + ',' + \
            str(pixelspacing1) + ',' + \
            str(wc) + ',' + \
            str(ww) + ',' + \
            str(ri) + ',' + \
            str(rs) + ',' + \
            str(rows) + ',' + \
            str(columns) + ',' + \
            str(pp) + ',' + \
            str(ipp0) + ',' + \
            str(ipp1) + ',' + \
            str(ipp2) + ',' + \
            str(io0) + ',' + \
            str(io1) + ',' + \
            str(io2) + ',' + \
            str(io3) + ',' + \
            str(io4) + ',' + \
            str(io5) + ',' + \
            str(inum) + ',' + \
            str(kvp) + ',' + \
            str(tc) + ',' + \
            str(exposure) + '\n'
        file1.write(mystring)

We sort the csv file and images directory list so they are in the same order.

Then we look through the patients and add their DICOM attributes to a metadata file. I read through the metadata first so that I can do post-processing on the metadata within each patient before I write the TFRecords (not present in this notebook). In practice I captured all the metadata once and just reference a file of the full metadata. In this stand-alone notebook, I build it on the fly. Also, for the private test data, you will need to build it on the fly anyway.

It would be more efficient to combine those two steps and cache the DICOM data, calculate on the patient level and then write the TFRecords.

In [None]:
file1 = open(metadata_file,'w')
file1.write('StudyInstanceUID,SeriesInstanceUID,SOPInstanceUID,dcm_slice_thickness,dcm_pixel_spacing0,dcm_pixel_spacing1,dcm_window_center,dcm_window_width,dcm_rescale_intercept,dcm_rescale_slope,dcm_rows,dcm_columns,dcm_patient_position,dcm_image_position_patient0,dcm_image_position_patient1,dcm_image_position_patient2,dcm_image_orientation0,dcm_image_orientation1,dcm_image_orientation2,dcm_image_orientation3,dcm_image_orientation4,dcm_image_orientation5,dcm_instance_number,dcm_kvp,dcm_xray_tube_current,dcm_exposure\n')

#mydirs = os.listdir(images_folder)
#mydirs.sort()

print('only reading limited records', df_size)
df = pd.read_csv(csv_file)
df = df.sort_values(by=['StudyInstanceUID','SOPInstanceUID'])
df = df[df_start:(df_start+df_size)].reset_index(drop=True)

df_temp = df
df_patients = df_temp['StudyInstanceUID'].unique()
patient_count = len(df_patients)
print('only reading limited patients', patient_count)

count = 0
#for mydir in mydirs[0:patient_count]:
for mydir in df_patients:
    SeriesUID = os.listdir(images_folder+'/'+mydir)
    extract_DICOM_attributes(images_folder+'/'+mydir+'/'+SeriesUID[0])
    count = count + 1
    if count//100 == count/100:
        print('patient count processed',count)
file1.close()

md = pd.read_csv(metadata_file)
df_count = len(md)
print('image count (adjusted to get full exams)',df_count)
print('done')

Next, we include standard Tensorflow routines for different datatypes. I think this is straight from the Tensorflow examples.

In [None]:
def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

Next, we have a custom routine that lists all the variables that we want to store in the Tensorflow records.

Note that the first variable is the image itself. The rest are from the train.csv file or the DICOM files.

In [None]:
def serialize_example(feature0, feature1, feature2, feature3, feature4, feature5, feature6,\
    feature7, feature8, feature9, feature10,\
    feature11, feature12, feature13, feature14, feature15, feature16, feature17, \
    feature18, feature19, feature20, feature21, feature22, feature23, feature24, \
    feature25, feature26, feature27, feature28, feature29, feature30, feature31):
  feature = {
      'image': _bytes_feature(feature0),
      'StudyInstanceUID': _bytes_feature(feature1),
      'SeriesInstanceUID': _bytes_feature(feature2),
      'SOPInstanceUID': _bytes_feature(feature3),
      'pe_present_on_image': _int64_feature(feature4),
      'negative_exam_for_pe': _int64_feature(feature5),
      'qa_motion':  _int64_feature(feature6),
      'qa_contrast':  _int64_feature(feature7),
      'flow_artifact':  _int64_feature(feature8),
      'rv_lv_ratio_gte_1':  _int64_feature(feature9),
      'rv_lv_ratio_lt_1':  _int64_feature(feature10),
      'leftsided_pe': _int64_feature(feature11),
      'chronic_pe':  _int64_feature(feature12),
      'true_filling_defect_not_pe':  _int64_feature(feature13),
      'rightsided_pe': _int64_feature(feature14),
      'acute_and_chronic_pe': _int64_feature(feature15),
      'central_pe': _int64_feature(feature16),
      'indeterminate': _int64_feature(feature17),
      'dcm_slice_thickness': _float_feature(feature18),
      'dcm_pixel_spacing0': _float_feature(feature19),
      'dcm_patient_position': _bytes_feature(feature20),
      'dcm_image_position_patient2': _float_feature(feature21),
      'dcm_image_orientation0': _float_feature(feature22),
      'dcm_image_orientation1': _float_feature(feature23),
      'dcm_image_orientation2': _float_feature(feature24),
      'dcm_image_orientation3': _float_feature(feature25),
      'dcm_image_orientation4': _float_feature(feature26),
      'dcm_image_orientation5': _float_feature(feature27),
      'dcm_instance_number': _int64_feature(feature28),
      'dcm_kvp': _int64_feature(feature29),
      'dcm_xray_tube_current': _int64_feature(feature30),
      'dcm_exposure': _int64_feature(feature31)
  }
  example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
  return example_proto.SerializeToString()

def serialize_example_image_only(feature0,feature1):
  feature = {
      'image': _bytes_feature(feature0),
      'SOPInstanceUID': _bytes_feature(feature1),
  }
  example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
  return example_proto.SerializeToString()

The window and level function is a pretty standard routine. (I'm note sure who claims original authorship.)

The DICOM data comes in Hounsfield units, which are being converted to grayscale. The range is roughly -1000 to 3000, so we pick a subset of the range that is clinically appropriate to our task.

I read the Window/Level (also called Window/Center) from the DICOM file. For this Competition, I am actually plugging in a standard W/C rather than using the DICOM entries. But I keep the code in case I want to use the actual DICOM values.

Note that WindowCenter and WindowLevel can be multi-valued. I just get the first entry.

I get the image data from the DICOM file within a "try:" expression. Some of the image data requires a library GDCM for decompression. This is tricky to install (see Melanoma contest for a wide range of attempts). At least in the Train Dataset, they are not that common, so I just replace them with a blank image.

I then define a "crop" function. CT images have a lot of space on the edges and a pulmonary embolism will never be outside the lungs. How much to crop is based on trial and error.

In [None]:
def window_image(img, window_center,window_width, intercept, slope):
    window_width = abs(window_width)
    img = (img*slope +intercept)
    img_min = window_center - window_width//2
    img_max = window_center + window_width//2
    img[img<img_min] = img_min
    img[img>img_max] = img_max
    img = img - np.min(img)
    img = img / np.max(img)
    img = (img*255.0).astype('uint8')
    return img 


def dicom_window(dcm_data):
    if 'RescaleIntercept' in dcm_data:
        RI = dcm_data.RescaleIntercept
    else:
        RI = -1024
    if 'RescaleSlope' in dcm_data:
        RS = dcm_data.RescaleSlope
    else:
        RS = 1
    if 'WindowCenter' in dcm_data:
        WC = dcm_data.WindowCenter
        if hasattr(WC, "__len__"):
            WC = WC[0]
    else:
        WC = 40
    if 'WindowWidth' in dcm_data:
        WW = dcm_data.WindowWidth
        if hasattr(WW, "__len__"):
            WW = WW[0]
    else:
        WW = 400
# get image from DICOM file. Sometimes error due to missing GDCM library.
    try:
        img = dcm_data.pixel_array
    except:
        print('DICOM Image error (usually GDCM library)',dcm_data.SOPInstanceUID)
        img = np.zeros(shape=(512,512))
# Here I replace the above parameters with a standard PE specific Window/Level. Idea courtesy of Ian Pan 
# Comment out these two lines if you want to use the W/L stored in the DICOM files.
    WC = 100
    WW = 700
    img = window_image(img,WC,WW,RI,RS)
    return(img)

def crop_center(img,cropx,cropy):
    y,x = img.shape
    startx = x//2-(cropx//2)
    starty = y//2-(cropy//2)    
    return img[starty:starty+cropy,startx:startx+cropx]

Because of the size of the dataset, I break it into manageable chunks. Typically 100,000 images at a time. Not perfect, because it could duplicate patients across TFRecords.

I read train.csv and the new metadata file and merge them.

In [None]:
print('only reading limited records', df_size)

#sort csv file because we might not have all Metadata created.
df = pd.read_csv(csv_file)
df = df.sort_values(by=['StudyInstanceUID','SOPInstanceUID'])
df = df[df_start:(df_start+df_size)].reset_index(drop=True)
df['name'] = df['StudyInstanceUID'] + '/' + df['SeriesInstanceUID'] + '/' + df['SOPInstanceUID'] + '.dcm'
df['name2'] = df['StudyInstanceUID'] + '/' + df['SOPInstanceUID']

#add dummy placeholder columns for non-existent test metadata. 
if train_or_test == 'test':
    df = df.assign(pe_present_on_image = 0)
    df = df.assign(negative_exam_for_pe = 0)
    df = df.assign(qa_motion = 0)
    df = df.assign(qa_contrast = 0)
    df = df.assign(flow_artifact = 0)
    df = df.assign(rv_lv_ratio_gte_1 = 0)
    df = df.assign(rv_lv_ratio_lt_1 = 0)
    df = df.assign(leftsided_pe = 0)
    df = df.assign(chronic_pe = 0)
    df = df.assign(true_filling_defect_not_pe = 0)
    df = df.assign(rightsided_pe = 0)
    df = df.assign(acute_and_chronic_pe = 0)
    df = df.assign(central_pe = 0)
    df = df.assign(indeterminate = 0)

md = pd.read_csv(metadata_file)
md['name2'] = md['StudyInstanceUID'] + '/' + md['SOPInstanceUID']
del md['StudyInstanceUID']
del md['SeriesInstanceUID']
del md['SOPInstanceUID']

df = pd.merge(df,md,on='name2')

filenames = df['name'].copy()

print(train_or_test, 'count',len(filenames))                        

# Write TFRecords

The main routine to read the DICOM images and train.csv data and write the TFRecords. Written in 10,000 record chunks.

All the images in this competition seem to be 512x512, but I check anyway and resize if necessary.

Then I crop to 400x400 because there is a lot of "air" on the edge of the patients and pulmonary embolism is never in the soft tissues of the chest. Value is just be some trial and error and hasn't been proven to be useful.

Then I resize to 256 x 256 for my final images.

I leave the images as Gray-Scale. Ian Pan (in his separate JPEG Dataset) uses the color channels to store three different Windows/Levels. I'm not sure if that takes up more disk storage. Probably a good idea, but I haven't implemented it yet.

I'll change the Gray-Scale images to RGB during training/testing since I use Models that are pretrained on color images.

Image errors related to decompression errors requiring the GDCM library result in blank images, but do not interfere with processing. 

In [None]:
#optional shuffle for randomization
#random.shuffle(filenames)

SIZE = 10000
IMGS = filenames
CT = len(IMGS)//SIZE + int(len(IMGS)%SIZE!=0)
for j in range(CT):
    print(); print('Writing TFRecord %i of %i...'%(j,CT))
    CT2 = min(SIZE,len(IMGS)-j*SIZE)
    with tf.io.TFRecordWriter(tfrec_folder + prefix + '%.2i-%i.tfrec'%(j,CT2)) as writer:
        for k in range(CT2):
            thisimg = SIZE*j+k
            thisfile = IMGS[thisimg]
            ds = pydicom.read_file(images_folder+'/'+thisfile)
            img = dicom_window(ds)
            if img.shape != [512,512]:
                img = resize(img, (512,512), \
                   mode='reflect', anti_aliasing=True,preserve_range=True)
            img = crop_center(img,400,400)
            img = resize(img, (256,256), \
                   mode='reflect', anti_aliasing=True,preserve_range=True)
            img = cv2.imencode('.jpg', img, (cv2.IMWRITE_JPEG_QUALITY, 94))[1].tostring()
            name = IMGS[SIZE*j+k].split('.')[0]
            
            example = serialize_example(
                img,
                str.encode(df.loc[df['name']==thisfile, 'StudyInstanceUID'].iloc[0]),
                str.encode(df.loc[df['name']==thisfile, 'SeriesInstanceUID'].iloc[0]),
                str.encode(df.loc[df['name']==thisfile, 'SOPInstanceUID'].iloc[0]),
                df.loc[df['name']==thisfile, 'pe_present_on_image'].iloc[0],
                df.loc[df['name']==thisfile, 'negative_exam_for_pe'].iloc[0],
                df.loc[df['name']==thisfile, 'qa_motion'].iloc[0],
                df.loc[df['name']==thisfile, 'qa_contrast'].iloc[0],
                df.loc[df['name']==thisfile, 'flow_artifact'].iloc[0],
                df.loc[df['name']==thisfile, 'rv_lv_ratio_gte_1'].iloc[0],
                df.loc[df['name']==thisfile, 'rv_lv_ratio_lt_1'].iloc[0],
                df.loc[df['name']==thisfile, 'leftsided_pe'].iloc[0],
                df.loc[df['name']==thisfile, 'chronic_pe'].iloc[0],
                df.loc[df['name']==thisfile, 'true_filling_defect_not_pe'].iloc[0],
                df.loc[df['name']==thisfile, 'rightsided_pe'].iloc[0],
                df.loc[df['name']==thisfile, 'acute_and_chronic_pe'].iloc[0],
                df.loc[df['name']==thisfile, 'central_pe'].iloc[0],
                df.loc[df['name']==thisfile, 'indeterminate'].iloc[0],
                df.loc[df['name']==thisfile, 'dcm_slice_thickness'].iloc[0],
                df.loc[df['name']==thisfile, 'dcm_pixel_spacing0'].iloc[0],
                str.encode(df.loc[df['name']==thisfile, 'dcm_patient_position'].iloc[0]),
                df.loc[df['name']==thisfile, 'dcm_image_position_patient2'].iloc[0],
                df.loc[df['name']==thisfile, 'dcm_image_orientation0'].iloc[0],
                df.loc[df['name']==thisfile, 'dcm_image_orientation1'].iloc[0],
                df.loc[df['name']==thisfile, 'dcm_image_orientation2'].iloc[0],
                df.loc[df['name']==thisfile, 'dcm_image_orientation3'].iloc[0],
                df.loc[df['name']==thisfile, 'dcm_image_orientation4'].iloc[0],
                df.loc[df['name']==thisfile, 'dcm_image_orientation5'].iloc[0],
                df.loc[df['name']==thisfile, 'dcm_instance_number'].iloc[0],
                df.loc[df['name']==thisfile, 'dcm_kvp'].iloc[0],
                df.loc[df['name']==thisfile, 'dcm_xray_tube_current'].iloc[0],
                df.loc[df['name']==thisfile, 'dcm_exposure'].iloc[0]
            )
            writer.write(example)
            if k%1000==0: print(k,', ',end='')
print('done')

# Now we should test our files by reading a sample back.

Keeping some header code (not needed in this notebook), so that this section can be a standalone notebook also.

If you use this section in a separate notebook, you'll need to update the location of the TFRecords to point to your stored DataSet.

In [None]:
import tensorflow as tf 
from tensorflow import keras 
from tensorflow.keras.utils import Sequence 
import tensorflow.keras.backend as K 
import tensorflow.keras.backend 
import tensorflow.keras.layers as L

# adjust to train or test if broken out to separate notebookd
#train_or_test = 'train'
#train_or_test = 'test'

# adjust to your directory if not working directory. Will need "/"
tfrec_folder = ''

#if accessing through a Dataset, enter name here:
my_GCS_PATH = 'your TFRecs Dataset here'

DEVICE = "TPU"

if DEVICE == "TPU":
    print("connecting to TPU...")
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        print('Running on TPU ', tpu.master())
    except ValueError:
        print("Could not connect to TPU")
        tpu = None

if tpu:
    try:
        print("initializing  TPU ...")
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
        print("TPU initialized")
    except _:
        print("failed to initialize TPU")
else:
    DEVICE = "GPU"
if DEVICE != "TPU": 
    print("Using default strategy for CPU and single GPU") 
    strategy = tf.distribute.get_strategy()

if DEVICE == "GPU": 
    print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

AUTO = tf.data.experimental.AUTOTUNE
REPLICAS = strategy.num_replicas_in_sync

CFG = dict(
    batch_size        = 32,
    read_size         = 256, 
    crop_size         = 256, 
    net_size          = 256)

import os, gc, random, cv2, csv
import numpy as np
from numpy.random import seed 
import pandas as pd
from skimage import measure
from skimage.transform import resize 
from sklearn.metrics import roc_auc_score
import tensorflow as tf 
#from tensorflow import keras 
#from tensorflow.keras.utils import Sequence 
#from matplotlib import pyplot as plt 
#import matplotlib.pyplot as plt 
#import matplotlib.image as mpimg 
import pydicom
import PIL 
from tqdm import tqdm

from kaggle_datasets import KaggleDatasets 
mypath = '../input/metadata' 

#If using the TPU, you have to use a different path.

if DEVICE == 'TPU': 
    GCS_PATH = KaggleDatasets().get_gcs_path(my_GCS_PATH) 
else: 
    GCS_PATH = ''
    
tfrec_filter = train_or_test + '*.tfrec'
print(tfrec_filter)

tfrec_files = np.sort(np.array(tf.io.gfile.glob(tfrec_filter)))
print('Number of files', len(tfrec_files))

Routines here to 
    1. Load images and metadata.
    2. Augment the images (rotate, intensity, etc). Not needed here, but might be used in a training model.

In [None]:
def read_labeled_tfrecord_metadata(example):
    tfrec_format = {
        'image'                        : tf.io.FixedLenFeature([], tf.string),
        'StudyInstanceUID'             : tf.io.FixedLenFeature([], tf.string),
        'SeriesInstanceUID'            : tf.io.FixedLenFeature([], tf.string),
        'SOPInstanceUID'               : tf.io.FixedLenFeature([], tf.string),
        'pe_present_on_image'          : tf.io.FixedLenFeature([], tf.int64),
        'negative_exam_for_pe'         : tf.io.FixedLenFeature([], tf.int64),
        'qa_motion'                    : tf.io.FixedLenFeature([], tf.int64),
        'qa_contrast'                  : tf.io.FixedLenFeature([], tf.int64),
        'flow_artifact'                : tf.io.FixedLenFeature([], tf.int64),
        'rv_lv_ratio_gte_1'            : tf.io.FixedLenFeature([], tf.int64),
        'rv_lv_ratio_lt_1'             : tf.io.FixedLenFeature([], tf.int64),
        'leftsided_pe'                 : tf.io.FixedLenFeature([], tf.int64),
        'chronic_pe'                   : tf.io.FixedLenFeature([], tf.int64),
        'true_filling_defect_not_pe'   : tf.io.FixedLenFeature([], tf.int64),
        'rightsided_pe'                : tf.io.FixedLenFeature([], tf.int64),
        'acute_and_chronic_pe'         : tf.io.FixedLenFeature([], tf.int64),
        'central_pe'                   : tf.io.FixedLenFeature([], tf.int64),
        'indeterminate'                : tf.io.FixedLenFeature([], tf.int64),
        'dcm_slice_thickness'          : tf.io.FixedLenFeature([], tf.float32),
        'dcm_pixel_spacing0'           : tf.io.FixedLenFeature([], tf.float32),
        'dcm_patient_position'         : tf.io.FixedLenFeature([], tf.string),
        'dcm_image_position_patient2'  : tf.io.FixedLenFeature([], tf.float32),
        'dcm_image_orientation0'       : tf.io.FixedLenFeature([], tf.float32),
        'dcm_image_orientation1'       : tf.io.FixedLenFeature([], tf.float32),
        'dcm_image_orientation2'       : tf.io.FixedLenFeature([], tf.float32),
        'dcm_image_orientation3'       : tf.io.FixedLenFeature([], tf.float32),
        'dcm_image_orientation4'       : tf.io.FixedLenFeature([], tf.float32),
        'dcm_image_orientation5'       : tf.io.FixedLenFeature([], tf.float32),
        'dcm_kvp'                      : tf.io.FixedLenFeature([], tf.int64),
        'dcm_instance_number'          : tf.io.FixedLenFeature([], tf.int64),
        'dcm_xray_tube_current'        : tf.io.FixedLenFeature([], tf.int64),
        'dcm_exposure'                 : tf.io.FixedLenFeature([], tf.int64),
    }           
    example = tf.io.parse_single_example(example, tfrec_format)
    metadata = [example['dcm_slice_thickness'],
                example['dcm_pixel_spacing0']
# for some reason I cannot return the dcm_instance_number
#                example['dcm_instance_number']
               ]
    target = [example['pe_present_on_image'],
              example['negative_exam_for_pe'],
              example['rv_lv_ratio_gte_1'],
              example['rv_lv_ratio_lt_1'],
              example['leftsided_pe'],
              example['chronic_pe'],
              example['rightsided_pe'],
              example['acute_and_chronic_pe'],
              example['central_pe'],
              example['indeterminate']]

    return example['image'], metadata, target

def prepare_image(img, cfg=None, augment=True):    
    img = tf.image.decode_jpeg(img, channels=1)
#    img = tf.image.resize(img, [cfg['read_size'], cfg['read_size']])
    img = tf.cast(img, tf.float32) / 255.0
    
#    if augment:
#        img = tf.image.random_crop(img, [cfg['crop_size'], cfg['crop_size'], 3])
#        img = tf.image.random_flip_left_right(img)
#        img = tf.image.random_hue(img, 0.01)
#        img = tf.image.random_saturation(img, 0.7, 1.3)
#        img = tf.image.random_contrast(img, 0.8, 1.2)
#        img = tf.image.random_brightness(img, 0.1)
#    else:
#        img = tf.image.central_crop(img, cfg['crop_size'] / cfg['read_size'])
                                   
    img = tf.image.resize(img, [cfg['net_size'], cfg['net_size']])
    img = tf.image.grayscale_to_rgb(img, name=None)
    img = tf.reshape(img, [cfg['net_size'], cfg['net_size'], 3])
    return img

def get_dataset_metadata(files, cfg, augment = False, shuffle = False, repeat = False, 
                labeled=True, return_image_names=True):
    
    ds = tf.data.TFRecordDataset(files, num_parallel_reads=AUTO)
    ds = ds.cache()
    
    if repeat:
        ds = ds.repeat()
    
    if shuffle: 
        ds = ds.shuffle(1024*8)
        opt = tf.data.Options()
        opt.experimental_deterministic = False
        ds = ds.with_options(opt)
        
    if labeled: 
        ds = ds.map(read_labeled_tfrecord_metadata, num_parallel_calls=AUTO)
    else:
        ds = ds.map(lambda example: read_unlabeled_tfrecord_metadata(example, return_image_names), 
                    num_parallel_calls=AUTO)      
    
    ds = ds.map(lambda img, metadata, imgname_or_label: (prepare_image(img, augment=augment, cfg=cfg), metadata, imgname_or_label), \
                num_parallel_calls=AUTO)
    ds = ds.map(lambda img, metadata, imgname_or_label: (tuple([img, metadata]), imgname_or_label), num_parallel_calls=AUTO)
    ds = ds.batch(cfg['batch_size'] * REPLICAS)
    ds = ds.prefetch(AUTO)
    return ds
print("done") 

def show_dataset_metadata(thumb_size, cols, rows, ds):
    mosaic = PIL.Image.new(mode='RGB', size=(thumb_size*cols + (cols-1), 
                                             thumb_size*rows + (rows-1)))
    for idx, data in enumerate(iter(ds)):
#        img, metadata, target_or_imgid = data
        inputs, target_or_imgid = data
        img = inputs[0]
        metadata = inputs[1]
#        print(metadata)
        ix  = idx % cols
        iy  = idx // cols
        img = np.clip(img.numpy() * 255, 0, 255).astype(np.uint8)
        img = PIL.Image.fromarray(img)
        img = img.resize((thumb_size, thumb_size), resample=PIL.Image.BILINEAR)
        mosaic.paste(img, (ix*thumb_size + ix, 
                           iy*thumb_size + iy))
    display(mosaic)

Load a sample of images and display.

In [None]:
files_sample = np.sort(np.array(tfrec_files))

ds = get_dataset_metadata(files_sample, CFG, shuffle=True,augment=True, labeled=True).unbatch().take(12*5)   
show_dataset_metadata(256, 2, 5, ds)
print("done")