<h1 class="text-center">Generative Modeling of Neuroimaging Data using Generative
Adversarial Networks</h1>
<br>
<h6 class="text-center">Andrew Van</h6>
<h6 class="text-center">Advisor: Nico Dosenbach</h6>
<script type="text/javascript">
$(window).load(function(){
    Reveal.configure({
        transition: 'fade' // none/fade/slide/convex/concave/zoom
    })
});    
</script>

In [1]:
from IPython.display import HTML
HTML('''<script>
code_show=true; 
function code_toggle() {
 if (code_show){
   $('div.input').hide();
   $('div.prompt.output_prompt').css('opacity', 0);
 } else {
   $('div.input').show();
   $('div.prompt.output_prompt').css('opacity', 1);
 }
 code_show = !code_show
} 
$(document).ready(code_toggle);
</script>
<a href="javascript:code_toggle()"><button>Toggle Code</button></a>
''')

<h3 class="text-center">Overview</h3>

- What are GANs?
    - Theory
    - Limitations
- Model Training
    - Progressive GAN and Wasserstein Distance (How do these address limitations?)
    - Results
- Applications
    - Reconstruction in Underdetermined Systems
    - Anomaly Detection
- Future Directions

<h3 class="text-center">What are Generative Adversarial Networks (GANs)?</h3>

<h3 class="text-center">What are GANs?</h3>

- Take two networks and train them in an adversarial manner...

<div>
<img style="display:block;margin-right:auto;margin-left:auto;" src="images/GANdiag.png" width=60% height=auto>
</div>
<p style="text-align: center">GAN Training Process</p>

<h3 class="text-center">What are GANs?</h3>

- ...and they can make really convincing fake images.

<div>
<img style="display:block;margin-right:auto;margin-left:auto;" src="images/stylegan-teaser.png" width=90% height=auto>
</div>
<p style="text-align: center">These are not real people! <a href="https://arxiv.org/abs/1812.04948" target="_blank">StyleGAN</a></p>

<h3 class="text-center">Fake Images? So What?</h3>

- GANs capture an underlying process that generated our dataset.
    - Creating realistic fake examples is the byproduct of a well-trained model.
- A data-driven prior!
        
<div>
<img style="display:block;margin-right:auto;margin-left:auto;" src="images/imageinpainting.png" width=75% height=auto>
</div>
<p style="text-align: center"><a href="https://arxiv.org/abs/1604.07379" target="_blank">Image Inpainting</a></p>

<h3 class="text-center">What is Generative Modeling?</h3>

- Discriminative models
    - Map high-dimensional data to a class label
    - e.g. Given an image, is it a cat/dog?
- Generative models answer a somewhat opposite question
    - Given a sample drawn from a distribution, we want to learn an estimate of that distribution
    - e.g. Given a set of cats/dogs, learn the images that generate each class.

<div>
<img style="display:block;margin-right:auto;margin-left:auto;" src="images/density_estimation.png" width=75% height=auto>
</div>
<p><center>Generative Modeling: Estimating a distribution from sample</center></p>

<h3><center>Taxonomy of generative modeling</center></h3>

<div>
    <img style="display:block;margin:auto;" src="images/taxonomy.png" width=80% height=auto>
</div>
<p><center>Types of generative models (maximum likelihood)</center></p>

<h3 class="text-center">GANs compared to other generative models</h3>

- GANs are implicit density models
    - A black box: we don't have access to the actual distribution, we can only draw samples from it.
- Explicit models define the distribution explicitly. We have a known parameterization of our model.
    - The difficulty comes in having an explicit model complex enough to capture the data complexity of dataset, while still being computationally tractable.
        - Bound on complexity of model to ensure tractability
    - e.g. Fully Visible Belief Networks (FVBNs), Variational Autoencoders
- Advantages
    - More computationally tractable than explicit models
        - Does not use any Markov Chains
        - Can be trained using back-propagation
    - No approximations needed. Can learn any dataset distribution (given enough examples and layers).
    - Subjectively, the research community has found that GANs provide better samples than any other method.

<h3 class="text-center">GAN Theory</h3>

