# Variational Autoencoder for OSIC

Adapted some of Carlos' work on the 3DCNN autoencoder to do the VAE version (https://www.kaggle.com/carlossouza/osic-autoencoder-training). I've done something similar for the google landmarks recognition challenge so wasn't too much work to adapt it.

## What is a variational autoencoder?

Normal autoencoders take an input and map it one-to-one with an output. However in the variational autoencoder setup, the encoder maps the input to a distribution (we assume that the input was 'generated' by this distribution). Usually this distribution is a multivariate gaussian so each of our latent dimensions is one dimension of the distribution.

Mapping the inputs to a gaussian is a kind of regularization. It is also what allows us to use VAEs to generate new images.

In [None]:
from tqdm.notebook import tqdm
import cv2
import copy
import matplotlib.pyplot as plt
plt.style.use('ggplot')
import pydicom.pixel_data_handlers.gdcm_handler as gdcm_handler 
import scipy
import cv2
import pydicom
import os
from matplotlib import cm
import imageio
from pathlib import Path
from skimage.segmentation import clear_border
from skimage.morphology import ball, disk, dilation, binary_erosion, remove_small_objects, erosion, closing, reconstruction, binary_closing
from skimage.measure import label, regionprops
from skimage.segmentation import clear_border
from skimage.filters import roberts, sobel
from scipy import ndimage as ndi
from skimage import measure, morphology
from scipy.stats import kurtosis
import seaborn as sns
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import torch
import torch.nn as nn
from torch import optim
from sklearn.model_selection import KFold
import random
import copy
from torchvision import models
import torch.multiprocessing as mp
from torch.nn import functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Seed for reproducibility of experiments

