## Composite variational auto-encoder

This notebook implements the composite variational auto-encoder.

At the moment, only the generator is implemented.

In [None]:
%matplotlib inline

import matplotlib.pylab as plt

import os

import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.nn.functional import grid_sample, affine_grid
from torch.utils.data import TensorDataset

import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam, Adamax

In [None]:
pyro.enable_validation(True)
pyro.distributions.enable_validation(False)
pyro.set_rng_seed(0)
smoke_test = 'CI' in os.environ

In [None]:
class CompositeVAEModelParams(dict):
    def __init__(self):
        super(CompositeVAEModelParams, self).__init__()
        
        # parameters for BasicObjectDecoder
        self['basic_object_decoder.width'] = 64
        self['basic_object_decoder.height'] = 64
        self['basic_object_decoder.z_dim'] = 64
        self['basic_object_decoder.hidden_dim'] = 1024
        
        # parameters for the object grid
        self['object_grid.width'] = 8
        self['object_grid.height'] = 6
        
        # parameters of the dataset
        self['observed_image.width'] = 384
        self['observed_image.height'] = 256
        
        # parameters for CompositeVAE
        self['composite_vae.active_cell_prob'] = 0.5
        self['composite_vae.dx_scale'] = 0.2
        self['composite_vae.dy_scale'] = 0.2
        self['composite_vae.s_loc'] = 1.0
        self['composite_vae.s_scale'] = 0.1
        

class BasicObjectDecoder(nn.Module):
    """This module decodes the latent representation of a single object and decodes it into an image."""
    def __init__(self, params: dict):
        super(BasicObjectDecoder, self).__init__()
        
        self.width = params['basic_object_decoder.width']
        self.height = params['basic_object_decoder.height']
        self.z_dim = params['basic_object_decoder.z_dim']
        self.hidden_dim = params['basic_object_decoder.hidden_dim']
        
        self.pixel_count = self.width * self.height
        self.fc1 = nn.Linear(self.z_dim, self.hidden_dim)
        self.fc21 = nn.Linear(self.hidden_dim, self.pixel_count)

        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()

    def forward(self, z_obj):
        hidden = self.relu(self.fc1(z_obj))
        pixel_intensity = self.sigmoid(self.fc21(hidden))
        return pixel_intensity.reshape(z_obj.shape[:-1] + (self.width, self.height))
    
    def load_params(self, decoder_state_dict):
        self.load_state_dict(decoder_state_dict)

        
