# One GiGANtic Leap...

> Creating artifial life.

- toc: true 
- badges: true
- comments: true
- categories: [jupyter]


# Introduction

Generative Adversarial Networks are more than a novel deep learning architecture. They are powerful game-theoretic constructs with potential to change much of our understanding of what is "real". If what is real is useful insofar as what we perceive is useful, then we may be in touble. The advent of GANs in the early part of this century has distrupted the way reality is made. Deep networks are becoming immensely powerful at learning class-posterior distributions $P(Y|X=x)$. However, generative models are doing something much trickier: learning $P(X,Y)$. GANs use a discriminator of $P(Y|X)$ to accomplish this. Thus, adversarial nets are a balancing act to implement and train, but provide unparalleled results. On the one hand, this could seriousy destroy media (Microsoft's MSN recently replaced human authors with generative machines), social stability, and our politics (think DeepFakes). However, there is reason to be hopeful.

One realm I'm particularly excited to see GANs making an entrance is in the biosciences. GANs (and really any high-fidelity generative model) may hold the promise to generate sythetic datasets to build better supervised learners on features such as genomes, proteomes, and bioimages. This could help bring machine learning closer to solving few-shot and zero-shot problems in medicine.

Another example is for the purposes of simulation and modeling. Tons of assumptions must be made during modeling (for, say, a pandemic) that may not prove realistic. On the other hand, epidemiological datasets may be limited in scope and depth, such that sparse or missing features from a data collection stage could prove vital if not for their low frequencies. Simulating patient risk pools for disease-fighting or drug design via a GAN is one potential method to overcome this, as we can sample from $P(Y,X=x_{minority})$ via a model trained to discriminate perfectly on $P(Y|X)$.

Bioengineering is another intriguing space, whereby *in-silico* experimentation meets the lab bench through simulating data like yeast genomes. There's even the potential to build a *life factory* through generating realistic but novel oligomers and augmenting existing genomes (possible given how splicing is currently solved).

In any event, we should know just how to 1) Train a GAN, and 2) Use it to augment an existing workflow. In this post, I will begin by using the COVID-19 cell atlas as a source of $X$ and $y$. $y$ will be some form of clinical phenotype, and $X$ will be the single-cell read counts. That is, single-cell profiling of COVID-19 patients will provide the raw data to build a generative adversarial network. We will begin by training a discriminator $f$ to learn a classification $f(X) = y$, or $P(Y|X)$. Then, the generative model will begin to simulate $X,y$ pairs, and be trained to accurately generate $P(X,Y)$. This can be and is used to generate realistic data, or samples on $P(X|Y=y)$ (such as realistic read-counts of single-cell data given patient phenotypes). We can use this trained generative model to shore up minority classes to improve classifier performance (another $f(X) = y$, though on held-out test data), or simply to provide *more* data distributed in the same fashion as $X,y$.

# Data

I will use the COVID-19 Cell Atlas and load it into a PyTorch `Dataset` object.

In [1]:
import scanpy as sc
from torch.utils.data import Dataset,DataLoader

In [7]:
dataset = sc.datasets.pbmc3k_processed

In [None]:
class SingleCellDataset()

# Model Architecture

## Probabilistic Intuition

A symmetric view:

i) A generator (in the data-faker view) samples from $P(X|Y=y)$, the distribution of the feature over each class or label.

- Note that $P(X,Y) = P(X|Y)P(Y) \implies P(X|Y) = P(X,Y)/P(Y)$

ii) A discriminator samples on $P(Y|X=x)$, the distribution of the labels over each feature.

- Note that $P(X,Y) = P(Y|X)P(X) \implies P(Y|X) = P(X,Y)/P(X)$


This is because we are provided many instances of $x,y$, where the label space is finite.

Assuming the labels are discrete (i.e., in the case of a classifier and not a regressor), we can always sum to marginalize out the labels from the joint distribution to produce the feature prior $P(x) = \sum_y P(X,Y=y)$.

Likewise, for a continuous feature $X$, we can always integrate to marginalize out the features and produce the class prior $P(y) = \int_x P(Y,X=x)$.