In [None]:
def seed_all(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
seed_all()

### Model and Loss function

The VAE loss function has two parts. 
The familiar mean-squared-error: the difference between the input and output. This encourages the decoder to faithfully reconstruct the input.
The 'kullback-liebler divergence': a measure of the difference between two distributions. This term is what forces the encoding to be close to a gaussian.

The main difference between the autoencoder and variational autoencoder model in code is the encoder having both a mean - mu and variance (here log variance) - logvar output. In addition we have a 'reparameterize' function. This takes the encoded distribution defined by the mean and variance, generates a sample from the distribution, and uses that as the input to the decoder. 

The rest of the model is pretty much the same as Carlos' autoencoder but I use 224x224 image size due to some experiments using resnet.

In [None]:
def vae_loss_function(recon_x, x, mu, logvar,KL_weight=1):
    MSE = F.mse_loss(recon_x, x, reduction='none')
    MSE = torch.div(MSE,torch.numel(x))
    MSE = torch.sum(MSE)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    KLD = torch.div(KLD,x.shape[0])
    return MSE, KLD


class VariationalAutoEncoder(nn.Module):
    def __init__(self, latent_features=10):
        super(VariationalAutoEncoder, self).__init__()
        # Encoder
        self.conv1 = nn.Conv3d(1, 16, 3)
        self.conv2 = nn.Conv3d(16, 32, 3)
        self.conv3 = nn.Conv3d(32, 96, 2)
        self.conv4 = nn.Conv3d(96, 1, 1)
        self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True)
        self.pool2 = nn.MaxPool3d(kernel_size=3, stride=3, return_indices=True)
        self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True)
        self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True)
        self.fc1_mu = nn.Linear(8*8, latent_features)
        self.fc1_logvar = nn.Linear(8*8, latent_features)
        # Decoder
        self.fc2 = nn.Linear(latent_features, 8*8)
        self.deconv0 = nn.ConvTranspose3d(1, 96, 1)
        self.deconv1 = nn.ConvTranspose3d(96, 32, 2)
        self.deconv2 = nn.ConvTranspose3d(32, 16, 3)
        self.deconv3 = nn.ConvTranspose3d(16, 1, 3)
        self.unpool0 = nn.MaxUnpool3d(kernel_size=2, stride=2)
        self.unpool1 = nn.MaxUnpool3d(kernel_size=2, stride=2)
        self.unpool2 = nn.MaxUnpool3d(kernel_size=3, stride=3)
        self.unpool3 = nn.MaxUnpool3d(kernel_size=2, stride=2)

    def encode(self, x, return_partials=True):
        # Encoder
        x = self.conv1(x)
        up3out_shape = x.shape
        x, i1 = self.pool1(x)

        x = self.conv2(x)
        up2out_shape = x.shape
        x, i2 = self.pool2(x)

        x = self.conv3(x)
        up1out_shape = x.shape
        x, i3 = self.pool3(x)

        x = self.conv4(x)
        up0out_shape = x.shape
        x, i4 = self.pool4(x)

        x = x.view(-1, 8*8)
        mu = F.relu(self.fc1_mu(x))
        logvar = F.relu(self.fc1_logvar(x))

        if return_partials:
            return mu,logvar, up3out_shape, i1, up2out_shape, i2, up1out_shape, i3, \
                   up0out_shape, i4

        else:
            return x
    
    def reparameterize(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = Variable(std.data.new(std.size()).normal_())
            return eps.mul(std).add_(mu)
        else:
            return mu

    def decode(self, z, up3out_shape, i1, up2out_shape, i2, \
        up1out_shape, i3, up0out_shape, i4):
        # Decoder
        x = F.relu(self.fc2(z))
        x = x.view(-1, 1, 1, 8, 8)
        x = self.unpool0(x, output_size=up0out_shape, indices=i4)
        x = self.deconv0(x)
        x = self.unpool1(x, output_size=up1out_shape, indices=i3)
        x = self.deconv1(x)
        x = self.unpool2(x, output_size=up2out_shape, indices=i2)
        x = self.deconv2(x)
        x = self.unpool3(x, output_size=up3out_shape, indices=i1)
        x = self.deconv3(x)
        return x
        
    def forward(self, x):
        mu,logvar, up3out_shape, i1, up2out_shape, i2, \
        up1out_shape, i3, up0out_shape, i4 = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_reconst = self.decode(z,up3out_shape, i1, up2out_shape, i2, \
        up1out_shape, i3, up0out_shape, i4)
        return x_reconst, z, mu, logvar

In [None]:
def plot_training_loss(train, val,title='loss'):
    plt.figure()
    plt.plot(train, label='Train')
    plt.plot(val, label='Val')
    if title=='loss':
        plt.title('Model Training Loss')
    else:
        plt.title('Model Metric Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.yscale('log')
    plt.legend()
    plt.savefig('training_loss')

### Helper functions for preprocessing image data

In [None]:
# Load the scans in given folder path
def load_scan(path):

    #slices = [pydicom.read_file(path / s) for s in os.listdir(path)]
    slices = [pydicom.read_file(path / s) for s in os.listdir(path)]
    slices.sort(key = lambda x: float(x.ImagePositionPatient[2]))
    try:
        slice_thickness = np.abs(slices[0].ImagePositionPatient[2] - slices[1].ImagePositionPatient[2])
    except:
        slice_thickness = np.abs(slices[0].SliceLocation - slices[1].SliceLocation)
    if slice_thickness==0:
        slice_thickness=slices[0].SliceThickness
    for s in slices:
        s.SliceThickness = slice_thickness
        
    return slices

def get_pixels_hu(slices):
    image = np.stack([np.array(s.pixel_array,dtype=np.int16) for s in slices])
    # Convert to int16 (from sometimes int16), 
    # should be possible as values should always be low enough (<32k)
    image = image.astype(np.int16)

    # Set outside-of-scan pixels to 0
    # The intercept is usually -1024, so air is approximately 0
    image[image == -2000] = 0
    
    # Convert to Hounsfield units (HU)
    for slice_number in range(len(slices)):
        
        intercept = slices[slice_number].RescaleIntercept
        slope = slices[slice_number].RescaleSlope
        
        if slope != 1:
            image[slice_number] = slope * image[slice_number].astype(np.float64)
            image[slice_number] = image[slice_number].astype(np.int16)
            
        image[slice_number] += np.int16(intercept)
    
    return np.array(image, dtype=np.int16)

def resample(image, scan, new_spacing=[1,1,1]):
    # Determine current pixel spacing
    #spacing = np.array([scan[0].SliceThickness] + scan[0].PixelSpacing, dtype=np.float32)
    spacing = np.array([scan[0].SliceThickness] + list(scan[0].PixelSpacing), dtype=np.float32)
    resize_factor = spacing / new_spacing
    new_real_shape = image.shape * resize_factor
    new_shape = np.round(new_real_shape)
    real_resize_factor = new_shape / image.shape
    new_spacing = spacing / real_resize_factor
    
    image = scipy.ndimage.interpolation.zoom(image, real_resize_factor, mode='nearest')
    return image, new_spacing

def get_segmented_lungs(im, plot=False):
    
    '''
    This funtion segments the lungs from the given 2D slice.
    '''
    if plot == True:
        f, plots = plt.subplots(8, 1, figsize=(5, 40))
    '''
    Step 1: Convert into a binary image. 
    '''
    binary = im < -200
    if plot == True:
        plots[0].axis('off')
        plots[0].imshow(binary, cmap=plt.cm.bone) 
    '''
    Step 2: Remove the blobs connected to the border of the image.
    '''
    cleared = clear_border(binary)
    if plot == True:
        plots[1].axis('off')
        plots[1].imshow(cleared, cmap=plt.cm.bone) 
    '''
    Step 3: Label the image.
    '''
    label_image = label(cleared)
    if plot == True:
        plots[2].axis('off')
        plots[2].imshow(label_image, cmap=plt.cm.bone) 
    '''
    Step 4: Keep the labels with 2 largest areas.
    '''
    areas = [r.area for r in regionprops(label_image)]
    areas.sort()
    if len(areas) > 2:
        for region in regionprops(label_image):
            if region.area < areas[-2]:
                for coordinates in region.coords:                
                       label_image[coordinates[0], coordinates[1]] = 0
    binary = label_image > 0
    if plot == True:
        plots[3].axis('off')
        plots[3].imshow(binary, cmap=plt.cm.bone) 
    '''
    Step 5: Erosion operation with a disk of radius 2. This operation is 
    seperate the lung nodules attached to the blood vessels.
    '''
    selem = disk(2)
    binary = binary_erosion(binary, selem)
    if plot == True:
        plots[4].axis('off')
        plots[4].imshow(binary, cmap=plt.cm.bone) 
    '''
    Step 6: Closure operation with a disk of radius 10. This operation is 
    to keep nodules attached to the lung wall.
    '''
    selem = disk(10)
    binary = binary_closing(binary, selem)
    if plot == True:
        plots[5].axis('off')
        plots[5].imshow(binary, cmap=plt.cm.bone) 
    '''
    Step 7: Fill in the small holes inside the binary mask of lungs.
    '''
    edges = roberts(binary)
    binary = ndi.binary_fill_holes(edges)
    if plot == True:
        plots[6].axis('off')
        plots[6].imshow(binary, cmap=plt.cm.bone) 
    '''
    Step 8: Superimpose the binary mask on the input image.
    '''
    get_high_vals = binary == 0
    im[get_high_vals] = 0
    if plot == True:
        plots[7].axis('off')
        plots[7].imshow(im, cmap=plt.cm.bone) 
        
    return im



### Helper functions for loading and preprocessing data

In [None]:
def load_and_prepare_data(add_pixel_stats=True):
    train=pd.read_csv('../input/osic-pulmonary-fibrosis-progression/train.csv')
    test=pd.read_csv('../input/osic-pulmonary-fibrosis-progression/test.csv')
    submission=pd.read_csv('../input/osic-pulmonary-fibrosis-progression/sample_submission.csv')
    
    #Prepare Train Data
    train['base_Weeks']=train.groupby(['Patient'])['Weeks'].transform('min')
    base=train[train.Weeks==train.base_Weeks]
    base = base.rename(columns={'FVC': 'base_FVC','Percent': 'base_Percent'})
    base.drop_duplicates(subset=['Patient', 'Weeks'], keep='first',inplace=True)
    train=train.merge(base[['Patient','base_FVC','base_Percent']],on='Patient',how='left')
    train['Week_passed'] = train['Weeks'] - train['base_Weeks']
    
    test = test.rename(columns={'Weeks': 'base_Weeks', 'FVC': 'base_FVC','Percent': 'base_Percent'})
    # Adding Sample Submission
    submission = pd.read_csv("../input/osic-pulmonary-fibrosis-progression/sample_submission.csv")
    # In submisison file, format: ID_'week', using lambda to split the ID
    submission['Patient'] = submission['Patient_Week'].apply(lambda x:x.split('_')[0])
    # In submisison file, format: ID_'week', using lambda to split the Week
    submission['Weeks'] = submission['Patient_Week'].apply(lambda x:x.split('_')[1]).astype(int)
    test = submission.drop(columns = ["FVC", "Confidence"]).merge(test, on = 'Patient')
    test['Week_passed'] = test['Weeks'] - test['base_Weeks']
    test=test[train.columns.drop(['FVC','Percent'])]
    
    if add_pixel_stats:
        pixel_stats=pd.read_csv('../input/osic-histogram-features/train_pixel_stats.csv')
        train=train.merge(pixel_stats[['Patient','kurtosis','std','mean','median']],how='left',on='Patient')
        test_ids=test.Patient.unique()
        root_dir = Path('/kaggle/input/osic-pulmonary-fibrosis-progression')
        ct_scans_dir=root_dir/'test'
        pixel_stats_test=test.copy()
        pixel_stats_test.drop_duplicates(subset=['Patient'],inplace=True)
        k,s,m,me=get_kurtosis_stats(test_ids,ct_scans_dir)
        pixel_stats_test['kurtosis']=np.array(k)
        pixel_stats_test['std']=np.array(s)
        pixel_stats_test['mean']=np.array(m)
        pixel_stats_test['median']=np.array(me)
        test=test.merge(pixel_stats_test[['Patient','kurtosis','std','mean','median']],how='left',on='Patient')
    return train, test

def OH_encode(train,test):
    #OH Encoding of categorical variables (https://www.kaggle.com/ulrich07/osic-keras-starter-with-custom-metrics)
    COLS = ['Sex','SmokingStatus']
    for col in COLS:
        for mod in train[col].unique():
            train[mod] = (train[col] == mod).astype(int)
            test[mod] = (test[col] == mod).astype(int)
        train.drop(col,axis=1,inplace=True)
        test.drop(col,axis=1,inplace=True)
    return train, test

def Scale(train):
    from sklearn import preprocessing
    robust_scaler = preprocessing.RobustScaler()
    train.loc[:,train.columns.difference(['Patient','FVC','Percent','Weeks','base_Weeks'])]=robust_scaler.fit_transform(train.loc[:,train.columns.difference(['Patient','FVC','Percent','Weeks','base_Weeks'])])
    return robust_scaler   

### Pytorch Dataset class
Neatened up since some of my other notebooks!

In [None]:
class OSIC(Dataset):
    def __init__(self,patient_ids,df,scaler=None,train=True,add_pixel_stats=True,impute_vals=None):
        root_dir = Path('/kaggle/input/osic-pulmonary-fibrosis-progression')
        self.df=df.copy()
        self.df=self.df.loc[self.df.Patient.isin(patient_ids),:]
        if not train:
            ct_scans_dir=root_dir/'test'
        else:
            ctscans_dir=root_dir/'train'
        self.df.loc[:,self.df.columns.difference(['Patient','FVC','Percent','Weeks','base_Weeks'])]=scaler.transform(self.df.loc[:,self.df.columns.difference(['Patient','FVC','Percent','Weeks','base_Weeks'])])
        self.data=self.df.loc[:,self.df.columns.difference(['FVC','Patient','Percent'])].values
        if train:
            self.impute_vals=np.nanmean(self.data, axis=0)
        else:
            self.impute_vals=impute_vals
        inds = np.where(np.isnan(self.data))
        self.data[inds] = np.take(self.impute_vals, inds[1])
        self.patients=self.df['Patient'].unique()
        self.train=train
        if self.train:
            self.fvc=self.df['FVC'].values
    
    def __len__(self):
        return len(self.patients)

    def __getitem__(self, idx):
        patient_id = self.patients[idx]
        if os.path.isfile('/kaggle/input/processed-osic/processed/'+patient_id+'.npy'):
            image = np.load('/kaggle/input/processed-osic/processed/'+patient_id+'.npy') 
        else:
            try:
                patient_path= ctscans_dir / patient_id
                scan = load_scan(patient_path)
                image=get_pixels_hu(scan)
                image, new_spacing = resample(image, scan, new_spacing=[1,1,1])
                image = segment_lung_mask(image, False)
            except:
                image=np.zeros((1,224,224))
       
        image=image.astype(np.uint8)
        padded_image=np.zeros((1,50,224,224))
        max_ind_i=min(50,image.shape[0])
        max_ind_j=min(224,image.shape[1])
        max_ind_k=min(224,image.shape[2])
        padded_image[:,:max_ind_i,:max_ind_j,:max_ind_k] = image[:max_ind_i,:max_ind_j,:max_ind_k]
        all_fvc=np.zeros(146)
        patient_weeks=self.df['Weeks'][self.df.Patient==patient_id].values+12
        all_fvc[patient_weeks]=self.fvc[self.df.Patient==patient_id]
        base_fvc=self.fvc[self.df.Patient==patient_id][0]
        if self.train:
            data = {'data': self.data[idx],
                    'image': padded_image,
                    'allfvc': all_fvc,
                    'base_fvc': base_fvc}
        else:
            
            data = {'data': self.data[idx]}
        return data

### Functions to train and validate a single epoch (pass over the data)

In [None]:
def train_epoch(model,optimizer,train_loader,KL_weight):
    criterion = nn.CrossEntropyLoss()
    epoch_losses=[]
    model.train()
    epoch_loss=0
    epoch_kld=0
    epoch_mse=0
    for i, data in enumerate(train_loader): 
        batch_size,_, _, _, _ = data['image'].shape
        optimizer.zero_grad()
        img = data['image'].float().to(device)
        X_reconst, z, mu, logvar = model(img)  # VAE
        #NOTE THE WARM-UP on the VAE LOSS function
        MSE, KLD = vae_loss_function(X_reconst, img, mu, logvar,KL_weight)
        loss = KL_weight*KLD+MSE
        epoch_loss+=loss.item()
        epoch_mse+=MSE.item()
        epoch_kld+=KLD.item()
        loss.backward()
        optimizer.step()
        
    return model,optimizer,epoch_loss,epoch_mse,epoch_kld

def val_epoch(model,val_loader):
    criterion = nn.CrossEntropyLoss()
    all_y,all_z, all_mu, all_logvar = [], [], [], []
    model.eval()
    epoch_loss=0
    epoch_kld=0
    epoch_mse=0
    for i, data in enumerate(val_loader): 
        batch_size, _, _, _, _ = data['image'].shape
        img = data['image'].float().to(device)
        X_reconst, z, mu, logvar = model(img)  # VAE
        MSE, KLD = vae_loss_function(X_reconst, img, mu, logvar)
        loss = KLD+MSE
        epoch_loss+=loss.item()
        epoch_mse+=MSE.item()
        epoch_kld+=KLD.item()
        y=data['base_fvc']
        all_y.extend(y.data.cpu().numpy())
        all_z.extend(z.data.cpu().numpy())
        all_mu.extend(mu.data.cpu().numpy())
        all_logvar.extend(logvar.data.cpu().numpy())
    return epoch_loss,epoch_mse,epoch_kld, all_y,all_z, all_mu, all_logvar

## Outer loop to train a model

In [None]:
def train_model(ids,train,max_epochs=100, patience=40,batch_size=1, plot_losses=True,add_pixel_stats=True):
    
    np.random.shuffle(ids)
    train_ids,val_ids=np.split(ids, [int(round(0.9 * len(ids), 0))])
    
    scaler=Scale(train.loc[train.Patient.isin(train_ids),:].copy())
    train_dataset = OSIC(train_ids,train,scaler=scaler,add_pixel_stats=add_pixel_stats)  
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size,shuffle=True,num_workers=3)
    val_dataset = OSIC(val_ids,train,scaler=scaler,impute_vals=train_dataset.impute_vals,add_pixel_stats=add_pixel_stats)  
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size,shuffle=True,num_workers=3)
    model = VariationalAutoEncoder().to(device)
    print('Number of parameters:')
    print(sum(p.numel() for p in model.parameters() if p.requires_grad))
    
    epoch_train_metric=[]
    epoch_val_metric=[]
    epoch_train_loss=[]
    epoch_val_loss=[]
    
    min_val_loss = 1e+100
    min_val_metric = 1e+100
    early_stop = False
    epoch=0
    optimizer = optim.Adam(model.parameters(),lr=5e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.05)
    while epoch<max_epochs and not early_stop:
        epoch+=1
        if epoch<5:
            KL_weight=0
        else:
            KL_weight=epoch/max_epochs
        model, optimizer, train_loss,train_mse,train_kld = train_epoch(model,optimizer,train_dataloader,KL_weight)
        scheduler.step()
        print('====> Epoch: {} Average train loss: {:.4f}'.format(
                        epoch, train_loss/(batch_size*len(train_dataloader))))
        print('====> Epoch: {} Average train mse: {:.4f}'.format(
                        epoch, train_mse/(batch_size*len(train_dataloader))))
        print('====> Epoch: {} Average train kld: {:.4f}'.format(
                        epoch, train_kld/(batch_size*len(train_dataloader))))

        val_loss,val_mse,val_kld, all_y,all_z, all_mu, all_logvar = val_epoch(model,val_dataloader)
        print('====> Epoch: {} Average val loss: {:.4f}'.format(
                        epoch, val_loss/(batch_size*len(val_dataloader))))
        print('====> Epoch: {} Average val mse: {:.4f}'.format(
                        epoch, val_mse/(batch_size*len(val_dataloader))))
        print('====> Epoch: {} Average val kld: {:.4f}'.format(
                        epoch, val_kld/(batch_size*len(val_dataloader))))
        
        
        epoch_train_loss.append(train_loss)
        epoch_val_loss.append(val_loss)
        
    if plot_losses:
        plot_training_loss(epoch_train_loss, epoch_val_loss)
        plot_training_loss(epoch_train_metric, epoch_val_metric,title='metric')
    return model, scaler, train_dataset.impute_vals, train_ids, val_ids, all_y,all_z, all_mu, all_logvar