class BasicImageTransformer(nn.Module):
    """This module transforms a batch of object images into their final location in the
    to-be-composed image (see `forward` below for more details).
    """
    def __init__(self, params: dict):
        super(BasicImageTransformer, self).__init__()
        self.params = params
        
        delta_x = 2. / params['object_grid.width']
        delta_y = 2. / params['object_grid.height']
        
        self.grid_offset_x, self.grid_offset_y = torch.meshgrid((
            -1. + 0.5 * delta_x + delta_x * torch.arange(0, params['object_grid.width']).float(),
            -1. + 0.5 * delta_y + delta_y * torch.arange(0, params['object_grid.height']).float()))
        
    @classmethod
    def generate_stn_theta(cls, tx, ty, sx, sy):
        """Generates STN matrix from explicit scales and translations.
        
        Here, `tx`, `ty`, `sx`, and `sy` refer to $\theta_{13}$, $\theta_{23}$, $\theta_{11}$,
        and $\theta_{22}$ in Eq. (1) in Ref. [1].
        
        Note:
            All of the dimensions of `tx`, `ty`, `sx`, and `sy` are treated as batch dimensions,
            and all tensors must have the same shape. Crucially, these tensors must not have unsqueezed
            flanking singleton dimensions on the right.
            
        Returns:
            A tensor with shape (..., 2, 3) corresponding to the affine transformation to be ingested
            by `torch.nn.functional.affine_grid`.
            
        References:
            [1] Jaderberg, Max, Karen Simonyan, and Andrew Zisserman. "Spatial transformer networks."
            Advances in Neural Information Processing Systems, pp. 2017-2025. 2015.
        """
        batch_shape = sx.shape
        combined = torch.cat(
            (tx.view(*batch_shape, 1), # 0
             ty.view(*batch_shape, 1), # 1
             sx.view(*batch_shape, 1), # 2
             sy.view(*batch_shape, 1), # 3
             torch.zeros(*batch_shape, 1)), -1) # 4
        expansion_indices = torch.LongTensor([3, 4, 1, 4, 2, 0])
        return torch.index_select(combined, -1, expansion_indices).view(*batch_shape, 2, 3)

    @classmethod
    def transform_objects(cls, obj_images, obj_tx, obj_ty, obj_sx, obj_sy, output_width, output_height):
        """Applies an affine transformation on `obj_images` as specified by `obj_tx`, `obj_ty`, `obj_sx`,
        and `obj_sy`.
        
        Arguments:
            obj_images: A tensor with shape (..., Ws, Hs) where Ws and Hs corrspond to the horizontal
                and vertical source dimensions, respectively.
            obj_tx: A tensor with shape (...) (see `BasicImageTransformer.generate_stn_theta`)
            obj_ty: A tensor with shape (...) (see `BasicImageTransformer.generate_stn_theta`)
            obj_sx: A tensor with shape (...) (see `BasicImageTransformer.generate_stn_theta`)
            obj_sy: A tensor with shape (...) (see `BasicImageTransformer.generate_stn_theta`)
            output_width: (positive integer) output width
            output_height: (positive integer) output height
        
        Returns:
            A tensor with shape (..., output_width, output_height) containing the spatially
            transformed images.
        """
        obj_width = obj_images.size(-2)
        obj_height = obj_images.size(-1)
        theta = cls.generate_stn_theta(obj_tx, obj_ty, obj_sx, obj_sy)
        theta_flat = theta.view(-1, 2, 3)
        obj_images_flat = obj_images.view(-1, obj_width, obj_height)
        n_objects_total = obj_images_flat.size(0)
        grid = affine_grid(theta_flat, torch.Size((n_objects_total, 1, output_width, output_height)))
        out = grid_sample(obj_images_flat.view(n_objects_total, 1, obj_width, obj_height), grid)
        return out.view(obj_images.shape[:-2] + (output_width, output_height))

    def forward(self, obj_images, obj_dx, obj_dy, obj_s):
        """Transforms a batch of object images into their final spatial position.
        
        Arguments:
            obj_images: A tensor with the following shape,              
                    
                    (...,
                     object_grid.width,
                     object_grid.height,
                     basic_object_decoder.width,
                     basic_object_decoder.height),
                
                containing images of individual object in the order deemed to be placed on the target
                grid. Here, ... denotes extra batch dimensions.
                
            obj_dx: A tensor with shape (..., object_grid.width, object_grid.height), denoting
                the horizontal centroid-to-centroid displacement of each object image with respect
                to the target grid cell. For example, dx = -1 (+1) implies aligning the centroid of the
                object to the left (right) edge of the cell.
                
            obj_dy: A tensor with shape (..., object_grid.width, object_grid.height), denoting
                the vertical centroid-to-centroid displacement of each object image with respect
                to the target grid cell. For example, dy = -1 (+1) implies aligning the centroid of the
                object to the bottom (top) edge of the cell.
                
            obj_s: A (positive) tensor with shape (..., object_grid.width, object_grid.height), denoting
                the relative scale of each object image with respect to the target grid cell dimensions.
                For example, s = 1 implies fitting the object image to the target grid cell whereas s = 0.5
                implies shrinking the object in half and then embedding inside the target grid cell.
                
        Returns:
            A tensor with the following shape:
            
                (...,
                     object_grid.width,
                     object_grid.height,
                     observed_image.width,
                     observed_image.height).
                     
            Crucially, the output is differentible with respect to all of the input arguments.
        """
        batch_shape = obj_images.shape[:-2]
        grid_width = self.params['object_grid.width']
        grid_height = self.params['object_grid.height']
        
        # generate STN transformation parameters
        #
        # Note:
        #
        #   These expressions are obtained by inverting the affine transformation
        #   (see the Mathematica notebook `affine_trans_eq.nb` for details.)
        #
        obj_tx = -(grid_width * self.grid_offset_x + 2 * obj_dx) / obj_s
        obj_ty = -(grid_height * self.grid_offset_y + 2 * obj_dy) / obj_s
        obj_sx = float(grid_width) / obj_s
        obj_sy = float(grid_height) / obj_s
        
        # transform to the main image
        trasformed_obj_images = self.transform_objects(
            obj_images, obj_tx, obj_ty, obj_sx, obj_sy,
            self.params['observed_image.width'],
            self.params['observed_image.height'])
                
        return trasformed_obj_images
    

