## DAPI stain variational auto-encoder

This notebook trains a variational auto-encoder (VAE) from DAPI stains.

In [None]:
%matplotlib inline

import matplotlib.pylab as plt

import os

import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import TensorDataset

import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam, Adamax

In [None]:
pyro.enable_validation(True)
pyro.distributions.enable_validation(False)
pyro.set_rng_seed(0)
smoke_test = 'CI' in os.environ

In [None]:
def setup_data_loaders(batch_size=128, use_cuda=False):
    train_dataset_npy = './synth_dapi_stains/synth_dapi_stains_train_64_64_10k.npy'
    test_dataset_npy = './synth_dapi_stains/synth_dapi_stains_test_64_64_1k.npy'
    
    train_set = TensorDataset(torch.Tensor(np.load(train_dataset_npy)))
    test_set = TensorDataset(torch.Tensor(np.load(test_dataset_npy)))
    
    kwargs = {'num_workers': 1, 'pin_memory': use_cuda}
    
    train_loader = torch.utils.data.DataLoader(dataset=train_set,
        batch_size=batch_size, shuffle=True, **kwargs)
    
    test_loader = torch.utils.data.DataLoader(dataset=test_set,
        batch_size=batch_size, shuffle=False, **kwargs)
    
    return train_loader, test_loader

In [None]:
class Encoder(nn.Module):
    """A basic two layer dense encoder."""
    def __init__(self, width, height, z_dim, hidden_dim):
        super(Encoder, self).__init__()

        self.width = width
        self.height = height
        self.pixel_count = width * height
        
        self.fc1 = nn.Linear(self.pixel_count, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, z_dim)
        self.fc22 = nn.Linear(hidden_dim, z_dim)

        self.softplus = nn.Softplus()
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.reshape(-1, self.pixel_count)
        hidden = self.relu(self.fc1(x))
        z_loc = self.fc21(hidden)
        z_scale = torch.exp(self.fc22(hidden))
        return z_loc, z_scale


class Decoder(nn.Module):
    """A basic two layer dense decoder."""
    def __init__(self, width, height, z_dim, hidden_dim):
        super(Decoder, self).__init__()
        
        self.width = width
        self.height = height
        self.pixel_count = width * height
        
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, self.pixel_count)

        self.softplus = nn.Softplus()
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()

    def forward(self, z):
        hidden = self.relu(self.fc1(z))
        loc_img = self.sigmoid(self.fc21(hidden))
        return loc_img

    
class VAE(nn.Module):
    def __init__(self, width, height, z_dim=50, hidden_dim=400, use_cuda=False):
        super(VAE, self).__init__()

        self.width = width
        self.height = height
        self.pixel_count = width * height

        self.encoder = Encoder(width, height, z_dim, hidden_dim)
        self.decoder = Decoder(width, height, z_dim, hidden_dim)

        if use_cuda:
            # calling cuda() here will put all the parameters of
            # the encoder and decoder networks into gpu memory
            self.cuda()

        self.use_cuda = use_cuda
        self.z_dim = z_dim

    def model(self, x):
        pyro.module("decoder", self.decoder)
        batch_size = x.size(0)
        
        with pyro.iarange("data", batch_size):
            z_loc = x.new_zeros(batch_size, self.z_dim)
            z_scale = x.new_ones(batch_size, self.z_dim)
            
            z = pyro.sample("latent", dist.Normal(z_loc, z_scale).independent(1))
            
            pixel_loc = self.decoder.forward(z)
            pixel_scale = 0.1 * torch.ones_like(pixel_loc)
            
            # TODO: consider a better observation distribution
            pyro.sample("obs",
                        dist.Normal(pixel_loc, pixel_scale).independent(1),
                        obs=x.reshape(-1, self.pixel_count))

    def guide(self, x):
        pyro.module("encoder", self.encoder)
        batch_size = x.size(0)
        
        with pyro.iarange("data", batch_size):
            z_loc, z_scale = self.encoder.forward(x)
            pyro.sample("latent", dist.Normal(z_loc, z_scale).independent(1))

    def reconstruct_img(self, x):
        z_loc, z_scale = self.encoder(x)
        z = dist.Normal(z_loc, z_scale).sample()
        loc_img = self.decoder(z)
        return loc_img

