# 1. OSIC AutoEncoder training
This notebooks demonstrates how to train a convolutional AutoEncoder to learn latent features from the 3D CT scans dataset.

One of the main applications of AutoEncoders is dimensionality reduction. We will use them for that: reducing 3D images (preprocessed to 1 x 40 x 256 x 256 tensors) to vectors (with 10 dimensions).
![autoencoder](https://hackernoon.com/hn-images/1*8ixTe1VHLsmKB3AquWdxpQ.png)

Once we have the trained model, the idea is to apply it to extract these latent features and combine them with the OSIC tabular data.

My first experiments had a less strangled bottleneck (started with 96 x 2 x 20 x 20), which was already a reduction of over 34:1 (the inputs are 3D images of 1 x 40 x 256 x 256). The AutoEncoder output was great, easy to see. However, using latent features of 96 x 2 x 20 x 20 meant that, in the tabular model, I had to combine 76,800 features (flattened) with the 9 tabular features. In order to have a better balance between tabular and latent features, I decide to strangle the bottleneck further, squeezing the 3D images to 10 features (already flatenned in the AutoEncoder model). As you can see below, the model learns as the loss keeps going down. However, the output of the AutoEncoder is not as visible as with the less strangled bottleneck.

# 2. Imports and global variables

In [None]:
import copy
from datetime import timedelta, datetime
import imageio
import matplotlib.pyplot as plt
from matplotlib import cm
import multiprocessing
import numpy as np
import os
# from pathlib import Path
import pydicom
import pytest
import scipy.ndimage
import scipy.ndimage as ndimage
from scipy.ndimage.interpolation import zoom
from skimage import measure, morphology, segmentation
from time import time, sleep
from tqdm import trange, tqdm
import torch                                                                                           
import torch.nn as nn
import torch.nn.functional as F                                   
from torch.utils.data import Dataset, random_split, DistributedSampler, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import warnings                        
import pandas as pd
from pathlib import Path

from joblib import Parallel, delayed


In [None]:
root_dir = '/kaggle/input/osic-preprocrsseddata'
# root_dir="/kaggle/input/osic-pulmonary-fibrosis-progression/"
test_dir = '/kaggle/input/osic-pulmonary-fibrosis-progression/test'
# model_file = '/kaggle/working/'
model_file = '/kaggle/working/diophantus.pt'
cache_dir='/kaggle/input/osic-preprocrsseddata'
resize_dims = (40, 256, 256)
clip_bounds = (-1000, 200)
watershed_iterations = 1
pre_calculated_mean = 0.02865046213070556
latent_features = 10
batch_size = 16
learning_rate = 3e-03
num_epochs = 150
val_size = 0.2
tensorboard_dir = '/kaggle/working/runs'

In [None]:
# !pip install --upgrade pip
# !pip install pydicom 

In [None]:
# from platform import python_version

# print(python_version())

In [None]:
# !conda install -c conda-forge pillow -y
# !conda install -c conda-forge pydicom -y
# !conda install gdcm -c conda-forge -y 
# !pip install pylibjpeg pylibjpeg-libjpeg
# !conda install -c conda-forge gdcm -y

In [None]:
# import gdcm

# 3. Dataset interface

## 3.1. ctscans_dataset.py
This interface ingests the data from the 3D CT scans, porting them to a PyTorch Dataset.**

In [None]:
ds=pd.DataFrame(columns=['Spacing','No Of Slices'])
# a=[1,"ff"]                                         

In [None]:
# ds.loc["ID123423212"]=a
ds

In [None]:
# class CTScansDataset(Dataset):
#     def __init__(self, root_dir, transform=None):
#         self.root_dir = Path(root_dir)
#         self.patients = [p for p in self.root_dir.glob('*') if p.is_dir()]
#         self.transform = transform

#     def __len__(self):
#         return len(self.patients)

#     def __getitem__(self, idx):
#         if torch.is_tensor(idx):
#             idx = idx.tolist()

#         image, metadata = self.load_scan(self.patients[idx])
#         sample = {'image': image, 'metadata': metadata}
#         if self.transform:
#             sample = self.transform(sample)

#         return sample

#     def save(self, path):
#         t0 = time()
#         Path(path).mkdir(exist_ok=True, parents=True)
#         print('Saving pre-processed dataset to disk')
#         sleep(1)
#         cum = 0

#         bar = trange(len(self))
#         for i in bar:
#             sample = self[i]
#             image, data = sample['image'], sample['metadata']
#             cum += torch.mean(image).item()

#             bar.set_description(f'Saving CT scan {data.PatientID}')
#             fname = Path(path) / f'{data.PatientID}.pt'
#             torch.save(image, fname)

#         sleep(1)
#         bar.close()
#         print(f'Done! Time {timedelta(seconds=time() - t0)}\n'
#               f'Mean value: {cum / len(self)}')

#     def get_patient(self, patient_id):
#         patient_ids = [str(p.stem) for p in self.patients]
#         return self.__getitem__(patient_ids.index(patient_id))

#     @staticmethod
#     def load_scan(path):
#         slices = [pydicom.read_file(p) for p in path.glob('*.dcm')]
#         try:
#             slices.sort(key=lambda x: float(x.ImagePositionPatient[2]))
#         except AttributeError:
#             warnings.warn(f'Patient {slices[0].PatientID} CT scan does not '
#                           f'have "ImagePositionPatient". Assuming filenames '
#                           f'in the right scan order.')

#         image = np.stack([s.pixel_array.astype(float) for s in slices])
#         return image, slices[0]





class CTScansDataset(Dataset):
    
    def __init__(self, root_dir, transform=None):
        self.root_dir = Path(root_dir)
        self.patients = [p for p in self.root_dir.glob('*') if p.is_dir()]
        self.transform = transform

    def __len__(self):
        return len(self.patients)

    def __getitem__(self, idx):
        
        if torch.is_tensor(idx):
            idx = idx.tolist()

        image,slices = self.load_scan(self.patients[idx])
        
#         print("in slices[0]")
#         print(slices[0].PatientID)
        if (slices[0].PatientID =="ID00052637202186188008618") or (slices[0].PatientID=="ID00011637202177653955184" ):
            sample = {'image': image, 'metadata': slices[0]}
            return sample
        
        image,metadata=self.resample(image,slices)
        sample = {'image': image, 'metadata': metadata}
        
        if slices[0].PatientID=="ID00108637202209619669361":
            print("in 361 GETITEM")
            return sample
        
        if self.transform:
            sample = self.transform(sample)

        return sample

#     def parallel_save(self,f,path,bar):
        
#           sample=self[f]
#           image, data = sample['image'], sample['metadata']
#           cum+=torch.mean(image).item()
#           bar.set_description(f'Saving CT scan {data.PatientID}')
#           fname = Path(path) / f'{data.PatientID}.pt'
#           torch.save(image, fname)


    def save(self, path):
        t0 = time()
#         Path(path).mkdir(exist_ok=True, parents=True)
        print('Saving pre-processed dataset to disk')
        sleep(1)
        cum = 0
        print("self length")
        print(len(self))
        bar = trange(len(self))
        # Parallel(n_jobs=32, verbose=10)(delayed(self.parallel_save)(f,path,bar) for f in bar)
        # Parallel(n_jobs=32, verbose=10)(delayed(convert)(f) for f in JPEG_FILES)

        for i in bar:
            sample = self[i]
            image, data = sample['image'], sample['metadata']
            if (data.PatientID =="ID00052637202186188008618") or (data.PatientID =="ID00011637202177653955184"):
                continue
#             if(data.PatientID=="ID00078637202199415319443"):
#                 continue
            if data.PatientID=="ID00108637202209619669361":
                continue
            cum += torch.mean(image).item()

            bar.set_description(f'Saving CT scan {data.PatientID}')
            fname = os.path.join(path, f'{data.PatientID}.pt')
            torch.save(image, fname)
            fname = os.path.join(path, f'{data.PatientID}.dcm')
            sample['metadata'].save_as(fname)
#             torch.save(data, fname)
#             data.save_as(Path(path)/f'{patient_id}.dcm')

        sleep(1)
        bar.close()
        print((len(self)))
        print(f'Done! Time {timedelta(seconds=time() - t0)}\n'
              f'Mean value: {cum / len(self)}')

    def get_patient(self, patient_id):
        patient_ids = [str(p.stem) for p in self.patients]
        return self.__getitem__(patient_ids.index(patient_id))

    @staticmethod
    def load_scan(path):
        
        slices = [pydicom.read_file(p) for p in path.glob('*.dcm')]
        print((slices[0].PatientID))
        # slices = [pydicom.read_file(path + '/' + s) for s in os.listdir(path)]

#         print()
#         print(slices[0].PatientID)
        if slices[0].PatientID=="ID00132637202222178761324":
            slice_thickness=0.625
        
        elif slices[0].PatientID=="ID00128637202219474716089":
            slice_thickness=5.0
            
        elif slices[0].PatientID=="ID00173637202238329754031":
            slice_thickness=1.0
            
        elif (slices[0].PatientID=="ID00052637202186188008618") or (slices[0].PatientID=="ID00011637202177653955184"):
            return slices[0],slices
            
        
        else:
#             
            try:
#                 if 'ImagePositionPatient' in slices[0].dir():
#                     print("hh")
                slices.sort(key = lambda x: float(x.InstanceNumber))
#                 print("in try")
                  
                slice_thickness=np.abs(slices[0].SpacingBetweenSlices)
#                 print("after slice_thickness")
            except:
                # print()
                # print(slices[0].PatientID)
                # slice_thickness=slices[0].SpacingBetweenSlices 
                # slice_thickness = np.abs(slices[0].SliceLocation - slices[1].SliceLocation)
#                 print("in except")
                if 'ImagePositionPatient' in slices[0].dir():
                    slices.sort(key = lambda x: float(x.ImagePositionPatient[2]))
                    slice_thickness = np.abs(slices[0].ImagePositionPatient[2] - slices[1].ImagePositionPatient[2])
                else:
#                     slice_thickness=slices[0].SliceThickness
                    slice_thickness = np.abs(slices[0].SliceLocation - slices[1].SliceLocation)
        arr=[]
        # image = np.stack([s.pixel_array.astype(float) for s in slices])
        dsdata=[]
#         dsdata.append(slices[0].PatientID)
        dsdata.append(slice_thickness)
        
        c=0
        for s in slices:
            c=c+1
            if s.Rows!=s.Columns:
                s_data=s.pixel_array
                s_cropped = s_data[~np.all(s_data == 0, axis=1)]
                s_cropped = s_cropped[:, ~np.all(s_cropped == 0, axis=0)]
                new_pix_array=s_cropped
                s.PixelData = new_pix_array.tostring()
                (s.Rows,s.Columns)=new_pix_array.shape

            arr.append(s.pixel_array)
            s.SliceThickness = slice_thickness
            
        dsdata.append(c)
        ds.loc[slices[0].PatientID]=dsdata
        
        
        image=np.stack(arr)
        # print(arr[0])
        image=image.astype(float)
        # print(slices[0].SliceThickness)
        # print((slices))
        return image,slices

    
    def crop_slice(s):
        s_cropped = s[~np.all(s == 0, axis=1)]
        s_cropped = s_cropped[:, ~np.all(s_cropped == 0, axis=0)]
        return s_cropped


    @staticmethod
    def resample(image, scan,new_spacing=[1,1,1]):
#         print("inresample")
      # print(scan[0])
      # print(type(scan))
#         print(scan[0].SliceThickness)
        if scan[0].PatientID=="ID00108637202209619669361":
            print("in 361 resample if")
            return image,scan[0]
#         elif scan[0].PatientID=="ID00052637202186188008618":
#             return image,scan[0]
        
        else:
#             spacing=np.array([scan[0].SliceThickness, scan[0].PixelSpacing[0], scan[0].PixelSpacing[1]], dtype=np.float32)
#             resize_factor = spacing / new_spacing
#             new_real_shape = image.shape * resize_factor
#             new_shape = np.round(new_real_shape)
#             real_resize_factor = new_shape / image.shape
#             new_spacing = spacing / real_resize_factor   
#     #         print("before zoom resizr") 
#             image = scipy.ndimage.interpolation.zoom(image, real_resize_factor, mode='nearest')
#     #         print("after zoom resie")
            return image, scan[0]


# class CTScansDataset(Dataset):
#     def __init__(self, root_dir, transform=None):
#         self.root_dir = Path(root_dir)
#         self.patients = [p for p in self.root_dir.glob('*') if p.is_dir()]
#         self.transform = transform

#     def __len__(self):
#         return len(self.patients)

#     def __getitem__(self, idx):
#         if torch.is_tensor(idx):
#             idx = idx.tolist()
# #         print(idx)
#         image,slices = self.load_scan(self.patients[idx])
#         # if slices[0].PatientID =="ID00078637202199415319443":
#         #   return slices
# #         print("after load_scan")
#         image,metadata=self.resample(image,slices)
#         sample = {'image': image, 'metadata': metadata}
#         if metadata.PatientID =="ID00078637202199415319443":
#         # print(metadata.PatientID)
#           return sample 
#         if self.transform:
          
#           sample = self.transform(sample)

#         return sample

#     def parallel_save(self,f,path,bar):
#         sample=self[f]
#         image, data = sample['image'], sample['metadata']
#         cum+=torch.mean(image).item()
#         bar.set_description(f'Saving CT scan {data.PatientID}')
#         fname = Path(path) / f'{data.PatientID}.pt'
#         torch.save(image, fname)


#     def save(self, path):       
#         t0 = time()
#         Path(path).mkdir(exist_ok=True, parents=True)
#         print('Saving pre-processed dataset to disk')
#         sleep(1)
#         cum = 0
#         print("sekf length")
#         print(len(self))
#         bar = trange(len(self))
#         # Parallel(n_jobs=32, verbose=10)(delayed(self.parallel_save)(f,path,bar) for f in bar)
#         # Parallel(n_jobs=32, verbose=10)(delayed(convert)(f) for f in JPEG_FILES)
        
#         for i in bar:
#             sample = self[i]
#             image, data = sample['image'], sample['metadata']
# #             if(data.PatientID=="ID00078637202199415319443"):
# #                 continue
#             cum += torch.mean(image).item()

#             bar.set_description(f'Saving CT scan {data.PatientID}')
#             fname = Path(path) / f'{data.PatientID}.pt'
#             torch.save(image, fname)

#         sleep(1)
#         bar.close() 
# #         print(len(self))
#         print(f'Done! Time {timedelta(seconds=time() - t0)}\n'
#               f'Mean value: {cum / len(self)}')

#     def get_patient(self, patient_id):
#         patient_ids = [str(p.stem) for p in self.patients]
#         return self.__getitem__(patient_ids.index(patient_id))

#     @staticmethod
#     def load_scan(path):
#         slices = [pydicom.read_file(p) for p in path.glob('*.dcm')]
#         # print((slices[0].PatientID))
#         # slices = [pydicom.read_file(path + '/' + s) for s in os.listdir(path)]

#         print()
#         print(slices[0].PatientID)
# #         print("ok load_scan")
#         if slices[0].PatientID=="ID00132637202222178761324":
            
#             slice_thickness=0.625
        
#         elif slices[0].PatientID=="ID00128637202219474716089 ":
#             slice_thickness=5
        
#         else:
# #             print(slices[0].dir())
#             try:
# #                 print(" try ok")
#                 # print(slices[0].dir())
#                 # print(slices[0].ImagePositionPatient[2])
#                 # c=1
#                 # for i in slices:
#                 #   print(i.InstanceNumber)
#                 #   print(i.ImagePositionPatient)
#                 #   print(c)
#                 #   c=c+1
#                 if 'ImagePositionPatient' in slices[0].dir():
                    
#                     slices.sort(key = lambda x: (x.InstanceNumber))
                  
#                 slice_thickness=np.abs(slices[0].SpacingBetweenSlices)
       
#             except:
# #                 print("in except")
#                 # print()
#                 # print(slices[0].PatientID)
#                 # slice_thickness=slices[0].SpacingBetweenSlices 
#                 # slice_thickness = np.abs(slices[0].SliceLocation - slices[1].SliceLocation)
#                 if ("ImagePositionPatient" in slices[0].dir()):
                    
#                     slices.sort(key = lambda x: float(x.ImagePositionPatient[2]))
#                     slice_thickness = np.abs(slices[0].ImagePositionPatient[2] - slices[1].ImagePositionPatient[2])


#                   # slice_thickness= np.abs(slices[0].SliceLocation - slices[1].SliceLocation)
                
#                 else:
                    
#                     slice_thickness=np.abs(slices[0].SliceThickness)
# #                 print("ende except")
        
#         arr=[]
#         # print("after sorting")
#         c=1
#         # for i in slices:
#         #   print(i.InstanceNumber)
#         #   print(i.ImagePositionPatient)
#         #   print(c)
#         #   c=c+1
#         # image = np.stack([s.pixel_array.astype(float) for s in slices])
# #         print("before slices loop")
#         for s in slices:
            
#             if s.Rows!=s.Columns:
                
#                 s_data=s.pixel_array
#                 s_cropped = s_data[~np.all(s_data == 0, axis=1)]
#                 s_cropped = s_cropped[:, ~np.all(s_cropped == 0, axis=0)]
#                 new_pix_array=s_cropped
#                 s.PixelData = new_pix_array.tostring()
#                 (s.Rows,s.Columns)=new_pix_array.shape

#             arr.append(s.pixel_array)
#             s.SliceThickness = slice_thickness
# #         print("after slice")
        
#         image=np.stack(arr)
# #         print("after image")
#         # print(arr[0])
#         image=image.astype(float)
#         # print(slices[0].SliceThickness)
#         # print((slices))
#         return image,slices

    
#     def crop_slice(s):
        
#         s_cropped = s[~np.all(s == 0, axis=1)]
#         s_cropped = s_cropped[:, ~np.all(s_cropped == 0, axis=0)]
#         return s_cropped


#     @staticmethod
#     def resample(image, scan,new_spacing=[1,1,1]):
        
# #         print("in resample")
# #         print(scan[0].SliceThickness)
#         spacing=np.array([scan[0].SliceThickness, scan[0].PixelSpacing[0], scan[0].PixelSpacing[1]], dtype=np.float32)
# #         print("after spacing")
#         resize_factor = spacing / new_spacing
#         new_real_shape = image.shape * resize_factor
#         new_shape = np.round(new_real_shape)
#         real_resize_factor = new_shape / image.shape
#         new_spacing = spacing / real_resize_factor   
# #         print("before zoom resizr") 
#         image = scipy.ndimage.interpolation.zoom(image, real_resize_factor, mode='nearest')
# #         print("after zoom resie")
#         return image, scan[0]

## 3.2. Pre-processing
There are some pre-processing to be done. Let's tackle them one step at a time.
### 3.2.1. crop_bounding_box.py

In [None]:
class CropBoundingBox:
    @staticmethod
    def bounding_box(img3d: np.array):
        mid_img = img3d[int(img3d.shape[0] / 2)]
        same_first_row = (mid_img[0, :] == mid_img[0, 0]).all()
        same_first_col = (mid_img[:, 0] == mid_img[0, 0]).all()
        if same_first_col and same_first_row:
            return True
        else:
            return False

    def __call__(self, sample):
        image, data = sample['image'], sample['metadata']
        if not self.bounding_box(image):
            return sample

        mid_img = image[int(image.shape[0] / 2)]
        r_min, r_max = None, None
        c_min, c_max = None, None
        for row in range(mid_img.shape[0]):
            if not (mid_img[row, :] == mid_img[0, 0]).all() and r_min is None:
                r_min = row
            if (mid_img[row, :] == mid_img[0, 0]).all() and r_max is None \
                    and r_min is not None:
                r_max = row
                break

        for col in range(mid_img.shape[1]):
            if not (mid_img[:, col] == mid_img[0, 0]).all() and c_min is None:
                c_min = col
            if (mid_img[:, col] == mid_img[0, 0]).all() and c_max is None \
                    and c_min is not None:
                c_max = col
                break

        image = image[:, r_min:r_max, c_min:c_max]
        return {'image': image, 'metadata': data}

### 3.2.2. convert_to_hu.py
Credits to [Guido Zuidhof's tutorial](https://www.kaggle.com/gzuidhof/full-preprocessing-tutorial).

In [None]:
class ConvertToHU:
    def __call__(self, sample):
        image, data = sample['image'], sample['metadata']

        img_type = data.ImageType
        is_hu = img_type[0] == 'ORIGINAL' and not (img_type[2] == 'LOCALIZER')
        # if not is_hu:
        #     warnings.warn(f'Patient {data.PatientID} CT Scan not cannot be'
        #                   f'converted to Hounsfield Units (HU).')

        intercept = data.RescaleIntercept
        slope = data.RescaleSlope
        image = (image * slope + intercept).astype(np.int16)
        return {'image': image, 'metadata': data}

### 3.2.3. resize.py

In [None]:
class Resize:
    def __init__(self, output_size):
        assert isinstance(output_size, tuple)
        self.output_size = output_size

    def __call__(self, sample):
        image, data = sample['image'], sample['metadata']
        resize_factor = np.array(self.output_size) / np.array(image.shape)
        image = zoom(image, resize_factor, mode='nearest')
        return {'image': image, 'metadata': data}

### 3.2.4. clip.py

In [None]:
class Clip:
    def __init__(self, bounds=(-1000, 500)):
        self.min = min(bounds)
        self.max = max(bounds)

    def __call__(self, sample):
        image, data = sample['image'], sample['metadata']
        image[image < self.min] = self.min
        image[image > self.max] = self.max
        return {'image': image, 'metadata': data}

### 3.2.5. mask_watershed.py
Credits to [Aadhav Vignesh's amazing kernel](https://www.kaggle.com/aadhavvignesh/lung-segmentation-by-marker-controlled-watershed).

IMPORTANT: I made some changes in Vignesh's code below to make it scalable, most notably reducing the number of iterations from 8 to 1. This was important to reduce the time to generate masks from ~8-9 seconds/slice (which would take over 17 hours to complete) to ~100ms/slice. I'm satisfied with the quality of the masks, as you can see in some samples below. However, using 8 iterations generate even better masks.

In [None]:
class MaskWatershed:
    def __init__(self, min_hu, iterations, show_tqdm):
        self.min_hu = min_hu
        self.iterations = iterations
        self.show_tqdm = show_tqdm

    def __call__(self, sample):
        image, data = sample['image'], sample['metadata']

        stack = []
        if self.show_tqdm:
            bar = trange(image.shape[0])
            bar.set_description(f'Masking CT scan {data.PatientID}')
        else:
            bar = range(image.shape[0])
        for slice_idx in bar:
            sliced = image[slice_idx]
            stack.append(self.seperate_lungs(sliced, self.min_hu,
                                             self.iterations))

        return {
            'image': np.stack(stack),
            'metadata': sample['metadata']
        }

    @staticmethod
    def seperate_lungs(image, min_hu, iterations):
        h, w = image.shape[0], image.shape[1]

        marker_internal, marker_external, marker_watershed = MaskWatershed.generate_markers(image)

        # Sobel-Gradient
        sobel_filtered_dx = ndimage.sobel(image, 1)
        sobel_filtered_dy = ndimage.sobel(image, 0)
        sobel_gradient = np.hypot(sobel_filtered_dx, sobel_filtered_dy)
        sobel_gradient *= 255.0 / np.max(sobel_gradient)

        watershed = morphology.watershed(sobel_gradient, marker_watershed)

        outline = ndimage.morphological_gradient(watershed, size=(3,3))
        outline = outline.astype(bool)

        # Structuring element used for the filter
        blackhat_struct = [[0, 0, 1, 1, 1, 0, 0],
                           [0, 1, 1, 1, 1, 1, 0],
                           [1, 1, 1, 1, 1, 1, 1],
                           [1, 1, 1, 1, 1, 1, 1],
                           [1, 1, 1, 1, 1, 1, 1],
                           [0, 1, 1, 1, 1, 1, 0],
                           [0, 0, 1, 1, 1, 0, 0]]

        blackhat_struct = ndimage.iterate_structure(blackhat_struct, iterations)

        # Perform Black Top-hat filter
        outline += ndimage.black_tophat(outline, structure=blackhat_struct)

        lungfilter = np.bitwise_or(marker_internal, outline)
        lungfilter = ndimage.morphology.binary_closing(lungfilter, structure=np.ones((5,5)), iterations=3)

        segmented = np.where(lungfilter == 1, image, min_hu * np.ones((h, w)))

        return segmented  #, lungfilter, outline, watershed, sobel_gradient

    @staticmethod
    def generate_markers(image, threshold=-400):
        h, w = image.shape[0], image.shape[1]

        marker_internal = image < threshold
        marker_internal = segmentation.clear_border(marker_internal)
        marker_internal_labels = measure.label(marker_internal)

        areas = [r.area for r in measure.regionprops(marker_internal_labels)]
        areas.sort()

        if len(areas) > 2:
            for region in measure.regionprops(marker_internal_labels):
                if region.area < areas[-2]:
                    for coordinates in region.coords:
                        marker_internal_labels[coordinates[0], coordinates[1]] = 0

        marker_internal = marker_internal_labels > 0

        # Creation of the External Marker
        external_a = ndimage.binary_dilation(marker_internal, iterations=10)
        external_b = ndimage.binary_dilation(marker_internal, iterations=55)
        marker_external = external_b ^ external_a

        # Creation of the Watershed Marker
        marker_watershed = np.zeros((h, w), dtype=np.int)
        marker_watershed += marker_internal * 255
        marker_watershed += marker_external * 128

        return marker_internal, marker_external, marker_watershed

### 3.2.6. normalize.py, to_tensor.py, zero_center.py

In [None]:
class Normalize:
    def __init__(self, bounds=(-1000, 500)):
        self.min = min(bounds)
        self.max = max(bounds)

    def __call__(self, sample):
        image, data = sample['image'], sample['metadata']
        image = image.astype(np.float)
        image = (image - self.min) / (self.max - self.min)
        return {'image': image, 'metadata': data}
    

class ToTensor:
    def __init__(self, add_channel=True):
        self.add_channel = add_channel

    def __call__(self, sample):
        image, data = sample['image'], sample['metadata']
        if self.add_channel:
            image = np.expand_dims(image, axis=0)

        return {'image': torch.from_numpy(image), 'metadata': data}
    
    
class ZeroCenter:
    def __init__(self, pre_calculated_mean):
        self.pre_calculated_mean = pre_calculated_mean

    def __call__(self, tensor):
        return tensor - self.pre_calculated_mean

### 3.2.7. Inspecting some slices

In [None]:
def show(list_imgs, cmap=cm.bone):
    list_slices = []
    for img3d in list_imgs:
        slc = int(img3d.shape[0] / 2)
        img = img3d[slc]
        list_slices.append(img)
    
    fig, axs = plt.subplots(1, 5, figsize=(15, 7))
    for i, img in enumerate(list_slices):
        axs[i].imshow(img, cmap=cmap)
        axs[i].axis('off')
        
    plt.show()

In [None]:
# test = CTScansDataset(
#     root_dir=test_dir,
#     transform=transforms.Compose([
#         CropBoundingBox(),
#         ConvertToHU(),
#         Resize(resize_dims),
#         Clip(bounds=clip_bounds),
#         MaskWatershed(min_hu=min(clip_bounds), iterations=1, show_tqdm=True),
#         Normalize(bounds=clip_bounds)
#     ]))

# list_imgs = [test[i]['image'] for i in range(len(test))]
# show(list_imgs)

## 3.3. Caching pre-processed images in the disk
Pre-processing all 176 3D CT scans take some time. Depending on the parameters we choose, it can take hours. 

With the current choice of parameters, it takes around 15 minutes. To accelerate experimentation, I already pre-cached the images with the preprocessing parameters in this notebook, saving them in a [public dataset](https://www.kaggle.com/carlossouza/osic-cached-dataset). 

This way, you can preprocess only once, and experiment with the same preprocessed tensors. The code to preprocess and cache images in the disk is:
```
data = CTScansDataset(
    root_dir=root_dir,
    transform=transforms.Compose([
        CropBoundingBox(),
        ConvertToHU(),
        Resize(size),
        Clip(bounds=clip_bounds),
        MaskWatershed(
            min_hu=min(clip_bounds),
            iterations=watershed_iterations,
            show_tqdm=False),
        Normalize(bounds=clip_bounds),
        ToTensor()
    ]))
data.save(dest_dir)
```

From this point on, we use the `CTTensorsDataset` as the interface to ingest the preprocessed tensors, taking the data to training.

In [None]:
# %mkdir newdata3

In [None]:
# !cp ../input/osic-pulmonary-fibrosis-progression/train/ID00108637202209619669361 -r /kaggle/working/newdata3/

In [None]:
# li=[f for f in os.listdir("/kaggle/working/newdata/ID00035637202182204917484")]
# print(li)

In [None]:
# %mkdir PreprocessedData

In [None]:
# data = CTScansDataset(
#     root_dir='/kaggle/input/osic-pulmonary-fibrosis-progression/train',
#     transform=transforms.Compose([
#         CropBoundingBox(),
#         ConvertToHU(),
#         Resize(resize_dims),
#         Clip(bounds=clip_bounds),
#         MaskWatershed(
#             min_hu=min(clip_bounds),
#             iterations=watershed_iterations,
#             show_tqdm=False),
#         Normalize(bounds=clip_bounds),
#         ToTensor()
#     ]))
# data.save("/kaggle/working/")

In [None]:
# ds.to_csv("fileinfo.csv")                                                           

In [None]:
# import fnmatch
# import os

# matches = []
# for root, dirnames, filenames in os.walk('src'):
#     for filename in fnmatch.filter(filenames, '*.c'):
#         matches.append(os.path.join(root, filename))

In [None]:
import os
# import fnmatch

# class CTTensorsDataset(Dataset):
#     def __init__(self, root_dir, transform=None):
#         self.root_dir = os.walk(root_dir)
        
# #         for f in os.walk(self.root_dir):
#         print(self.root_dir)
#         self.tensor_files = sorted([f for f in fnmatch.filter(self.root_dir,'*.pt')])
#         print(self.tensor_files)
#         self.transform = transform

#     def __len__(self):
#         return len(self.tensor_files)

#     def __getitem__(self, item):
#         if torch.is_tensor(item):
#             item = item.tolist()
        
#         image = torch.load(self.tensor_files[item])
#         print(self.tensor_files)
#         if self.transform:
#             image = self.transform(image)

#         return {
#             'patient_id': self.tensor_files[item].stem,
#             'image': image
#         }

#     def mean(self):
#         cum = 0
#         for i in range(len(self)):
#             sample = self[i]['image']
#             cum += torch.mean(sample).item()

#         return cum / len(self)

#     def random_split(self, val_size: float):
#         num_val = int(val_size * len(self))
#         num_train = len(self) - num_val
#         return random_split(self, [num_train, num_val])

class CTTensorsDataset(Dataset):
    def __init__(self, root_dir, transform=None):
#         self.root_dir = Path(root_dir)
#         self.tensor_files = sorted([f for f in self.root_dir.glob('*.pt')])
        
        self.root_dir = root_dir
#         print(self.root_dir)
        self.tensor_files =[os.path.join(self.root_dir,f) for f in os.listdir(self.root_dir) if f.endswith('.' + 'pt')]
        self.transform = transform

    def __len__(self):
        return len(self.tensor_files)

    def __getitem__(self, item):
        if torch.is_tensor(item):
            item = item.tolist()
#         print( os.path.basename(self.tensor_files[item])[:-3])
        image = torch.load(self.tensor_files[item])
#         print(image)
        if self.transform:
            image = self.transform(image)

        return {
            'patient_id': os.path.basename(self.tensor_files[item])[:-3],
            'image': image
        }

    def mean(self):
        cum = 0
        for i in range(len(self)):
            sample = self[i]['image']
            cum += torch.mean(sample).item()

        return cum / len(self)

    def random_split(self, val_size: float):
        num_val = int(val_size * len(self))
        num_train = len(self) - num_val
        return random_split(self, [num_train, num_val])

### 3.3.1. Checking data pipeline

In [None]:
train = CTTensorsDataset(
    root_dir=cache_dir,
    transform=ZeroCenter(pre_calculated_mean=pre_calculated_mean)
)
print(len(train))
cum = 0
for i in range(len(train)):
    sample = train[i]['image']
    cum += torch.mean(sample).item()

# assert cum / len(train) == pytest.approx(0)

# 4. AutoEncoder
Credits to [Srinjay Paul's great tutorial](https://srinjaypaul.github.io/3D_Convolutional_autoencoder_for_brain_volumes/), and lots of papers (I will link them later).

As mentioned, I strangled the bottleneck to force very few latent features (10). The image below shows the transformations:
![autoencoder](https://i.ibb.co/2hYZFc1/autoencoder.jpg)

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self, latent_features=latent_features):
        super(AutoEncoder, self).__init__()
        # Encoder
        self.conv1 = nn.Conv3d(1, 16, 3)
        self.conv2 = nn.Conv3d(16, 32, 3)
        self.conv3 = nn.Conv3d(32, 96, 2)
        self.conv4 = nn.Conv3d(96, 1, 1)
        self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True)
        self.pool2 = nn.MaxPool3d(kernel_size=3, stride=3, return_indices=True)
        self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True)
        self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True)
        self.fc1 = nn.Linear(10 * 10, latent_features)
        # Decoder
        self.fc2 = nn.Linear(latent_features, 10 * 10)
        self.deconv0 = nn.ConvTranspose3d(1, 96, 1)
        self.deconv1 = nn.ConvTranspose3d(96, 32, 2)
        self.deconv2 = nn.ConvTranspose3d(32, 16, 3)
        self.deconv3 = nn.ConvTranspose3d(16, 1, 3)
        self.unpool0 = nn.MaxUnpool3d(kernel_size=2, stride=2)
        self.unpool1 = nn.MaxUnpool3d(kernel_size=2, stride=2)
        self.unpool2 = nn.MaxUnpool3d(kernel_size=3, stride=3)
        self.unpool3 = nn.MaxUnpool3d(kernel_size=2, stride=2)
    
    def reparameterize(self, mu, log_var):
        """
        :param mu: mean from the encoder's latent space
        :param log_var: log variance from the encoder's latent space
        """
        std = torch.exp(0.5*log_var) # standard deviation
        eps = torch.randn_like(std) # `randn_like` as we need the same size
        sample = mu + (eps * std) # sampling as if coming from the input space
        return sample
    
    def encode(self, x, return_partials=True):
        # Encoder
        x = self.conv1(x)
        up3out_shape = x.shape
        x, i1 = self.pool1(x)

        x = self.conv2(x)
        up2out_shape = x.shape
        x, i2 = self.pool2(x)

        x = self.conv3(x)
        up1out_shape = x.shape
        x, i3 = self.pool3(x)

        x = self.conv4(x)
        up0out_shape = x.shape
        x, i4 = self.pool4(x)

        x = x.view(-1, 10 * 10)
        x = F.relu(self.fc1(x))

        if return_partials:
            return x, up3out_shape, i1, up2out_shape, i2, up1out_shape, i3, \
                   up0out_shape, i4

        else:
            return x

    def forward(self, x):
        x, up3out_shape, i1, up2out_shape, i2, \
        up1out_shape, i3, up0out_shape, i4 = self.encode(x)
        
        mu = x
        log_var = x
        # get the latent vector through reparameterization
        x = self.reparameterize(mu, log_var)
        
        # Decoder
        x = F.relu(self.fc2(x))
        x = x.view(-1, 1, 1, 10, 10)
        x = self.unpool0(x, output_size=up0out_shape, indices=i4)
        x = self.deconv0(x)
        x = self.unpool1(x, output_size=up1out_shape, indices=i3)
        x = self.deconv1(x)
        x = self.unpool2(x, output_size=up2out_shape, indices=i2)
        x = self.deconv2(x)
        x = self.unpool3(x, output_size=up3out_shape, indices=i1)
        x = self.deconv3(x)
        x=torch.sigmoid(x)
        return x,mu,log_var

In [None]:
def final_loss(bce_loss, mu, logvar):
    """
    This function will add the reconstruction loss (BCELoss) and the 
    KL-Divergence.
    KL-Divergence = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    :param bce_loss: recontruction loss
    :param mu: the mean from the latent vector
    :param logvar: log variance from the latent vector
    """
    BCE = bce_loss 
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# 5. Training
I decided to take this opportunity to learn how to use TPUs. However, after days of intense frustration, I gave up. XLA documentation is very poor. Notebook examples are either so simple they are not useful at all, or so advanced/complex it is impossible to understand what is happening. The code frequently freezes, and it is impossible to know what is happening in the background…

The code below runs smoothly on GPU.

## 5.1. Monitoring on Tensorboard
Credits to [Shivam Kumar tutorial](https://www.kaggle.com/shivam1600/tensorboard-on-kaggle).

In [None]:
!rm -rf ./logs/ 
!mkdir ./logs/
# Download Ngrok to tunnel the tensorboard port to an external port
!wget https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip
!unzip ngrok-stable-linux-amd64.zip
pool = multiprocessing.Pool(processes = 10)
results_of_processes = [
    pool.apply_async(os.system, args=(cmd, ), callback=None) for cmd in [
        f"tensorboard --logdir {tensorboard_dir}/ --host 0.0.0.0 --port 6006 &",
        "./ngrok http 6006 &"
    ]
]

In [None]:
! curl -s http://localhost:4040/api/tunnels | python3 -c \
    "import sys, json; print(json.load(sys.stdin)['tunnels'][0]['public_url'])"

## 5.2. Training loop
IMPORTANT: For the sake of the demonstration, I'm running this training only for 10 epochs. To have usable results, we need at least 100 epochs.

In [None]:
t0 = time()

# Load the data
data = CTTensorsDataset(
    root_dir=cache_dir,
    transform=ZeroCenter(pre_calculated_mean=pre_calculated_mean)
)
train_set, val_set = data.random_split(val_size)
datasets = {'train': train_set, 'val': val_set}
dataloaders = {
    x: DataLoader(
        datasets[x],
        batch_size=batch_size,
        shuffle=(x == 'train'),
        num_workers=0,
        pin_memory=False
    ) for x in ['train', 'val']}

dataset_sizes = {x: len(datasets[x]) for x in ['train', 'val']}

# Prepare for training
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = AutoEncoder(latent_features=latent_features).to(device)
# criterion = torch.nn.MSELoss()
criterion = nn.BCELoss(reduction='sum')
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

best_model_wts = None
best_loss = np.inf
date_time = datetime.now().strftime("%Y%m%d-%H%M")
log_dir = Path(tensorboard_dir) / f'{date_time}'
writer = SummaryWriter(log_dir)

In [None]:
print(device)

In [None]:
# Training loop
for epoch in range(0,150):

    # Each epoch has a training and validation phase
    for phase in ['train', 'val']:
        if phase == 'train':
            model.train()  # Set model to training mode
        else:
            model.eval()   # Set model to evaluate mode

        running_loss = 0.0
        running_preds = 0

        # Iterate over data.
        bar = tqdm(dataloaders[phase])
        for inputs in bar:
            bar.set_description(f'Epoch {epoch + 1} {phase}'.ljust(20))
            inputs = inputs['image'].to(device, dtype=torch.float)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward
            # track history if only in train
            with torch.set_grad_enabled(phase == 'train'):
#                 outputs = model(inputs)
#                 loss = criterion(outputs, inputs)
                reconstruction, mu, logvar = model(inputs)
                bce_loss = criterion(reconstruction, inputs)
                loss=final_loss(bce_loss, mu, logvar)

                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

            # statistics
            running_loss += loss.item() * inputs.size(0)
            running_preds += inputs.size(0)
            bar.set_postfix(loss=f'{running_loss / running_preds:0.6f}')

        epoch_loss = running_loss / dataset_sizes[phase]
        writer.add_scalar(f'Loss/{phase}', epoch_loss, epoch)

        # deep copy the model
        if phase == 'val' and epoch_loss < best_loss:
            best_loss = epoch_loss
            best_model_wts = copy.deepcopy(model.state_dict())
            torch.save(best_model_wts, model_file)
                                                                                 
# load best model weights
model.load_state_dict(best_model_wts)

print(f'Done! Time {timedelta(seconds=time() - t0)}')

In [None]:
!nvidia-smi

# 6. Inference and inspection
The code below inspects a random sample. As mentioned, the quality can be improved by increasing the number of latent features. However, that will become a problem later when we combine the latent features with the tabular features in the Quant Model.

In [None]:
slc = 0.5
sample_id = np.random.randint(len(data))
print(f'Inspecting CT Scan {data[sample_id]["patient_id"]}')

fig, axs = plt.subplots(1, 2, figsize=(10, 7))

sample = data[sample_id]['image'].squeeze(0).numpy()
axs[0].imshow(sample[int(40 * slc), :, :], cmap=cm.bone)
axs[0].axis('off')
imageio.mimsave("sample_input.gif", sample, duration=0.0001)

with torch.no_grad():
    img = data[sample_id]['image'].unsqueeze(0).float().to(device)
    latent_features = model.encode(img, return_partials=False)\
        .squeeze().cpu().numpy().tolist()
    outputs = model(img).squeeze().cpu().numpy()

axs[1].imshow(outputs[int(40 * slc), :, :], cmap=cm.bone)
axs[1].axis('off')

imageio.mimsave("sample_output.gif", outputs, duration=0.0001)

rmse = ((sample - outputs)**2).mean()
plt.show()
print(f'Latent features: {latent_features} \nLoss: {rmse}')

In [None]:
from IPython.display import HTML
HTML('<br/><img src="https://i.ibb.co/gFxgRq6/sample-input.gif" style="float: left; width: 30%; margin-right: 1%; margin-bottom: 0.5em;">'
     '<img src="https://i.ibb.co/Jm57fWw/sample-output.gif" style="float: left; width: 30%; margin-right: 1%; margin-bottom: 0.5em;">'
     '<p style="clear: both;">')

# 7. Next steps
- Train longer: 10 epochs is not enough to achieve good results
- Use the latent features in the Quant Model, and check how much they improve the predictions
- Investigate/debug why some of the latent features are zero

Clinical Dataset

In [None]:
import pickle
class ClinicalDataset(Dataset):
    def __init__(self, root_dir, ctscans_dir, mode, transform=None,
                 cache_dir=None):
        self.transform = transform
        self.mode = mode
        self.ctscans_dir = Path(ctscans_dir)
        self.cache_dir = None if cache_dir is None else Path(cache_dir)

        # If cache_dir is set, use cached values...
        if cache_dir is not None:
            self.raw = pd.read_csv(self.cache_dir/f'tabular_{mode}.csv')
            with open(self.cache_dir/'features_list.pkl', "rb") as fp:
                self.FE = pickle.load(fp)
            return

        # ...otherwise, pre-process
        tr = pd.read_csv(Path(root_dir)/"train.csv")
        tr.drop_duplicates(keep=False, inplace=True, subset=['Patient', 'Weeks'])
        chunk = pd.read_csv(Path(root_dir)/"test.csv")

        sub = pd.read_csv(Path(root_dir)/"sample_submission.csv")
        sub['Patient'] = sub['Patient_Week'].apply(lambda x: x.split('_')[0])
        sub['Weeks'] = sub['Patient_Week'].apply(lambda x: int(x.split('_')[-1]))
        sub = sub[['Patient', 'Weeks', 'Confidence', 'Patient_Week']]
        sub = sub.merge(chunk.drop('Weeks', axis=1), on="Patient")

        tr['WHERE'] = 'train'
        chunk['WHERE'] = 'val'
        sub['WHERE'] = 'test'
        data = tr.append([chunk, sub])

        data['min_week'] = data['Weeks']
        data.loc[data.WHERE == 'test', 'min_week'] = np.nan
        data['min_week'] = data.groupby('Patient')['min_week'].transform('min')

        base = data.loc[data.Weeks == data.min_week]
        base = base[['Patient', 'FVC']].copy()
        base.columns = ['Patient', 'min_FVC']
        base['nb'] = 1
        base['nb'] = base.groupby('Patient')['nb'].transform('cumsum')
        base = base[base.nb == 1]
        base.drop('nb', axis=1, inplace=True)

        data = data.merge(base, on='Patient', how='left')
        data['base_week'] = data['Weeks'] - data['min_week']
        del base

        COLS = ['Sex', 'SmokingStatus']
        self.FE = []
        for col in COLS:
            for mod in data[col].unique():
                self.FE.append(mod)
                data[mod] = (data[col] == mod).astype(int)

        data['age'] = (data['Age'] - data['Age'].min()) / \
                      (data['Age'].max() - data['Age'].min())
        data['BASE'] = (data['min_FVC'] - data['min_FVC'].min()) / \
                       (data['min_FVC'].max() - data['min_FVC'].min())
        data['week'] = (data['base_week'] - data['base_week'].min()) / \
                       (data['base_week'].max() - data['base_week'].min())
        data['percent'] = (data['Percent'] - data['Percent'].min()) / \
                          (data['Percent'].max() - data['Percent'].min())
        self.FE += ['age', 'percent', 'week', 'BASE']

        self.raw = data.loc[data.WHERE == mode].reset_index()
        del data

    def __len__(self):
        return len(self.raw)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        patient_id = self.raw['Patient'].iloc[idx]
        if self.cache_dir is None:
            patient_path = self.ctscans_dir / patient_id
            image, metadata = self.load_scan(patient_path)
        else:
            image = torch.load(self.cache_dir / f'{patient_id}.pt')
            metadata = pydicom.read_file(self.cache_dir / f'{patient_id}.dcm')

        sample = {
            'features': self.raw[self.FE].iloc[idx].values,
            'image': image,
            'metadata': metadata,
            'target': self.raw['FVC'].iloc[idx]
        }
        if self.transform:
            sample = self.transform(sample)

        return sample

    def cache(self, cache_dir):
        Path(cache_dir).mkdir(exist_ok=True, parents=True)

        # Cache raw features table
        self.raw.to_csv(Path(cache_dir)/f'tabular_{self.mode}.csv', index=False)

        # Cache features list
        with open(Path(cache_dir)/'features_list.pkl', "wb") as fp:
            pickle.dump(self.FE, fp)

        # Cache images and metadata
        self.raw['index'] = self.raw.index
        idx_unique = self.raw.groupby('Patient').first()['index'].values
        bar = tqdm(idx_unique.tolist())
        for idx in bar:
            sample = self[idx]
            patient_id = sample['metadata'].PatientID
            torch.save(sample['image'], Path(cache_dir)/f'{patient_id}.pt')
            sample['metadata'].save_as(Path(cache_dir)/f'{patient_id}.dcm')
            
    @staticmethod
    def load_scan(path):
        slices = [pydicom.read_file(p) for p in path.glob('*.dcm')]
        try:
            slices.sort(key=lambda x: float(x.ImagePositionPatient[2]))
        except AttributeError:
            warnings.warn(f'Patient {slices[0].PatientID} CT scan does not '
                          f'have "ImagePositionPatient". Assuming filenames '
                          f'in the right scan order.')

        image = np.stack([s.pixel_array.astype(float) for s in slices])
        return image, slices[0]

            

# class CTScansDataset(Dataset):
#     def __init__(self, root_dir, transform=None):
#         self.root_dir = Path(root_dir)
#         self.patients = [p for p in self.root_dir.glob('*') if p.is_dir()]
#         self.transform = transform

#     def __len__(self):
#         return len(self.patients)

#     def __getitem__(self, idx):
#         if torch.is_tensor(idx):
#             idx = idx.tolist()

#         image, metadata = self.load_scan(self.patients[idx])
#         sample = {'image': image, 'metadata': metadata}
#         if self.transform:
#             sample = self.transform(sample)

#         return sample

#     def save(self, path):
#         t0 = time()
#         Path(path).mkdir(exist_ok=True, parents=True)
#         print('Saving pre-processed dataset to disk')
#         sleep(1)
#         cum = 0

#         bar = trange(len(self))
#         for i in bar:
#             sample = self[i]
#             image, data = sample['image'], sample['metadata']
#             cum += torch.mean(image).item()

#             bar.set_description(f'Saving CT scan {data.PatientID}')
#             fname = Path(path) / f'{data.PatientID}.pt'
#             torch.save(image, fname)

#         sleep(1)
#         bar.close()
#         print(f'Done! Time {timedelta(seconds=time() - t0)}\n'
#               f'Mean value: {cum / len(self)}')

#     def get_patient(self, patient_id):
#         patient_ids = [str(p.stem) for p in self.patients]
#         return self.__getitem__(patient_ids.index(patient_id))

#     @staticmethod
#     def load_scan(path):
#         slices = [pydicom.read_file(p) for p in path.glob('*.dcm')]
#         try:
#             slices.sort(key=lambda x: float(x.ImagePositionPatient[2]))
#         except AttributeError:
#             warnings.warn(f'Patient {slices[0].PatientID} CT scan does not '
#                           f'have "ImagePositionPatient". Assuming filenames '
#                           f'in the right scan order.')

#         image = np.stack([s.pixel_array.astype(float) for s in slices])
#         return image, slices[0]

In [None]:
# test = CTScansDataset(
#     root_dir=test_dir,
#     transform=transforms.Compose([
#         CropBoundingBox(),
#         ConvertToHU(),
#         Resize(resize_dims),
#         Clip(bounds=clip_bounds),
#         MaskWatershed(min_hu=min(clip_bounds), iterations=1, show_tqdm=True),
#         Normalize(bounds=clip_bounds)
#     ]))
import os
data = ClinicalDataset(
    root_dir=root_dir,
    ctscans_dir=test_dir,
    mode='val',
    transform=transforms.Compose([
        CropBoundingBox(),
        ConvertToHU(),
        Resize((40, 256, 256)),
        Clip(bounds=clip_bounds),
        MaskWatershed(min_hu=min(clip_bounds), iterations=1, show_tqdm=True),
        Normalize(bounds=clip_bounds)
    ]))

for i in range(len(data)):
    assert data[i]['image'].shape == (40, 256, 256)
    
list_imgs = [data[i]['image'] for i in range(len(data))]
show(list_imgs)

In [None]:
!conda install gdcm -c conda-forge -y 

In [None]:
data = ClinicalDataset(
    root_dir=root_dir,
    ctscans_dir=os.path.join(root_dir,'train'),
    mode='train',
    transform=transforms.Compose([
        CropBoundingBox(),
        ConvertToHU(),
        Resize(resize_dims),
        Clip(bounds=clip_bounds),
        MaskWatershed(
            min_hu=min(clip_bounds),
            iterations=watershed_iterations,
            show_tqdm=False),
        Normalize(bounds=clip_bounds),
        ToTensor()
    ]))
data.cache("/kaggle/working/")

4.1. Quant model

In [None]:
class QuantModel(nn.Module):
    def __init__(self, in_tabular_features=9, in_ctscan_features=76800,
                 out_quantiles=3):
        super(QuantModel, self).__init__()
        # This line is new. We need to know a priori the number
        # of latent features to properly flatten the tensor
        self.in_ctscan_features = in_ctscan_features

        self.fc1 = nn.Linear(in_tabular_features, 512)
        self.fc2 = nn.Linear(in_ctscan_features, 512)
        self.fc3 = nn.Linear(1024, 512)
        self.fc4 = nn.Linear(512, out_quantiles)

    def forward(self, x1, x2):
        # Now the quant model has 2 inputs: x1 (the tabular features)
        # and x2 (the pre-computed latent features)
        x1 = F.relu(self.fc1(x1))
        
        # Flattens the latent features and concatenate with tabular features
        x2 = x2.view(-1, self.in_ctscan_features)
        x2 = F.relu(self.fc2(x2))
        x = torch.cat([x1, x2], dim=1)
        
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x

4.2. Quant loss

In [None]:
def quantile_loss(preds, target, quantiles):
    assert not target.requires_grad
    assert preds.size(0) == target.size(0)
    losses = []                                              
    for i, q in enumerate(quantiles):                                                                
        errors = target - preds[:, i]
        losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(1))   
    loss = torch.mean(torch.sum(torch.cat(losses, dim=1), dim=1))                                                
    return loss

Cache all pre-processed 3D CT Scans and pre-compute all latent features

In [None]:
# Helper function that generates all latent features
class GenerateLatentFeatures:
    def __init__(self, autoencoder, latent_dir):
        self.autoencoder = autoencoder
        self.latent_dir = Path(latent_dir)
        self.cache_dir = Path(cache_dir)

    def __call__(self, sample):
        patient_id = sample['metadata'].PatientID
        cached_latent_file = self.latent_dir/f'{patient_id}_lat.pt'

        if cached_latent_file.is_file():
            latent_features = torch.load(cached_latent_file)
        else:
            with torch.no_grad():
                img = sample['image'].float().unsqueeze(0)
                latent_features = self.autoencoder.encode(
                    img, return_partials=False).squeeze(0)
            torch.save(latent_features, cached_latent_file)

        return {
            'tabular_features': sample['features'],
            'latent_features': latent_features,
            'target': sample['target']
        }

In [None]:
autoencoder = AutoEncoder()
autoencoder.load_state_dict(torch.load(
    pretrained_ae_weigths,
    map_location=torch.device('cpu')
))
autoencoder.to(device)
autoencoder.eval()

data = ClinicalDataset(
    root_dir=root_dir,
    ctscans_dir=root_dir/'train',
    cache_dir=cache_dir,
    mode='train',
    transform=GenerateLatentFeatures(autoencoder, latent_dir)
)
for i in trange(len(data)):
    sample = data[i]
    assert sample['latent_features'].shape == (96, 2, 20, 20)

Overfit a small batch before moving forward

In [None]:
dataloader = DataLoader(data, batch_size=batch_size,
                        shuffle=True, num_workers=2)
batch = next(iter(dataloader))

model = QuantModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

bar = trange(50)
for epoch in bar:
    inputs1 = batch['tabular_features'].float().to(device)
    inputs2 = batch['latent_features'].float().to(device)
    targets = batch['target'].to(device)

    optimizer.zero_grad()
    preds = model(inputs1, inputs2)
    loss = quantile_loss(preds, targets, quantiles)
    loss.backward()
    if use_TPU:
        xm.optimizer_step(optimizer, barrier=True)
    else:
        optimizer.step()
    

    bar.set_postfix(loss=f'{loss.item():0.1f}')

Training

In [None]:
# Helper generator that group splits
def group_split(dataset, groups, test_size=0.2):
    gss = GroupShuffleSplit(n_splits=1, test_size=test_size)
    idx = list(gss.split(dataset.raw, dataset.raw, groups))
    train = Subset(dataset, idx[0][0])
    val = Subset(dataset, idx[0][1])
    return train, val
        
# Helper function with competition metric
def metric(preds, targets):
    sigma = preds[:, 2] - preds[:, 0]
    sigma[sigma < 70] = 70
    delta = (preds[:, 1] - targets).abs()
    delta[delta > 1000] = 1000
    return -np.sqrt(2) * delta / sigma - torch.log(np.sqrt(2) * sigma)

In [None]:
# Load the data
autoencoder = AutoEncoder()
autoencoder.load_state_dict(torch.load(
    pretrained_ae_weigths,
    map_location=torch.device('cpu')
))
autoencoder.eval()

data = ClinicalDataset(
    root_dir=root_dir,
    ctscans_dir=root_dir/'train',
    cache_dir=cache_dir,
    mode='train',
    transform=GenerateLatentFeatures(autoencoder, latent_dir)
)

trainset, valset = group_split(data, data.raw['Patient'], test_size)
t0 = time()

# Prepare to save model weights
Path(model_dir).mkdir(parents=True, exist_ok=True)
now = datetime.now()
fname = f'{model_name}-{now.year}{now.month:02d}{now.day:02d}.pth'
model_file = Path(model_dir) / fname

dataset_sizes = {'train': len(trainset), 'val': len(valset)}
dataloaders = {
    'train': DataLoader(trainset, batch_size=batch_size,
                        shuffle=True, num_workers=2),
    'val': DataLoader(valset, batch_size=batch_size,
                      shuffle=False, num_workers=2)
}

# Create the model and optimizer
model = QuantModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Set global tracking variables
epoch_loss = {'train': np.inf, 'val': np.inf}
epoch_metric = {'train': -np.inf, 'val': -np.inf}
best_loss = np.inf
best_model_wts = None
df = pd.DataFrame(columns=['epoch', 'train_loss', 'val_loss'])

# Training loop
for epoch in range(num_epochs):
    for phase in ['train', 'val']:
        if phase == 'train':
            model.train()  # Set model to training mode
        else:
            model.eval()   # Set model to evaluate mode

        running_loss = 0.0
        running_metric = 0.0

        # Iterate over data
        num_samples = 0
        bar = tqdm(dataloaders[phase])
        for batch in bar:
            bar.set_description(f'Epoch {epoch} {phase}'.ljust(20))
            inputs1 = batch['tabular_features'].float().to(device)
            inputs2 = batch['latent_features'].float().to(device)
            targets = batch['target'].to(device)

            # zero the parameter gradients
            optimizer.zero_grad()
            # forward
            # track gradients if only in train
            with torch.set_grad_enabled(phase == 'train'):
                preds = model(inputs1, inputs2)
                loss = quantile_loss(preds, targets, quantiles)
                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    if use_TPU:
                        xm.optimizer_step(optimizer, barrier=True)
                    else:
                        optimizer.step()

            running_loss += loss.item() * inputs1.size(0)
            running_metric += metric(preds, targets).sum()

            # batch statistics
            num_samples += inputs1.size(0)
            bar.set_postfix(loss=f'{running_loss / num_samples:0.1f}',
                            metric=f'{running_metric / num_samples:0.4f}')

        # epoch statistics
        epoch_loss[phase] = running_loss / dataset_sizes[phase]
        epoch_metric[phase] = running_metric / dataset_sizes[phase]

        # deep copy the model
        if phase == 'val' and epoch_loss['val'] < best_loss:
            best_loss = epoch_loss['val']
            best_model_wts = copy.deepcopy(model.state_dict())
            torch.save(best_model_wts, model_file)

    df = df.append({
        'epoch': epoch + 1,
        'train_loss': epoch_loss["train"],
        'val_loss': epoch_loss["val"]
    }, ignore_index=True)

# Save training statistics
fname = f'{model_name}-{now.year}{now.month:02d}{now.day:02d}.csv'
csv_file = Path(model_dir) / fname
df.to_csv(csv_file, index=False)

# load best model weights
model.load_state_dict(best_model_wts)

print(f'Training complete! Time: {timedelta(seconds=time() - t0)}')
models = [model]

Generate Submission CSV

In [None]:
data = ClinicalDataset(
    root_dir=root_dir,
    ctscans_dir=root_dir/'test',
    mode='test',
    transform=transforms.Compose([
        CropBoundingBox(),
        ConvertToHU(),
        Resize((40, 256, 256)),
        Clip(bounds=(-1000, 500)),
        Mask(method=MaskMethod.MORPHOLOGICAL, threshold=-500),
        Normalize(bounds=(-1000, -500)),
        ToTensor(),
        ZeroCenter(pre_calculated_mean=0.029105728564346046)
    ]))

data.cache(latent_dir)

In [None]:
data = ClinicalDataset(
    root_dir=root_dir,
    ctscans_dir=root_dir/'test',
    cache_dir=latent_dir,
    mode='test',
    transform=GenerateLatentFeatures(autoencoder, latent_dir)
)

avg_preds = np.zeros((len(data), len(quantiles)))

for model in models:
    dataloader = DataLoader(data, batch_size=batch_size,
                            shuffle=False, num_workers=2)
    preds = []
    for batch in tqdm(dataloader):
        inputs1 = batch['tabular_features'].float()
        inputs2 = batch['latent_features'].float()
        with torch.no_grad():
            preds.append(model(inputs1, inputs2))

    preds = torch.cat(preds, dim=0).numpy()
    avg_preds += preds

avg_preds /= len(models)
df = pd.DataFrame(data=avg_preds, columns=list(quantiles))
df['Patient_Week'] = data.raw['Patient_Week']
df['FVC'] = df[quantiles[1]]
df['Confidence'] = df[quantiles[2]] - df[quantiles[0]]
df = df.drop(columns=list(quantiles))
df.to_csv('submission.csv', index=False)