# Partial re-implementation of [Prior-aware autoencoders for lung pathology segmentation](https://www.sciencedirect.com/science/article/pii/S1361841522001384)

This paper is made up by 3 models:

1. a Partial Convolutionl Neural Network (PCNN) written as an autoencoder
2. a Normal Appearance Autoencoder model (NAA)
3. the Prior UNet which predicts the lung pathology segmentation mask.

This notebook will implement 2 and 3 the same way as the paper, but will the [Semantic Diffusion Model](https://github.com/WeilunWang/semantic-diffusion-model) instead for 1.

In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
from PIL import Image

## Module 1: Semantic Diffusion Model

In [17]:
# Directly inference from the trained SDM model to generate pathlogy free images
# by running the /scripts/luna16.sh under the Semantic Diffusion Model directory

# Prepare training and test dataset
LESION_PATH = "/home/user/data/prior_train/lesion"
LESION_FREE_PATH = "/home/user/data/prior_train/lesion-free"

class NAAImageDataset(Dataset):
  def __init__(self, lesion_path, lesion_free_path):
    self.lesion_path = lesion_path
    self.lesion_free_path = lesion_free_path
    self.file_list = os.listdir(self.lesion_path)
  
  def __len__(self):
    return len(self.file_list)
  
  def __getitem__(self, idx):
    filename = self.file_list[idx]
    lesion_image = Image.open()

## Module 2: Normal Appearance Autoencoder

In [3]:
# Define the Normal Apperance Autoencoder (NAA). It consists of an input layer of size 512, 5 hidden layers of size 256, 128, 64, 64, 64 and a sampling layer for mean and logvar of size 64
# Define a convolution layer where each conv block consists of a conv2d layer, a batchnorm layer, then another conv2d layer followed by another batch normalization layer and finally a relu layer
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias=True, max_pool=True):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
        self.bn1 = nn.BatchNorm2d(out_channels, eps=1e-03)
        self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
        self.bn2 = nn.BatchNorm2d(out_channels, eps=1e-03)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(2, 2) if max_pool else None
        self.dropout = nn.Dropout(0.2)
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        if self.maxpool is not None:
            x = self.maxpool(x)
        x = self.dropout(x)
        return x

class NAA(nn.Module):
    def __init__(self):
        super(NAA, self).__init__()
        self.encoder = torch.nn.Sequential(
            ConvBlock(1, 8, kernel_size=3, bias=True),
            ConvBlock(8, 16, kernel_size=3, bias=True),
            ConvBlock(16, 32, kernel_size=3, bias=True),
            ConvBlock(32, 64, kernel_size=3, bias=True, max_pool=False),
            ConvBlock(64, 128, kernel_size=3, bias=True, max_pool=False),
        )
        self.mean = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=0, bias=True)
        self.logvar = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=0, bias=True)
        self.mean_relu = nn.ReLU()
        self.logvar_relu = nn.ReLU()
        self.decoder = torch.nn.Sequential(
            nn.ConvTranspose2d(128, 128, kernel_size=3, stride=2, padding=0, bias=True),
            nn.ReLU(),
            nn.Dropout(0.2),
            ConvBlock(128, 64, kernel_size=3, bias=True, max_pool=False),
            nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=0, bias=True),
            nn.ReLU(),
            nn.Dropout(0.2),
            ConvBlock(64, 32, kernel_size=3, bias=True, max_pool=False),
            nn.ConvTranspose2d(32, 32, kernel_size=3, stride=2, padding=0, bias=True),
            nn.ReLU(),
            nn.Dropout(0.2),
            ConvBlock(32, 16, kernel_size=3, bias=True, max_pool=False),
            nn.ConvTranspose2d(16, 16, kernel_size=3, stride=2, padding=0, bias=True),
            nn.ReLU(),
            nn.Dropout(0.2),
            ConvBlock(16, 8, kernel_size=3, bias=True, max_pool=False),
            nn.ConvTranspose2d(8, 8, kernel_size=3, stride=2, padding=0, bias=True),
            nn.ReLU(),
            nn.Dropout(0.2),
            ConvBlock(8, 1, kernel_size=3, bias=True, max_pool=False),
        )
        self.final = nn.Sequential(
            nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.Sigmoid()
        )

    def reparameterize(self, mean, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mean + eps*std

    def forward(self, x):
        x = self.encoder(x)
        mean = self.mean_relu(self.mean(x))
        logvar = self.logvar_relu(self.logvar(x))
        z = self.reparameterize(mean, logvar)
        x = self.decoder(z)
        return x

## Module 3: Prior UNet