For a discussion of the methods used here see [this](https://www.kaggle.com/gautham11/generating-scan-and-segemented-scan-jpg-fastai/notebook) kernel and if you want to understand what is going on you will have to look there. Note that the method does not work perfectly as can be seen by several dark images in the plots generated while making tfrecords.

In [None]:
!conda install -c conda-forge pillow -y
!conda install -c conda-forge pydicom -y
!conda install -c conda-forge gdcm -y
!pip install pylibjpeg pylibjpeg-libjpeg

In [None]:
import os
import cv2

import pydicom
import pandas as pd
import numpy as np 
import torch
import tensorflow as tf 
import matplotlib.pyplot as plt 
from pathlib import Path
import scipy.ndimage as ndimage
from skimage import measure, morphology, segmentation
from scipy.ndimage.interpolation import zoom
from PIL import Image 

from tqdm.notebook import tqdm
%matplotlib inline

In [None]:
!pip install ../input/fastai2-wheels/fastcore-0.1.18-py3-none-any.whl -q
!pip install ../input/fastai2-wheels/fastai2-0.0.17-py3-none-any.whl -q

In [None]:
from fastai2.basics import *
from fastai2.medical.imaging import *
import cv2
import torch

In [None]:
data_dir = Path('/kaggle/input/osic-pulmonary-fibrosis-progression/')

In [None]:
## Code from https://www.kaggle.com/aadhavvignesh/lung-segmentation-by-marker-controlled-watershed 

def generate_markers(image):
    """
    Generates markers for a given image.
    
    Parameters: image
    
    Returns: Internal Marker, External Marker, Watershed Marker
    """
    
    #Creation of the internal Marker
    marker_internal = image < -400
    marker_internal = segmentation.clear_border(marker_internal)
    marker_internal_labels = measure.label(marker_internal)
    
    areas = [r.area for r in measure.regionprops(marker_internal_labels)]
    areas.sort()
    
    if len(areas) > 2:
        for region in measure.regionprops(marker_internal_labels):
            if region.area < areas[-2]:
                for coordinates in region.coords:                
                       marker_internal_labels[coordinates[0], coordinates[1]] = 0
    
    marker_internal = marker_internal_labels > 0
    
    # Creation of the External Marker
    external_a = ndimage.binary_dilation(marker_internal, iterations=10)
    external_b = ndimage.binary_dilation(marker_internal, iterations=55)
    marker_external = external_b ^ external_a
    
    # Creation of the Watershed Marker
    marker_watershed = np.zeros(image.shape, dtype=np.int)
    marker_watershed += marker_internal * 255
    marker_watershed += marker_external * 128
    
    return marker_internal, marker_external, marker_watershed


def seperate_lungs(image, iterations = 1):
    """
    Segments lungs using various techniques.
    
    Parameters: image (Scan image), iterations (more iterations, more accurate mask)
    
    Returns: 
        - Segmented Lung
        - Lung Filter
        - Outline Lung
        - Watershed Lung
        - Sobel Gradient
    """
    
    # Store the start time
    # start = time.time()
    
    marker_internal, marker_external, marker_watershed = generate_markers(image)
    
    
    '''
    Creation of Sobel Gradient
    '''
    
    # Sobel-Gradient
    sobel_filtered_dx = ndimage.sobel(image, 1)
    sobel_filtered_dy = ndimage.sobel(image, 0)
    sobel_gradient = np.hypot(sobel_filtered_dx, sobel_filtered_dy)
    sobel_gradient *= 255.0 / np.max(sobel_gradient)
    
    
    '''
    Using the watershed algorithm
    
    
    We pass the image convoluted by sobel operator and the watershed marker
    to morphology.watershed and get a matrix matrix labeled using the 
    watershed segmentation algorithm.
    '''
    watershed = morphology.watershed(sobel_gradient, marker_watershed)
    
    '''
    Reducing the image to outlines after Watershed algorithm
    '''
    outline = ndimage.morphological_gradient(watershed, size=(3,3))
    outline = outline.astype(bool)
    
    
    '''
    Black Top-hat Morphology:
    
    The black top hat of an image is defined as its morphological closing
    minus the original image. This operation returns the dark spots of the
    image that are smaller than the structuring element. Note that dark 
    spots in the original image are bright spots after the black top hat.
    '''
    
    # Structuring element used for the filter
    blackhat_struct = [[0, 0, 1, 1, 1, 0, 0],
                       [0, 1, 1, 1, 1, 1, 0],
                       [1, 1, 1, 1, 1, 1, 1],
                       [1, 1, 1, 1, 1, 1, 1],
                       [1, 1, 1, 1, 1, 1, 1],
                       [0, 1, 1, 1, 1, 1, 0],
                       [0, 0, 1, 1, 1, 0, 0]]
    
    blackhat_struct = ndimage.iterate_structure(blackhat_struct, iterations)
    
    # Perform Black Top-hat filter
    outline += ndimage.black_tophat(outline, structure=blackhat_struct)
    
    '''
    Generate lung filter using internal marker and outline.
    '''
    lungfilter = np.bitwise_or(marker_internal, outline)
    lungfilter = ndimage.morphology.binary_closing(lungfilter, structure=np.ones((5,5)), iterations=3)
    
    '''
    Segment lung using lungfilter and the image.
    '''
    segmented = np.where(lungfilter == 1, image, -2000*np.ones(image.shape))
    
    return segmented, lungfilter, outline, watershed, sobel_gradient


In [None]:
#DICOM Read utils
def fix_pxrepr(dcm):
    if dcm.PixelRepresentation != 0 or dcm.RescaleIntercept<-100:
        return dcm
    x = dcm.pixel_array + 1000
    px_mode = 4096
    x[x>=px_mode] = x[x>=px_mode] - px_mode
    dcm.PixelData = x.tobytes()
    dcm.RescaleIntercept = -1000
    return dcm

def read_dcm(path):
    dcm = fix_pxrepr(Path(path).dcmread())
    if dcm.Rows != 512 or dcm.Columns != 512: 
        dcm.zoom_to((512,512))
    return dcm

In [None]:
bins = torch.tensor([-4096., -3024., -2048., -2000., -1109., -1025., -1024., -1023., -1020.,
        -1017., -1014., -1011., -1008., -1005., -1003., -1001.,  -998.,  -996.,
         -993.,  -991.,  -988.,  -985.,  -981.,  -978.,  -973.,  -969.,  -964.,
         -958.,  -952.,  -946.,  -940.,  -933.,  -926.,  -919.,  -912.,  -904.,
         -895.,  -886.,  -875.,  -862.,  -847.,  -829.,  -806.,  -777.,  -738.,
         -684.,  -607.,  -501.,  -386.,  -301.,  -247.,  -210.,  -183.,  -163.,
         -148.,  -136.,  -126.,  -117.,  -109.,  -101.,   -93.,   -85.,   -77.,
          -69.,   -60.,   -51.,   -42.,   -33.,   -24.,   -15.,    -6.,     0.,
            5.,    13.,    20.,    27.,    34.,    41.,    48.,    55.,    62.,
           70.,    79.,    90.,   102.,   118.,   137.,   162.,   193.,   237.,
          301.,   417.,   750.,  1278.])

In [None]:
masked_bins = torch.tensor([-2048., -1000.,  -942.,  -908.,  -879.,  -847.,  -808.,  -751.,  -656.,
         -467.,   -81.,   251.])

## Generate tfrecords
This was helped by [this](https://www.kaggle.com/cdeotte/how-to-create-tfrecords) kernel.

In [None]:
groups = np.load('../input/cv-splits/groups.npy',allow_pickle=True)

# Load meta data
INPUT_FOLDER = '../input/osic-pulmonary-fibrosis-progression/train/'
train = pd.read_csv('../input/osic-pulmonary-fibrosis-progression/train.csv')
# patients = os.listdir(INPUT_FOLDER)
patients = train['Patient'].unique()

# Make a dictionary that can read the number of images for a given patient
import os
director = "../input/osic-pulmonary-fibrosis-progression/train"
num_im_dict = {}
for k, pid in enumerate(patients):
    num_im_dict[pid] = len(os.listdir( director + "/" + pid ))

In [None]:
def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

# def _float_feature(value):
#   """Returns a float_list from a float / double."""
#   return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _floats_feature(value):
  return tf.train.Feature(float_list=tf.train.FloatList(value=value.reshape(-1)))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

In [None]:
def serialize_example(feature0, feature1, feature2, feature3):
  feature = {
      'image': _bytes_feature(feature0),
      'image_name': _bytes_feature(feature1),
      'image_dim': _int64_feature(feature2),
      'slice': _int64_feature(feature3)
  }
  example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
  return example_proto.SerializeToString()

In [None]:
def process_image(path):
    dcom = fix_pxrepr(Path(path).dcmread())
    img = dcom.scaled_px.numpy()
    if img.shape[0] != img.shape[1]:
        mid_height = int(np.floor(img.shape[0]/2))
        mid_width = int(np.floor(img.shape[1]/2))
        img = img[(mid_height-256):(mid_height+256),(mid_width-256):(mid_width+256)]
    if img.shape[0] != 512:
        resize_factor = np.array([512,512]) / np.array(img.shape)
        img = scipy.ndimage.interpolation.zoom(img, resize_factor, mode='nearest')

    scaled_img = torch.tensor(img).hist_scaled(bins).numpy()

    _, lungmask, _, _, _ = seperate_lungs(img)
    masked_lung = torch.tensor(np.where(lungmask, img, -2048)).hist_scaled(masked_bins).numpy()
    chest_lung = torch.tensor(np.where(~lungmask, img, -2048)).hist_scaled(masked_bins).numpy()

    return np.stack((scaled_img,masked_lung,chest_lung))

In [None]:
fig, ax = plt.subplots(44,4,figsize=(20,44*5))
pax = ax.flatten()

CT = len(groups)
z=0; scaled_img = np.stack((np.eye(3),np.eye(3)))
for j,patients in enumerate(groups):
    print(); print('Writing TFRecord %i of %i...'%(j,CT))
    CT2 = len(patients)
    tot_im_num = np.sum([num_im_dict[pid] for pid in patients])
    if scaled_img.shape[1]!=scaled_img.shape[2]:print('Image size not equal');break
    with tf.io.TFRecordWriter('train%.2i-%i-%i.tfrec'%(j,CT2,tot_im_num)) as writer:
        for k,pt in enumerate(patients):
            nim = num_im_dict[pt]
            indx = int(nim/2)
            name = str.encode(pt)
            slc = 0; imcnt = 0
            # If the image is not there don't count it, but do increase the slice number
            while imcnt < nim-1:
                try:
                    slc+=1
                    scaled_img = process_image(data_dir/'train/{}/{}.dcm'.format(pt,slc+1))
                    imcnt+=1
                    dimg = scaled_img.shape[1]
                    img = cv2.imencode('.jpg', np.rollaxis(scaled_img,0,3)*255, (cv2.IMWRITE_JPEG_QUALITY, 94))[1].tostring()
                    example = serialize_example(
                        img, 
                        name,
                        dimg,
                        slc)
                    writer.write(example)
                    if slc == indx:
                        pax[z].imshow(scaled_img[0], cmap=plt.cm.gray)
                        z+=1
                except FileNotFoundError: pass

fig.show()

# Verify TFRecords
We will verify the TFRecords we just made by using code from the Flower Comp starter notebook [here][1] to display the TFRecords below.

[1]: https://www.kaggle.com/mgornergoogle/getting-started-with-100-flowers-on-tpu

In [None]:
# numpy and matplotlib defaults
np.set_printoptions(threshold=15, linewidth=80)
CLASSES = [0,1]

def batch_to_numpy_images_and_labels(data):
    images, labels = data
    numpy_images = images.numpy()
    numpy_labels = labels.numpy()
    #if numpy_labels.dtype == object: # binary string in this case, these are image ID strings
    #    numpy_labels = [None for _ in enumerate(numpy_images)]
    # If no labels, only image IDs, return None for labels (this is the case for test data)
    return numpy_images, numpy_labels

def title_from_label_and_target(label, correct_label):
    if correct_label is None:
        return CLASSES[label], True
    correct = (label == correct_label)
    return "{} [{}{}{}]".format(CLASSES[label], 'OK' if correct else 'NO', u"\u2192" if not correct else '',
                                CLASSES[correct_label] if not correct else ''), correct

def display_one_flower(image, title, subplot, red=False, titlesize=16):
    plt.subplot(*subplot)
    plt.axis('off')
    plt.imshow(image)
    if len(title) > 0:
        plt.title(title, fontsize=int(titlesize) if not red else int(titlesize/1.2), color='red' if red else 'black', fontdict={'verticalalignment':'center'}, pad=int(titlesize/1.5))
    return (subplot[0], subplot[1], subplot[2]+1)
    
def display_batch_of_images(databatch, predictions=None):
    """This will work with:
    display_batch_of_images(images)
    display_batch_of_images(images, predictions)
    display_batch_of_images((images, labels))
    display_batch_of_images((images, labels), predictions)
    """
    # data
    images, labels = batch_to_numpy_images_and_labels(databatch)
    if labels is None:
        labels = [None for _ in enumerate(images)]
        
    # auto-squaring: this will drop data that does not fit into square or square-ish rectangle
    rows = int(math.sqrt(len(images)))
    cols = len(images)//rows
        
    # size and spacing
    FIGSIZE = 13.0
    SPACING = 0.1
    subplot=(rows,cols,1)
    if rows < cols:
        plt.figure(figsize=(FIGSIZE,FIGSIZE/cols*rows))
    else:
        plt.figure(figsize=(FIGSIZE/rows*cols,FIGSIZE))
    
    # display
    for i, (image, label) in enumerate(zip(images[:rows*cols], labels[:rows*cols])):
        title = label
        correct = True
        if predictions is not None:
            title, correct = title_from_label_and_target(predictions[i], label)
        dynamic_titlesize = FIGSIZE*SPACING/max(rows,cols)*40+3 # magic formula tested to work from 1x1 to 10x10 images
        subplot = display_one_flower(image, title, subplot, not correct, titlesize=dynamic_titlesize)
    
    #layout
    plt.tight_layout()
    if label is None and predictions is None:
        plt.subplots_adjust(wspace=0, hspace=0)
    else:
        plt.subplots_adjust(wspace=SPACING, hspace=SPACING)
    plt.show()

In [None]:
def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float32) / 255.0  # convert image to floats in [0, 1] range
    image = tf.reshape(image, [*IMAGE_SIZE, 3]) # explicit size needed for TPU
    return image

def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "image_name": tf.io.FixedLenFeature([], tf.string),  # shape [] means single element
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    label = example['image_name']
    return image, label # returns a dataset of (image, label) pairs

def load_dataset(filenames, labeled=True, ordered=False):
    # Read from TFRecords. For optimal performance, reading from multiple files at once and
    # disregarding data order. Order does not matter since we will be shuffling the data anyway.

    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False # disable order, increase speed

    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.map(read_labeled_tfrecord)
    # returns a dataset of (image, label) pairs if labeled=True or (image, id) pairs if labeled=False
    return dataset

def get_training_dataset():
    dataset = load_dataset(TRAINING_FILENAMES, labeled=True)
    dataset = dataset.repeat() # the training dataset must repeat for several epochs
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def count_data_items(filenames):
    # the number of data items is written in the name of the .tfrec files, i.e. flowers00-230.tfrec = 230 data items
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

In [None]:
# INITIALIZE VARIABLES
IMAGE_SIZE= [512,512]; BATCH_SIZE = 32
AUTO = tf.data.experimental.AUTOTUNE
TRAINING_FILENAMES = tf.io.gfile.glob('train*.tfrec')
print('There are %i train images'%count_data_items(TRAINING_FILENAMES))

In [None]:
# DISPLAY TRAIN IMAGES
training_dataset = get_training_dataset()
training_dataset = training_dataset.unbatch().batch(20)
train_batch = iter(training_dataset)

display_batch_of_images(next(train_batch))