Then, either the "generator" $P(X|Y)$ or the "discriminator" $P(Y|X)$ can be derived by the definition of conditional probability and Bayes' rule (see the above bullets).

## Implementation

We integrate this construction by using a GAN. Classically, this implies that we train a discriminator $f(X) = y$ on $\vec{x},\vec{y}$ and subsequently use this 
    




## The Generative Adversarial Network in PyTorch

In [None]:
import os
from argparse import ArgumentParser
from collections import OrderedDict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNISyT

import pytorch_lightning as pl

In [None]:
class CovidDataModule(pl.LightningDataModule):

    def __init__(self, data_dir: str = './', batch_size: int = 64, num_workers: int = 8):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

        # self.dims is returned when you call dm.size()
        # Setting default dims here because we know them.
        # Could optionally be assigned dynamically in dm.setup()
        self.dims = (1, 28, 28)
        self.num_classes = 10

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super().__init__()
        self.img_shape = img_shape

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *self.img_shape)
        return img

In [None]:
class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity

In [None]:
 class CovidGAN(pl.LightningModule):

    def __init__(
        self,
        channels,
        width,
        height,
        latent_dim: int = 100,
        lr: float = 0.0002,
        b1: float = 0.5,
        b2: float = 0.999,
        batch_size: int = 64,
        **kwargs
    ):
        super().__init__()
        self.save_hyperparameters()

        # networks
        data_shape = (channels, width, height)
        self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=data_shape)
        self.discriminator = Discriminator(img_shape=data_shape)

        self.validation_z = torch.randn(8, self.hparams.latent_dim)

        self.example_input_array = torch.zeros(2, self.hparams.latent_dim)

    def forward(self, z):
        return self.generator(z)

    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)

    def training_step(self, batch, batch_idx, optimizer_idx):
        imgs, _ = batch

        # sample noise
        z = torch.randn(imgs.shape[0], self.hparams.latent_dim)
        z = z.type_as(imgs)

        # train generator
        if optimizer_idx == 0:

            # generate images
            self.generated_imgs = self(z)

            # log sampled images
            sample_imgs = self.generated_imgs[:6]
            grid = torchvision.utils.make_grid(sample_imgs)
            self.logger.experiment.add_image('generated_images', grid, 0)

            # ground truth result (ie: all fake)
            # put on GPU because we created this tensor inside training_loop
            valid = torch.ones(imgs.size(0), 1)
            valid = valid.type_as(imgs)

            # adversarial loss is binary cross-entropy
            g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
            tqdm_dict = {'g_loss': g_loss}
            output = OrderedDict({
                'loss': g_loss,
                'progress_bar': tqdm_dict,
                'log': tqdm_dict
            })
            return output

        # train discriminator
        if optimizer_idx == 1:
            # Measure discriminator's ability to classify real from generated samples

            # how well can it label as real?
            valid = torch.ones(imgs.size(0), 1)
            valid = valid.type_as(imgs)

            real_loss = self.adversarial_loss(self.discriminator(imgs), valid)

            # how well can it label as fake?
            fake = torch.zeros(imgs.size(0), 1)
            fake = fake.type_as(imgs)

            fake_loss = self.adversarial_loss(
                self.discriminator(self(z).detach()), fake)

            # discriminator loss is the average of these
            d_loss = (real_loss + fake_loss) / 2
            tqdm_dict = {'d_loss': d_loss}
            output = OrderedDict({
                'loss': d_loss,
                'progress_bar': tqdm_dict,
                'log': tqdm_dict
            })
            return output

    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
        return [opt_g, opt_d], []

    def on_epoch_end(self):
        z = self.validation_z.type_as(self.generator.model[0].weight)

        # log sampled images
        sample_imgs = self(z)
        grid = torchvision.utils.make_grid(sample_imgs)
        self.logger.experiment.add_image('generated_images', grid, self.current_epoch)

In [None]:
dm = CovidDataModule()
model = CovidGAN(*dm.size())
trainer = pl.Trainer(gpus=1, max_epochs=5, progress_bar_refresh_rate=20)
trainer.fit(model, dm)

# The Test Case: Data Augmentation for Predicting Clinical Phenotypes

## Hypothesis: Oversampling Features from Minority Cell-Type Membership 