# 1. OSIC AutoEncoder training
This notebooks demonstrates how to train a convolutional AutoEncoder to learn latent features from the 3D CT scans dataset.

One of the main applications of AutoEncoders is dimensionality reduction. We will use them for that: reducing 3D images (preprocessed to 1 x 40 x 256 x 256 tensors) to vectors (with 10 dimensions).
![autoencoder](https://hackernoon.com/hn-images/1*8ixTe1VHLsmKB3AquWdxpQ.png)

Once we have the trained model, the idea is to apply it to extract these latent features and combine them with the OSIC tabular data.

My first experiments had a less strangled bottleneck (started with 96 x 2 x 20 x 20), which was already a reduction of over 34:1 (the inputs are 3D images of 1 x 40 x 256 x 256). The AutoEncoder output was great, easy to see. However, using latent features of 96 x 2 x 20 x 20 meant that, in the tabular model, I had to combine 76,800 features (flattened) with the 9 tabular features. In order to have a better balance between tabular and latent features, I decide to strangle the bottleneck further, squeezing the 3D images to 10 features (already flatenned in the AutoEncoder model). As you can see below, the model learns as the loss keeps going down. However, the output of the AutoEncoder is not as visible as with the less strangled bottleneck.

# 2. Imports and global variables

In [None]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev

In [None]:
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils

import warnings
warnings.filterwarnings("ignore")

import os
os.environ['XLA_USE_BF16']="1"
os.environ['XLA_TENSOR_ALLOCATOR_MAXSIZE'] = '100000000'

In [None]:
import copy
from datetime import timedelta, datetime
import imageio
import matplotlib.pyplot as plt
from matplotlib import cm
import multiprocessing
import numpy as np
import os
from pathlib import Path
import pydicom
import pytest
import scipy.ndimage as ndimage
from scipy.ndimage.interpolation import zoom
from skimage import measure, morphology, segmentation
from time import time, sleep
from tqdm import trange, tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, random_split, DistributedSampler, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms

In [None]:
root_dir = '/kaggle/input/osic-cached-dataset'
test_dir = '/kaggle/input/osic-pulmonary-fibrosis-progression/test'
model_file = '/kaggle/working/diophantus.pt'
resize_dims = (40, 256, 256)
clip_bounds = (-1000, 200)
watershed_iterations = 1
pre_calculated_mean = 0.02865046213070556
latent_features = 10
batch_size = 4
learning_rate = 3e-5
num_epochs = 10
val_size = 0.2
tensorboard_dir = '/kaggle/working/runs'
flags = {
    'batch_size': batch_size,
    'num_epochs': num_epochs,
    'seed': 1234,
    'learning_rate': learning_rate,
    'model_file': model_file
}

# 3. Dataset interface
## 3.1. ctscans_dataset.py
This interface ingests the data from the 3D CT scans, porting them to a PyTorch Dataset.**

In [None]:
class CTScansDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = Path(root_dir)
        self.patients = [p for p in self.root_dir.glob('*') if p.is_dir()]
        self.transform = transform

    def __len__(self):
        return len(self.patients)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        image, metadata = self.load_scan(self.patients[idx])
        sample = {'image': image, 'metadata': metadata}
        if self.transform:
            sample = self.transform(sample)

        return sample

    def save(self, path):
        t0 = time()
        Path(path).mkdir(exist_ok=True, parents=True)
        print('Saving pre-processed dataset to disk')
        sleep(1)
        cum = 0

        bar = trange(len(self))
        for i in bar:
            sample = self[i]
            image, data = sample['image'], sample['metadata']
            cum += torch.mean(image).item()

            bar.set_description(f'Saving CT scan {data.PatientID}')
            fname = Path(path) / f'{data.PatientID}.pt'
            torch.save(image, fname)

        sleep(1)
        bar.close()
        print(f'Done! Time {timedelta(seconds=time() - t0)}\n'
              f'Mean value: {cum / len(self)}')

    def get_patient(self, patient_id):
        patient_ids = [str(p.stem) for p in self.patients]
        return self.__getitem__(patient_ids.index(patient_id))

    @staticmethod
    def load_scan(path):
        slices = [pydicom.read_file(p) for p in path.glob('*.dcm')]
        try:
            slices.sort(key=lambda x: float(x.ImagePositionPatient[2]))
        except AttributeError:
            warnings.warn(f'Patient {slices[0].PatientID} CT scan does not '
                          f'have "ImagePositionPatient". Assuming filenames '
                          f'in the right scan order.')

        image = np.stack([s.pixel_array.astype(float) for s in slices])
        return image, slices[0]

## 3.2. Pre-processing
There are some pre-processing to be done. Let's tackle them one step at a time.
### 3.2.1. crop_bounding_box.py

In [None]:
class CropBoundingBox:
    @staticmethod
    def bounding_box(img3d: np.array):
        mid_img = img3d[int(img3d.shape[0] / 2)]
        same_first_row = (mid_img[0, :] == mid_img[0, 0]).all()
        same_first_col = (mid_img[:, 0] == mid_img[0, 0]).all()
        if same_first_col and same_first_row:
            return True
        else:
            return False

    def __call__(self, sample):
        image, data = sample['image'], sample['metadata']
        if not self.bounding_box(image):
            return sample

        mid_img = image[int(image.shape[0] / 2)]
        r_min, r_max = None, None
        c_min, c_max = None, None
        for row in range(mid_img.shape[0]):
            if not (mid_img[row, :] == mid_img[0, 0]).all() and r_min is None:
                r_min = row
            if (mid_img[row, :] == mid_img[0, 0]).all() and r_max is None \
                    and r_min is not None:
                r_max = row
                break

        for col in range(mid_img.shape[1]):
            if not (mid_img[:, col] == mid_img[0, 0]).all() and c_min is None:
                c_min = col
            if (mid_img[:, col] == mid_img[0, 0]).all() and c_max is None \
                    and c_min is not None:
                c_max = col
                break

        image = image[:, r_min:r_max, c_min:c_max]
        return {'image': image, 'metadata': data}

### 3.2.2. convert_to_hu.py
Credits to [Guido Zuidhof's tutorial](https://www.kaggle.com/gzuidhof/full-preprocessing-tutorial).

In [None]:
class ConvertToHU:
    def __call__(self, sample):
        image, data = sample['image'], sample['metadata']

        img_type = data.ImageType
        is_hu = img_type[0] == 'ORIGINAL' and not (img_type[2] == 'LOCALIZER')
        # if not is_hu:
        #     warnings.warn(f'Patient {data.PatientID} CT Scan not cannot be'
        #                   f'converted to Hounsfield Units (HU).')

        intercept = data.RescaleIntercept
        slope = data.RescaleSlope
        image = (image * slope + intercept).astype(np.int16)
        return {'image': image, 'metadata': data}

### 3.2.3. resize.py

In [None]:
class Resize:
    def __init__(self, output_size):
        assert isinstance(output_size, tuple)
        self.output_size = output_size

    def __call__(self, sample):
        image, data = sample['image'], sample['metadata']
        resize_factor = np.array(self.output_size) / np.array(image.shape)
        image = zoom(image, resize_factor, mode='nearest')
        return {'image': image, 'metadata': data}

### 3.2.4. clip.py

In [None]:
class Clip:
    def __init__(self, bounds=(-1000, 500)):
        self.min = min(bounds)
        self.max = max(bounds)

    def __call__(self, sample):
        image, data = sample['image'], sample['metadata']
        image[image < self.min] = self.min
        image[image > self.max] = self.max
        return {'image': image, 'metadata': data}

### 3.2.5. mask_watershed.py
Credits to [Aadhav Vignesh's amazing kernel](https://www.kaggle.com/aadhavvignesh/lung-segmentation-by-marker-controlled-watershed).

IMPORTANT: I made some changes in Vignesh's code below to make it scalable, most notably reducing the number of iterations from 8 to 1. This was important to reduce the time to generate masks from ~8-9 seconds/slice (which would take over 17 hours to complete) to ~100ms/slice. I'm satisfied with the quality of the masks, as you can see in some samples below. However, using 8 iterations generate even better masks.

In [None]:
class MaskWatershed:
    def __init__(self, min_hu, iterations, show_tqdm):
        self.min_hu = min_hu
        self.iterations = iterations
        self.show_tqdm = show_tqdm

    def __call__(self, sample):
        image, data = sample['image'], sample['metadata']

        stack = []
        if self.show_tqdm:
            bar = trange(image.shape[0])
            bar.set_description(f'Masking CT scan {data.PatientID}')
        else:
            bar = range(image.shape[0])
        for slice_idx in bar:
            sliced = image[slice_idx]
            stack.append(self.seperate_lungs(sliced, self.min_hu,
                                             self.iterations))

        return {
            'image': np.stack(stack),
            'metadata': sample['metadata']
        }

    @staticmethod
    def seperate_lungs(image, min_hu, iterations):
        h, w = image.shape[0], image.shape[1]

        marker_internal, marker_external, marker_watershed = MaskWatershed.generate_markers(image)

        # Sobel-Gradient
        sobel_filtered_dx = ndimage.sobel(image, 1)
        sobel_filtered_dy = ndimage.sobel(image, 0)
        sobel_gradient = np.hypot(sobel_filtered_dx, sobel_filtered_dy)
        sobel_gradient *= 255.0 / np.max(sobel_gradient)

        watershed = morphology.watershed(sobel_gradient, marker_watershed)

        outline = ndimage.morphological_gradient(watershed, size=(3,3))
        outline = outline.astype(bool)

        # Structuring element used for the filter
        blackhat_struct = [[0, 0, 1, 1, 1, 0, 0],
                           [0, 1, 1, 1, 1, 1, 0],
                           [1, 1, 1, 1, 1, 1, 1],
                           [1, 1, 1, 1, 1, 1, 1],
                           [1, 1, 1, 1, 1, 1, 1],
                           [0, 1, 1, 1, 1, 1, 0],
                           [0, 0, 1, 1, 1, 0, 0]]

        blackhat_struct = ndimage.iterate_structure(blackhat_struct, iterations)

        # Perform Black Top-hat filter
        outline += ndimage.black_tophat(outline, structure=blackhat_struct)

        lungfilter = np.bitwise_or(marker_internal, outline)
        lungfilter = ndimage.morphology.binary_closing(lungfilter, structure=np.ones((5,5)), iterations=3)

        segmented = np.where(lungfilter == 1, image, min_hu * np.ones((h, w)))

        return segmented  #, lungfilter, outline, watershed, sobel_gradient

    @staticmethod
    def generate_markers(image, threshold=-400):
        h, w = image.shape[0], image.shape[1]

        marker_internal = image < threshold
        marker_internal = segmentation.clear_border(marker_internal)
        marker_internal_labels = measure.label(marker_internal)

        areas = [r.area for r in measure.regionprops(marker_internal_labels)]
        areas.sort()

        if len(areas) > 2:
            for region in measure.regionprops(marker_internal_labels):
                if region.area < areas[-2]:
                    for coordinates in region.coords:
                        marker_internal_labels[coordinates[0], coordinates[1]] = 0

        marker_internal = marker_internal_labels > 0

        # Creation of the External Marker
        external_a = ndimage.binary_dilation(marker_internal, iterations=10)
        external_b = ndimage.binary_dilation(marker_internal, iterations=55)
        marker_external = external_b ^ external_a

        # Creation of the Watershed Marker
        marker_watershed = np.zeros((h, w), dtype=np.int)
        marker_watershed += marker_internal * 255
        marker_watershed += marker_external * 128

        return marker_internal, marker_external, marker_watershed

### 3.2.6. normalize.py, to_tensor.py, zero_center.py

In [None]:
class Normalize:
    def __init__(self, bounds=(-1000, 500)):
        self.min = min(bounds)
        self.max = max(bounds)

    def __call__(self, sample):
        image, data = sample['image'], sample['metadata']
        image = image.astype(np.float)
        image = (image - self.min) / (self.max - self.min)
        return {'image': image, 'metadata': data}
    

class ToTensor:
    def __init__(self, add_channel=True):
        self.add_channel = add_channel

    def __call__(self, sample):
        image, data = sample['image'], sample['metadata']
        if self.add_channel:
            image = np.expand_dims(image, axis=0)

        return {'image': torch.from_numpy(image), 'metadata': data}
    
    
class ZeroCenter:
    def __init__(self, pre_calculated_mean):
        self.pre_calculated_mean = pre_calculated_mean

    def __call__(self, tensor):
        return tensor - self.pre_calculated_mean

### 3.2.7. Inspecting some slices

In [None]:
def show(list_imgs, cmap=cm.bone):
    list_slices = []
    for img3d in list_imgs:
        slc = int(img3d.shape[0] / 2)
        img = img3d[slc]
        list_slices.append(img)
    
    fig, axs = plt.subplots(1, 5, figsize=(15, 7))
    for i, img in enumerate(list_slices):
        axs[i].imshow(img, cmap=cmap)
        axs[i].axis('off')
        
    plt.show()

In [None]:
test = CTScansDataset(
    root_dir=test_dir,
    transform=transforms.Compose([
        CropBoundingBox(),
        ConvertToHU(),
        Resize(resize_dims),
        Clip(bounds=clip_bounds),
        MaskWatershed(min_hu=min(clip_bounds), iterations=1, show_tqdm=True),
        Normalize(bounds=clip_bounds)
    ]))

list_imgs = [test[i]['image'] for i in range(len(test))]
show(list_imgs)

## 3.3. Caching pre-processed images in the disk
Pre-processing all 176 3D CT scans take some time. Depending on the parameters we choose, it can take hours. 

With the current choice of parameters, it takes around 15 minutes. To accelerate experimentation, I already pre-cached the images with the preprocessing parameters in this notebook, saving them in a [public dataset](https://www.kaggle.com/carlossouza/osic-cached-dataset). 

This way, you can preprocess only once, and experiment with the same preprocessed tensors. The code to preprocess and cache images in the disk is:
```
data = CTScansDataset(
    root_dir=root_dir,
    transform=transforms.Compose([
        CropBoundingBox(),
        ConvertToHU(),
        Resize(size),
        Clip(bounds=clip_bounds),
        MaskWatershed(
            min_hu=min(clip_bounds),
            iterations=watershed_iterations,
            show_tqdm=False),
        Normalize(bounds=clip_bounds),
        ToTensor()
    ]))
data.save(dest_dir)
```

From this point on, we use the `CTTensorsDataset` as the interface to ingest the preprocessed tensors, taking the data to training.

In [None]:
class CTTensorsDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = Path(root_dir)
        self.tensor_files = sorted([f for f in self.root_dir.glob('*.pt')])
        self.transform = transform

    def __len__(self):
        return len(self.tensor_files)

    def __getitem__(self, item):
        if torch.is_tensor(item):
            item = item.tolist()

        image = torch.load(self.tensor_files[item])
        if self.transform:
            image = self.transform(image)

        return {
            'patient_id': self.tensor_files[item].stem,
            'image': image
        }

    def mean(self):
        cum = 0
        for i in range(len(self)):
            sample = self[i]['image']
            cum += torch.mean(sample).item()

        return cum / len(self)

    def random_split(self, val_size: float):
        num_val = int(val_size * len(self))
        num_train = len(self) - num_val
        return random_split(self, [num_train, num_val])

### 3.3.1. Checking data pipeline

In [None]:
train = CTTensorsDataset(
    root_dir=root_dir,
    transform=ZeroCenter(pre_calculated_mean=pre_calculated_mean)
)
cum = 0
for i in range(len(train)):
    sample = train[i]['image']
    cum += torch.mean(sample).item()

assert cum / len(train) == pytest.approx(0)

# 4. AutoEncoder
Credits to [Srinjay Paul's great tutorial](https://srinjaypaul.github.io/3D_Convolutional_autoencoder_for_brain_volumes/), and lots of papers (I will link them later).

As mentioned, I strangled the bottleneck to force very few latent features (10). The image below shows the transformations:
![autoencoder](https://i.ibb.co/2hYZFc1/autoencoder.jpg)

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self, latent_features=latent_features):
        super(AutoEncoder, self).__init__()
        # Encoder
        self.conv1 = nn.Conv3d(1, 16, 3)
        self.conv2 = nn.Conv3d(16, 32, 3)
        self.conv3 = nn.Conv3d(32, 96, 2)
        self.conv4 = nn.Conv3d(96, 1, 1)
        self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True)
        self.pool2 = nn.MaxPool3d(kernel_size=3, stride=3, return_indices=True)
        self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True)
        self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True)
        self.fc1 = nn.Linear(10 * 10, latent_features)
        # Decoder
        self.fc2 = nn.Linear(latent_features, 10 * 10)
        self.deconv0 = nn.ConvTranspose3d(1, 96, 1)
        self.deconv1 = nn.ConvTranspose3d(96, 32, 2)
        self.deconv2 = nn.ConvTranspose3d(32, 16, 3)
        self.deconv3 = nn.ConvTranspose3d(16, 1, 3)
        self.unpool0 = nn.MaxUnpool3d(kernel_size=2, stride=2)
        self.unpool1 = nn.MaxUnpool3d(kernel_size=2, stride=2)
        self.unpool2 = nn.MaxUnpool3d(kernel_size=3, stride=3)
        self.unpool3 = nn.MaxUnpool3d(kernel_size=2, stride=2)

    def encode(self, x, return_partials=True):
        # Encoder
        x = self.conv1(x)
        up3out_shape = x.shape
        x, i1 = self.pool1(x)

        x = self.conv2(x)
        up2out_shape = x.shape
        x, i2 = self.pool2(x)

        x = self.conv3(x)
        up1out_shape = x.shape
        x, i3 = self.pool3(x)

        x = self.conv4(x)
        up0out_shape = x.shape
        x, i4 = self.pool4(x)

        x = x.view(-1, 10 * 10)
        x = F.relu(self.fc1(x))

        if return_partials:
            return x, up3out_shape, i1, up2out_shape, i2, up1out_shape, i3, \
                   up0out_shape, i4

        else:
            return x

    def forward(self, x):
        x, up3out_shape, i1, up2out_shape, i2, \
        up1out_shape, i3, up0out_shape, i4 = self.encode(x)

        # Decoder
        x = F.relu(self.fc2(x))
        x = x.view(-1, 1, 1, 10, 10)
        x = self.unpool0(x, output_size=up0out_shape, indices=i4)
        x = self.deconv0(x)
        x = self.unpool1(x, output_size=up1out_shape, indices=i3)
        x = self.deconv1(x)
        x = self.unpool2(x, output_size=up2out_shape, indices=i2)
        x = self.deconv2(x)
        x = self.unpool3(x, output_size=up3out_shape, indices=i1)
        x = self.deconv3(x)

        return x

# 5. Training on TPU
Trying again... Credits to [ilovescience great tutorial](https://www.kaggle.com/tanlikesmath/the-ultimate-pytorch-tpu-tutorial-jigsaw-xlm-r).

In [None]:
mx = AutoEncoder(latent_features=10)
loss_fn = torch.nn.MSELoss()
data = CTTensorsDataset(
    root_dir=root_dir,
    transform=ZeroCenter(pre_calculated_mean=pre_calculated_mean)
)
train_set, val_set = data.random_split(val_size)


def reduce_fn(vals):
    return sum(vals) / len(vals)

In [None]:
def train_loop_fn(data_loader, model, optimizer, device):
    model.train()
    for batch_num, batch in enumerate(data_loader):

        inputs = batch['image'].float().to(device)

        # pass ids to model
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs, inputs)

        if batch_num % 20 == 0:
            # since the loss is on all 8 cores, reduce the loss values
            # and print the average (as defined in reduce_fn)
            loss_reduced = xm.mesh_reduce('loss_reduce', loss, reduce_fn)
            # master_print will only print once (not from all 8 cores)
            xm.master_print(f'{batch_num}: loss={loss_reduced:0.6f}')

        loss.backward()
        xm.optimizer_step(optimizer)

    model.eval() # put model in eval mode for later use

In [None]:
def eval_loop_fn(data_loader, model, device):
    with torch.no_grad():
        loss_fn = torch.nn.MSELoss()
        running_loss = 0.0
        for batch_num, batch in enumerate(data_loader):
            inputs = batch['image'].float().to(device)
            outputs = model(inputs)
            loss = loss_fn(outputs, inputs)
            running_loss += loss * inputs.size(0)

    return running_loss / len(data_loader.dataset)

In [None]:
def run(index, flags):
    # Sets a common random seed - both for initialization and
    # ensuring graph is the same
    torch.manual_seed(flags['seed'])

    # Creates the (distributed) train sampler, which let this process
    # only access its portion of the training dataset
    train_sampler = DistributedSampler(
        train_set,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True)

    val_sampler = DistributedSampler(
        val_set,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False)

    # Creates dataloaders, which load data in batches
    # Note: test loader is not shuffled or sampled
    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=flags['batch_size'],
        sampler=train_sampler,
        num_workers=0,
        drop_last=True)  #

    val_loader = torch.utils.data.DataLoader(
        val_set,
        batch_size=flags['batch_size'],
        sampler=val_sampler,
        num_workers=0,
        drop_last=False)

    # Acquires the (unique) Cloud TPU core corresponding
    # to this process's index
    device = xm.xla_device()

    model = mx.to(device) # put model onto the TPU core
    xm.master_print('done loading model')

    lr = flags['learning_rate'] * xm.xrt_world_size()  # scale the learning rate
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    xm.master_print(met.metrics_report())

    ## Trains
    t0 = time()
    for epoch in range(flags['num_epochs']):
        para_loader = pl.ParallelLoader(train_loader, [device])
        xm.master_print('parallel loader created... training now')
        # call training loop:
        train_loop_fn(para_loader.per_device_loader(device),
                      model, optimizer, device)
        
        xm.master_print(met.metrics_report())

        del para_loader
        gc.collect()

        para_loader = pl.ParallelLoader(val_loader, [device])
        # call evaluation loop
        loss = eval_loop_fn(para_loader.per_device_loader(device),
                            model, device)

        del para_loader
        gc.collect()

        # stats
        loss_reduced = xm.mesh_reduce('loss_reduce', loss, reduce_fn)
        xm.master_print(f'Epoch {epoch} loss: {loss_reduced:0.6f}')
        gc.collect()

    print(f"Process {index} finished evaluation. Time: {time() - t0}")

    # save our model
    xm.save(model.state_dict(), flags['model_file'])

In [None]:
def map_fn(index, flags):
    a = run(index, flags)
    
    
xmp.spawn(map_fn, args=(flags,), nprocs=8, start_method='fork')