- Two networks: generator and discriminator setup in a zero-sum game
    - Discriminator Objective: Pick out fake from real
    - Generator Objective: Create a fake indistinguishable from real

<div>
<img style="display:block;margin-right:auto;margin-left:auto;" src="images/GANdiag.png" width=60% height=auto>
</div>
<p style="text-align: center">GAN Training Process</p>

<h3 class="text-center">GAN Theory</h3>

- This is a two-player minimax game:

$$ \min_{G} \max_{D} \underset{\mathbf{x} \sim \mathbb{P}_{r}}{\mathbb{E}}[\log D(\mathbf{x})] + \underset{\mathbf{z} \sim \mathbb{P}_{\mathbf{z}}}{\mathbb{E}}[\log (1 - D(G(\mathbf{z})))] $$

- From game theory, we find the Nash Equilibirum for this system
    - We reach Nash Equilibrium when $\mathbb{P}_{r} = \mathbb{P}_{g}$, and $D(x) = 0.5$ for all inputs.

<h3 class="text-center">Limitations</h3>

- Difficult to acheive Nash Equilibirum
- Low Support
- Vanishing Gradients
- Mode Collapse
- Lack of Evaluation Metric

<h3 class="text-center">Difficult to acheive Nash Equilibirum</h3>

- Minimax optimization may not converge stably
- Example: Minimaxing f(x,y) = xy
    - Signs of x,y are opposite --> oscillations

<div>
<img style="display:block;margin-right:auto;margin-left:auto;" src="images/nash_equilibrium.png" width=80% height=auto>
</div>
<p style="text-align: center">Instability of objective functions</p>

<h3 class="text-center">Low Support</h3>

- When two distributions lie in a low dimensional manifold, they are (with very high probability) disjoint
    - A perfect discriminator exists that can separate both exactly

<div>
<img style="display:block;margin-right:auto;margin-left:auto;" src="images/low_dim_manifold.png" width=80% height=auto>
</div>
<p style="text-align: center">Distributions of both networks lie in a lower dimensional space</p>

<h3 class="text-center">Vanishing Gradients</h3>

- Low support can lead to the discriminator improving too fast
    - If the discriminator is perfect too early, generator doesn't have any gradients to backprop on
- So slow discriminator learning?
    - If the discriminator is bad, generator doesn't learn accurately

<div>
<img style="display:block;margin-right:auto;margin-left:auto;" src="images/GAN_vanishing_gradient.png" width=60% height=auto>
</div>
<p style="text-align: center">Vanishing Gradients</p>

<h3 class="text-center">Mode Collapse</h3>

- GAN collapse the value of several inputs to the same output
- Why does it happen?
    - Max-Min instead of Min-Max?

<div>
<img style="display:block;margin-right:auto;margin-left:auto;" src="images/mode_collapse.png" width=100% height=auto>
</div>
<p style="text-align: center">Helvetica Scenario</p>

<h3 class="text-center">Lack of Evaluation Metric</h3>

- Unclear how to universally evaluate GANs quantitatively at present
- Popular measure in many GAN papers is Frechet Inception Distance or Inception Score
    - Only works for natural image scenes, unsuitable for medical imaging applications