In [None]:
def train(svi, train_loader, use_cuda=False):
    epoch_loss = 0.
    for _, (x,) in enumerate(train_loader):
        # if on GPU put mini-batch into CUDA memory
        if use_cuda:
            x = x.cuda()
        epoch_loss += svi.step(x)

    normalizer_train = len(train_loader.dataset)
    total_epoch_loss_train = epoch_loss / normalizer_train
    
    return total_epoch_loss_train


def evaluate(svi, test_loader, use_cuda=False):
    test_loss = 0.
    for i, (x,) in enumerate(test_loader):
        # if on GPU put mini-batch into CUDA memory
        if use_cuda:
            x = x.cuda()
        test_loss += svi.evaluate_loss(x)

    normalizer_test = len(test_loader.dataset)
    total_epoch_loss_test = test_loss / normalizer_test
    
    return total_epoch_loss_test

In [None]:
# Run options
LEARNING_RATE = 1.0e-3
USE_CUDA = False
Z_DIM = 64
HIDDEN_DIM = 1024
WIDTH = 64
HEIGHT = 64

# Run only for a single iteration for testing
NUM_EPOCHS = 1 if smoke_test else 300
TEST_FREQUENCY = 5

In [None]:
train_loader, test_loader = setup_data_loaders(batch_size=256, use_cuda=USE_CUDA)

# clear pyro params
pyro.clear_param_store()

# setup the VAE
vae = VAE(use_cuda=USE_CUDA, z_dim=Z_DIM, hidden_dim=HIDDEN_DIM, width=WIDTH, height=HEIGHT)

# setup the optimizer
adam_args = {"lr": LEARNING_RATE}
optimizer = Adam(adam_args)

# setup the inference algorithm
svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO())

train_elbo = []
test_elbo = []

## Load model parameters (optional)

In [None]:
vae.encoder.load_state_dict(torch.load('./synth_dapi_vae_params/dapi_vae_encoder_params.pt'))
vae.decoder.load_state_dict(torch.load('./synth_dapi_vae_params/dapi_vae_decoder_params.pt'))

## Train model (can skip if parameters loaded)

In [None]:
for epoch in range(NUM_EPOCHS):
    total_epoch_loss_train = train(svi, train_loader, use_cuda=USE_CUDA)
    train_elbo.append(-total_epoch_loss_train)
    print("[epoch %03d]  average training loss: %.4f" % (epoch, total_epoch_loss_train))

    if epoch % TEST_FREQUENCY == 0:
        # report test diagnostics
        total_epoch_loss_test = evaluate(svi, test_loader, use_cuda=USE_CUDA)
        test_elbo.append(-total_epoch_loss_test)
        print("[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test))

## Fun

In [None]:
# reconstruction of a real spot
(real_img,) = test_loader.dataset[28]
plt.figure()
plt.imshow(real_img.numpy().squeeze(), cmap=plt.cm.Greys_r)

plt.figure()
rec_img = vae.reconstruct_img(real_img).reshape_as(real_img)
plt.imshow(rec_img.detach().numpy().squeeze(), cmap=plt.cm.Greys_r)

In [None]:
from ipywidgets import interactive, FloatSlider

z_rand = torch.Tensor(np.random.randn(Z_DIM))
dz_dir_1 = np.random.randn(Z_DIM)
dz_dir_2 = np.random.randn(Z_DIM)
dz_dir_3 = np.random.randn(Z_DIM)

def plot_dapi_stain(dz_1, dz_2, dz_3):
    plt.figure(figsize=(6,6))
    with torch.no_grad():
        z_new = z_rand.numpy() + dz_1 * dz_dir_1 + dz_2 * dz_dir_2 + dz_3 * dz_dir_3 
        img = vae.decoder.forward(torch.Tensor(z_new)).reshape(WIDTH, HEIGHT).numpy()
    plt.axis('off')
    plt.imshow(img, cmap=plt.cm.Greys_r)
    plt.show()

interactive_plot = interactive(plot_dapi_stain,
                               dz_1=FloatSlider(min=-1., max=1., step=0.1),
                               dz_2=FloatSlider(min=-1., max=1., step=0.1),
                               dz_3=FloatSlider(min=-1., max=1., step=0.1))
output = interactive_plot.children[-1]
interactive_plot

## Save model to disk

In [None]:
torch.save(vae.encoder.state_dict(), './synth_dapi_vae_params/dapi_vae_encoder_params.pt')
torch.save(vae.decoder.state_dict(), './synth_dapi_vae_params/dapi_vae_decoder_params.pt')