In [None]:
train,test=load_and_prepare_data(add_pixel_stats=False)
train,test=OH_encode(train,test)
ids=train.Patient.unique()
model, scaler, impute_vals, train_ids, val_ids,all_y,all_z, all_mu, all_logvar=train_model(ids,train,max_epochs=100, patience=40,batch_size=4, plot_losses=True,add_pixel_stats=True)

## Inspection
When we wish to inspect the output of the model, we can take just the mean of the encoded distribution as the input to the decoder.

In [None]:
slc = 0.5
val_dataset = OSIC(val_ids,train,scaler=scaler,impute_vals=impute_vals)
sample_id = np.random.randint(len(val_dataset))
print(f'Inspecting CT Scan {val_dataset.patients[sample_id]}')

fig, axs = plt.subplots(1, 2, figsize=(10, 7))

sample = val_dataset[sample_id]['image'].squeeze(0)
axs[0].imshow(sample[int(40 * slc), :, :], cmap=cm.bone)
axs[0].axis('off')
imageio.mimsave("sample_input.gif", sample, duration=0.0001)

with torch.no_grad():
    img = torch.tensor(val_dataset[sample_id]['image']).unsqueeze(0).float().to(device)
    latent_features = model.encode(img, return_partials=False)\
        .squeeze().cpu().numpy().tolist()
    outputs = model(img)[0].squeeze().cpu().numpy()