class BasicImageComposer(nn.Module):
    """This module takes spatially transformed objects and composes the final image."""
    def __init__(self, params: dict):
        super(BasicImageComposer, self).__init__()
        self.params = params
        
    def forward(self, trasformed_obj_images, obj_active):
        """Compose the final image by simpling adding the image of spatially transformed object,
        and applying a sigmoidal transformation to ensure that the intensities are < 1.0.
        
        Arguments:
            trasformed_obj_images: A tensor with the following shape,
            
                    (...,
                         object_grid.width,
                         object_grid.height,
                         observed_image.width,
                         observed_image.height),
                     
                that contains spatially transformed images of indivual objects (e.g. the output
                of `BasicImageTransformer.forward`.)

            obj_active: A tensor with shape (..., object_grid.width, object_grid.height) that
                implies the presence or absence of an object in the final image.
            
        Returns:
            A tensor with shape (..., observed_image.width, observed_image.height) corresponding
            to a fully-composed image.
        """
        
        added_intensities = torch.sum(
            torch.sum(
                trasformed_obj_images * obj_active.view(*obj_active.shape, 1, 1),
                dim=-3),
            dim=-3)
        
        return torch.sigmoid(added_intensities)
    

class CompositeVAE(nn.Module):
    def __init__(self, params: dict):
        super(CompositeVAE, self).__init__()
        self.basic_object_decoder = BasicObjectDecoder(params)
        self.transformer = BasicImageTransformer(params) 
        self.composer = BasicImageComposer(params) 
        self.params = params
    
    def model(self, images: torch.Tensor):
        """Specifiation of the CompositeVAE model.
        
        Arguments:
            images: a tensor of shape (n_batch, observed_images.width, observed_images.height)
                where n_batch >= 1 is allowed to vary.
        """
        
        n_batch = images.size(0)
        batch_grid_shape = [n_batch, self.params['object_grid.width'], self.params['object_grid.height']]
        
        with pyro.iarange("images", dim=-1, size=n_batch):    
            # draw a Bernoulli for each grid point (cell presence/absence)
            obj_active = pyro.sample("obj_active",
                dist.Bernoulli(probs=self.params['composite_vae.active_cell_prob'])
                    .expand_by(batch_grid_shape))

            # draw basic object latent codes for each grid point
            obj_z_loc = torch.zeros(self.params['basic_object_decoder.z_dim'])
            obj_z_scale = torch.ones(self.params['basic_object_decoder.z_dim'])
            obj_z = pyro.sample("obj_z",
                dist.Normal(obj_z_loc, obj_z_scale)
                    .expand_by(batch_grid_shape)
                    .independent(3))
            
            # decode the latent code of basic objects into images
            obj_images = self.basic_object_decoder(obj_z)
            
            # displacement and scale of basic objects
            obj_dx = pyro.sample("obj_dx",
                dist.Normal(0, self.params['composite_vae.dx_scale'])
                    .expand_by(batch_grid_shape)
                    .independent(3))
            obj_dy = pyro.sample("obj_dy",
                dist.Normal(0, self.params['composite_vae.dy_scale'])
                    .expand_by(batch_grid_shape)
                    .independent(3))
            
            alpha =  self.params['composite_vae.s_loc']**2 / self.params['composite_vae.s_scale']**2
            beta = self.params['composite_vae.s_loc'] / self.params['composite_vae.s_scale']**2
            obj_s = pyro.sample("obj_s",
                dist.Gamma(concentration=alpha, rate=beta)
                    .expand_by(batch_grid_shape)
                    .independent(3))
            
            # transform to the observed image frame
            transform_objects = self.transformer.forward(obj_images, obj_dx, obj_dy, obj_s)
            
            # compose to the observed image frame
            composed_images = self.composer.forward(transform_objects, obj_active)
            
        return composed_images
    

## Test the generator

In [None]:
params = CompositeVAEModelParams()
cvae = CompositeVAE(params)

# initialize the decoder with the VAE we trained earlier on synthetic DAPI stains
cvae.basic_object_decoder.load_params(torch.load('./synth_dapi_vae_params/dapi_vae_decoder_params.pt'))

composed_images = cvae.model(torch.zeros(1, 1, 1))

plt.figure(figsize=(8, 6))
img = composed_images.detach().numpy()[0, :, :].T
plt.imshow(img, origin='lower', cmap=plt.cm.Greys_r)
plt.gca().set_aspect('equal')
plt.axis('off')