- [Boraj 2018](https://arxiv.org/abs/1802.03446) gives comprehensive review of proposed evaluation method (will touch on this again)

<h3 class="text-center">Model Training</h3>

<h3 class="text-center">Dataset Characterization</h3>

- Human Connectome Project (HCP) 1200 Young Adult Dataset
- 1113 T1s, 1783 Unique Volumes
    - 1707 Unique Volumes after QC applied (issue code A)
- 256 x 256 x 320, 0.7 mm voxel
- Volumes unstacked in axial direction
    - 256 slices of 256 x 320, zero-padded to 512 x 512
- Total dataset size: 546,240 TI slices of 512 x 512 resolution.

<h3 class="text-center">Progressive GAN</h3>

- Grows the architecture of the network during training
    - Stablilizes training, since network has to learn less features first

<div>
<img style="display:block;margin-right:auto;margin-left:auto;" src="images/progarch.png" width=70% height=auto>
</div>
<p style="text-align: center">Progressive GAN Architecture</p>

<h3 class="text-center">Progressive GAN</h3>

- Transition to next resolution no occurs smoothly
    - During transition, acts a residual block passing in lower resolution image to the output

<div>
<img style="display:block;margin-right:auto;margin-left:auto;" src="images/proggrow.png" width=70% height=auto>
</div>
<p style="text-align: center">Growth Block</p>

<h3 class="text-center">Wasserstein Distance</h3>

- Instead of the original GAN loss function, we use Wasserstein Distance
    - Wasserstein distance attempts to solve the problem of low support and vanishing gradients
        - Also solves mode collapse? 
    - Replaces the Discriminator with a Critic; Gives a measure of how different the fake samples are from the real (contrast with original discriminator which does a binary choice)

<div>
<img style="display:block;margin-right:auto;margin-left:auto;" src="images/WD.png" width=50% height=auto>
</div>
<p style="text-align: center">Wasserstein Distance</p>

<h3 class="text-center">Wasserstein Distance with Gradient Penalty</h3>

- Wasserstein Distance has 1-Lipshitz constraint (which was actually shown to be gradient norm = 1)
    - Gradient of critic must be equal to 1 everywhere
    - We add gradient penalty term to satisfy this constraint

$$ \min_{G} \max_{D} \underset{\mathbf{z} \sim \mathbb{P}_{\mathbf{z}}}{\mathbb{E}}[D(G(\mathbf{z}))] - \underset{\mathbf{x} \sim \mathbb{P}_{r}}{\mathbb{E}}[D(\mathbf{x})] + \lambda \underset{\hat{\mathbf{x}} \sim \mathbb{P}_{\hat{\mathbf{x}}}}{\mathbb{E}}[(\lVert \nabla_{\hat{\mathbf{x}}} D(\hat{\mathbf{x}}) \rVert_{2} - 1)^{2}] $$

<h3 class="text-center">Training Parameters</h3>

- Trained with TensorFlow on NVIDIA Tesla V100 GPU on AWS/NVIDIA Geforce RTX 2080 Ti locally
- Adam Optimizer with initial learning rate set to 1e-3 for each resolution layer
- Resolution doubling occured every 1.2 million images (600k transition, 600k iterations @ resolution)
- Mini-batch sizes: 4 × 4: 128, 8 × 8: 128, 16 × 16: 128, 32 × 32: 64, 64 × 64: 32, 128 × 128: 16, 256 × 256: 8, and 512 × 512: 4

<h3 class="text-center">Results</h3>

<h3 class="text-center">Training Process</h3>
<br>
<video style="display:block;margin-right:auto;margin-left:auto;" width=90% height=auto controls loop>
    <source src="videos/train.webm" type="video/webm">
</video>
<p style="text-align: center">Snapshots throughout training</p>

<h3 class="text-center">Loss Functions</h3>

<div>
<img style="display:block;margin-right:auto;margin-left:auto;" src="images/loss.png" width=50% height=auto>
</div>
<p style="text-align: center">Wasserstein Distance</p>

<h3 class="text-center">Fake vs. Real</h3>
<br>
<div>
<img style="display:block;margin-right:auto;margin-left:auto;" src="images/real.png" width=90% height=auto>
</div>

<p style="text-align: center">Real Images</p>

<h3 class="text-center">Fake vs. Real</h3>
<br>
<div>
<img style="display:block;margin-right:auto;margin-left:auto;" src="images/fake.png" width=90% height=auto>
</div>

<p style="text-align: center">Fake Images</p>

<h3 class="text-center">Latent Vector Walk</h3>
<br>
<video style="display:block;margin-right:auto;margin-left:auto;" width=60% height=auto controls loop>
    <source src="videos/interp.webm" type="video/webm">
</video>
<p style="text-align: center">Snapshots throughout training</p>

<h3 class="text-center">Nearest Neighbor</h3>

- Output closest matching image in training set

<div>
<img style="margin-right:0;margin-left:0;margin-top:0;float:left" src="images/fakeimg.png" width=45% height=auto>
<img style="margin-right:0;margin-left:0;margin-top:0;float:right" src="images/realimg.png" width=45% height=auto>
</div>

<p style="text-align: center">(left) random output from generator (right) nearest neighbor to fake image in training set</p>

<h3 class="text-center">How to evaluate GANs?</h3>

- What to use?

<div>
<img style="display:block;margin-right:auto;margin-left:auto;" src="images/evaltable.png" width=75% height=auto>
</div>

<h3 class="text-center">Task-Based Assessments of Image Quality (TAIQ)</h3>

- Evaluation of GANs based on their performance on a task
    - This would be an evaluation of a system that a GAN is one component of
    - e.g. Looking at reconstruction accuracy when using a GAN based reconstruction algorithm using various trained GAN models
    - e.g. Doing an ROC analysis on a GAN based anomaly detector
- TAIQ gives relevant metrics on the model with tasks that we care about
    - Downside: Does not tell you how a GAN model fails, may still need other independent metrics for mode collapse, sample fidelity, etc.

<h3 class="text-center">Applications</h3>

<h3 class="text-center">Underdetermined Reconstruction</h3>

- In image reconstruction, we are interested in the following problem:

$$y = Ax + \eta$$

where $x \in \mathbb{R}^{n}$ is the object to reconstruct, $y \in \mathbb{R}^{m}$ describes the measurements, $A \in \mathbb{R}^{m \times n}$ is the system operator/measurement matrix, and $\eta$ is additive noise. The goal is to recover the object, $x$, using measurements from our system, $y$.

<h3 class="text-center">Underdetermined Reconstruction with Sparsity</h3>

- In reconstruction applications like Super-Resolution or Compressed Sensing, the system is underdetermined.
    - The $rank(A) < n$.
- In order to guarantee unique recovery, further assumptions about the data must be used.
    - One possible solution: sparsity (commonly used in compressed sensing)
    - Unique recovery of an underdetermined system using sparsity assumptions through $\ell_{1}$-regularization.
   
$$ \min_{x} \lVert x \rVert_{1}\\ s.t.\ Ax = y \nonumber $$

<h3 class="text-center">Underdetermined Reconstruction with GANs</h3>

- A stronger prior: reconstructions that lie in the range of a well-trained generator.
    
$$ Find\ \hat{\mathbf{x}} = G(\hat{\mathbf{z}}) \\ s.t.\ \hat{\mathbf{z}} = \arg \min_{\mathbf{z}} \lVert AG(\mathbf{z}) - y \rVert_{2}^{2}$$

<div>
<img style="display:block;margin-right:auto;margin-left:auto;" src="images/lassogen.png" width=60% height=auto>
</div>

<h3 class="text-center">Anomaly Detection</h3>

- Anomaly: a sample that is different from the rest of the dataset
    - We are interested in anomaly detection for Quality Control
- Deep learning methods like GANs can capture high dimensional datasets
    - Potential improvement in anomaly detection performance

<h3 class="text-center">AnoGAN</h3>

- Finding anomalies through residual and discrimination loss

$$ L(\mathbf{x}) = \min_{\mathbf{z}} \lVert \mathbf{x} - G(\mathbf{z}) \rVert_{2}^{2} + \lVert \mathbf{f}(\mathbf{x}) - \mathbf{f}(G(\mathbf{z})) \rVert_{2}^{2}$$

- Residual loss: How similar a new example from the most similar (L2) sample from the generator
- Discrimination loss: How similar the new example is to the most similar sample from the generator in terms of discriminator features (feature matching)

<h3 class="text-center">Future Directions</h3>

- Examine the efficacy of each evaluation method proposed in [Borji 2018](https://arxiv.org/abs/1802.03446) on our own data
- Try newer, better GAN architectures such as [StyleGAN](https://arxiv.org/abs/1812.04948)
- Explore generalization of GANs with training/test splits
- Investigate and define a reconstruction problem to apply GAN-based reconstruction techniques
- Explore effectiveness of GAN anomaly detection frameworks

<h3 class="text-center">Acknowledgements</h3>

- Mark Anastasio, PhD
- Sayantan Bhadra
- Abhinav Jha, PhD    
- Nico Dosenbach, MD/PhD    