axs[1].imshow(outputs[int(40 * slc), :, :], cmap=cm.bone)
axs[1].axis('off')

imageio.mimsave("sample_output.gif", outputs, duration=0.0001)

rmse = ((sample - outputs)**2).mean()
plt.show()
print(f'Latent features: {latent_features} \nLoss: {rmse}')

In [None]:
l=((outputs[int(40 * slc), :, :]-sample[int(40 * slc), :, :])**2)
(l/l.size).sum()

In [None]:
from IPython.display import HTML
HTML('<br/><img src="https://i.ibb.co/gFxgRq6/sample-input.gif" style="float: left; width: 30%; margin-right: 1%; margin-bottom: 0.5em;">'
     '<img src="https://i.ibb.co/Jm57fWw/sample-output.gif" style="float: left; width: 30%; margin-right: 1%; margin-bottom: 0.5em;">'
     '<p style="clear: both;">')

## What else might we try?


### Identify clusters of patients in the latent space?


In [None]:
from sklearn.manifold import TSNE

z_embed = TSNE(n_components=2, n_iter=12000).fit_transform(all_z)

fig = plt.figure(figsize=(12, 10))
plots = []
plt.scatter(z_embed[:, 0], z_embed[:, 1], c=all_y, s=8)

plt.axis('off')
plt.title('t-SNE: 2-dim')
#plt.savefig("./ResNetVAE_{}_embedded_plot.png".format(exp), bbox_inches='tight', dpi=600)
plt.show()

### Identify anomalous images 

If a particular patient is outside of the usual distribution for lungs, then the reconstruction error is likely to be higher. We could use this to identify patients with an unusual PF trajectory.

Or with a 2D version of the VAE on either the full lung or patches of the lung, we could identify regions that are unusual and perhaps contain symtpoms of PF.

TBC
