# Deep probabilistic models with applications in astronomy

In [None]:
%matplotlib notebook
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import YouTubeVideo
from tqdm.notebook import tqdm
import corner

import torch
import pyro
print(f"Torch version: {torch.__version__}")
print(f"Pyro version: {pyro.__version__}")

%load_ext autoreload
%autoreload 2
from astro_utils import plot_params, plot_lc, plot_lc_folded, featurize_lc, plot_lc_features, make_train_plots

# Pyro basics

Example: Bayesian linear regression 

$$
y_i = w x_i + b +  \epsilon_i
$$

with $N$ observations $(x_i, y_i)$, Gaussian noise and Gaussian priors for $w$ and  $b$

The generative process in this case is 

- Sample $w \sim \mathcal{N}(\mu_w, \sigma_w^2)$
- Sample $b \sim \mathcal{N}(\mu_b, \sigma_b^2)$
- For $i=1,2,\ldots, N$
    - Sample $y_i \sim \mathcal{N}(w x_i + b, \sigma_\epsilon^2)$
    
where $\mu_w, \sigma_w, \mu_b, \sigma_b, \sigma_\epsilon$ are hyperparameters



## Writing a model

To write the model we use the submodules and primitives

- `pyro.distributions` to define prior/likelihoods 
- `pyro.sample` to define random variables (RV): Expects name, distribution and optionally observations
- `pyro.plate` for conditionally independent RV: Expects name, and size

In [None]:
from pyro.distributions import Normal, Uniform

def model(x, y=None): 
    w = pyro.sample("w", Normal(0.0, 1.0))
    b = pyro.sample("b", Normal(0.0, 1.0))
    s = pyro.sample("s", Uniform(0.0, 1.0))    
    with pyro.plate('dataset', size=len(x)):
        return pyro.sample("y", Normal(x*w + b, s), obs=y)

We can use 

- `pyro.infer.Predictive` 

to draw samples from the model

In [None]:
predictive = pyro.infer.Predictive(model, num_samples=500)

hatx = np.linspace(-6, 6, num=100).astype('float32') 
apriori_samples = predictive(torch.from_numpy(hatx))

# Plot samples from the priors
figure = corner.corner(np.stack([apriori_samples[var].detach().numpy()[:, 0] for var in ['b', 'w', 's']]).T, 
                       smooth=1., labels=["bias", "weight", "noise_std"], bins=20, 
                       quantiles=[0.16, 0.5, 0.84], 
                       show_titles=True, title_kwargs={"fontsize": 12})

# Plot posterior predictive of y given x
y_trace = apriori_samples["y"].detach().numpy()
med = np.median(y_trace, axis=0)
qua = np.quantile(y_trace, (0.05, 0.95), axis=0)

fig, ax = plt.subplots(figsize=(7, 3), tight_layout=True)
ax.plot(hatx, y_trace.T);

## Inference

In the bayesian setting we want the posterior distribution 

$$
p(\theta | \mathcal{D}) = \frac{p(\mathcal{D}|\theta) p(\theta)}{\int_\theta p(\mathcal{D}|\theta) p(\theta)}
$$

where $\mathcal{D}$ is our dataset and $\theta = (w, b)$

For complex models the posterior is intractable. So we either do

- MCMC: Train a Markov chain to generate samples as if they came from the actual posterior: Sampling based
- Variational Inference: Choose a more simple posterior that is similar to the actual posterior: Optimization based



### Variational Inference

Propose an approximate (simple) posterior $q_\phi(\theta)$, e.g. factorized Gaussian

Optimize $\phi$ so that $q_\phi$ approximates $p(\theta|\mathcal{D})$

This is typically done by maximizing a lower bound on the evidence

$$
\mathcal{ELBO}(\phi) = \mathbb{E}_{\theta \sim q_\phi}[ \log p(\mathcal{D}|\theta)] - \text{KL}[q_\phi(\theta)|p(\theta)]
$$

- Maximize the likelihood of the model
- Minimize the distance between the approximate posterior and the prior

Once $q$ has been trained we use it as a replacement for $p(\theta|\mathcal{D})$ to calculate the **posterior predictive distribution**

$$
p(\mathbf{y}|\mathbf{x}, \mathcal{D}) = \int p(\mathbf{y}|\mathbf{x}, \theta) p(\theta| \mathcal{D}) \,d\theta
$$



## VI with Pyro

We use `pyro.infer.SVI` to perform **Stochastic Variational Inference**, which expects

- A generative model
- An approximate posterior (guide)
- Cost function: Typically ELBO
- Optimizer: How to optimize the ELBO, typically gradient descent based

We can use the `pyro.infer.autoguide` to create approximate posteriors from predefined recipes, for example a factorized Gaussian posterior (`AutoDiagonalNormal`)

In [None]:
pyro.enable_validation(True) # Useful for debugging
pyro.clear_param_store()

# Create a guide (approximate posterior)
from pyro.infer.autoguide import AutoDiagonalNormal as approx_posterior
guide = approx_posterior(model)


# Stochastic Variational Inference
svi = pyro.infer.SVI(model=model,  
                     guide=guide,
                     loss=pyro.infer.Trace_ELBO(), 
                     optim=pyro.optim.ClippedAdam({"lr": 1e-2}))

Lets consider a dataset with two observations $\mathcal{D} = \{ (-2, -2), (2, 2) \}$

`svi.step(x, y)` performs a gradient ascent step to maximize the ELBO

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(10, 2.5), dpi=80, tight_layout=True)

nepochs = 1000
loss = np.zeros(shape=(nepochs, ))
params = np.zeros(shape=(nepochs, 2, 3))

# Observed data
x = torch.tensor([-2., 2.])
y = torch.tensor([-2., 2.])

for k in tqdm(range(nepochs)):
    loss[k] = svi.step(x, y)
    
    phi = [param.detach().numpy() for param in guide.parameters()]
    params[k, 0, :] = phi[0] # Locations
    params[k, 1, :] = phi[1] # Scales    
    if np.mod(k, 10) == 0:
        plot_params(ax, k+1, loss, params)
        fig.canvas.draw()

Again we can use `pyro.infer.Predictive` to draw samples from the model

This time we use the guide to sample $w$ and $b$

In [None]:
predictive = pyro.infer.Predictive(model, 
                                   guide=guide, 
                                   num_samples=1000)

posterior_samples = predictive(torch.from_numpy(hatx))

# Plot posterior of w,  b and s
figure = corner.corner(np.stack([posterior_samples[var].detach().numpy()[:, 0] for var in ['b', 'w', 's']]).T, 
                       smooth=1., labels=["bias", "weight", "noise_std"], bins=20, 
                       quantiles=[0.16, 0.5, 0.84],
                       show_titles=True, title_kwargs={"fontsize": 12})

# Plot posterior predictive of y given x
y_trace = posterior_samples["y"].detach().numpy()
med = np.median(y_trace, axis=0)
qua = np.quantile(y_trace, (0.05, 0.95), axis=0)

fig, ax = plt.subplots(figsize=(7, 3), tight_layout=True)
ax.plot(hatx, med)
ax.fill_between(hatx, qua[0], qua[1], alpha=0.5);

ax.errorbar(x.numpy(), y.numpy(), yerr=2*posterior_samples['s'].median().item(), 
           fmt='none', c='k', zorder=100);


# PyTorch basics

Artificial Neuron: Linear regressor plus activation function (e.g. sigmoid -> Logistic regression)

$$
\hat y = g\left(\sum_i w_i x_i + b\right)
$$

Multilayer perceptron: Several fully connected layers concatenated

$$
\begin{align}
f_i &=   b_i + \sum_{j=1}^H w_{ij} h_j  \nonumber \\
&=  b_i + \sum_{j=1}^H w_{ij} g \left( b_j + \sum_{d=1}^D w_{jd} x_d  \right) \nonumber
\end{align}
$$

<img src="images/MLP.png" width="400">

Neural network model in pytorch

- We write a class that inherits from `torch.nn.Module`
- The constructor `__init__(self, args):` define the layers, e.g. fully-connected (`Linear`), convolutional (`Conv1D`, `Conv2D`)
- The function `forward(self, x):` defines how layers are connected

In [None]:
class MultiLayerPerceptron(torch.nn.Module):
    
    def __init__(self, D, H, O):
        super(type(self), self).__init__()
        self.hidden_layer = torch.nn.Linear(D, H) # WX + B
        self.output_layer = torch.nn.Linear(H, O)
        self.activation = torch.nn.Sigmoid()
        
    def forward(self, x): 
        z = self.activation(self.hidden_layer(x))
        return self.output_layer(z)

# Variational Autoencoder (VAE) [(Kingma et al., 2014)](https://arxiv.org/abs/1312.6114)

Non-linear latent variable model

<img src="images/VAE.png" width="800">






- Latent variables: $z \sim \mathcal{N}(0, I)$
- Observed variable: $x|z \sim \mathcal{N}(\hat \mu, \hat \sigma^2)$

We use a neural network $f_\theta$ to model a non-linear function $\hat \mu (z)$

We want to infer the latent given the observed

$$
p(z|x) = \frac{p(x|z) p(z)}{\int p(x|z) p(z) dz}
$$

But this is not tractable so we use a variational approximation (factorized Gaussian)

$$
\begin{align}
q_\phi(z|x) &= \mathcal{N}(\mu(x), \sigma(x)^2) \nonumber \\
&= \mu(x) + \sigma(x) \epsilon, \quad \epsilon \sim \mathcal{N}(0, I) \nonumber 
\end{align}
$$

where a neural network $g_\phi$ are used to model the non-linear functions $\mu(x)$ and $\sigma(x)$

We amortize the parameters of $q$ using the neural network, making inference very efficient

Note: The VAE is not a full bayesian neural network, the parameters of the neural net are point estimates

In [None]:
from pyro.distributions import Normal
import torch.nn as nn

class Encoder(torch.nn.Module):
    def __init__(self, data_dim, latent_dim, hidden_dim):
        super(type(self), self).__init__()
        self.hidden = nn.Linear(data_dim, hidden_dim)
        self.z_loc = nn.Linear(hidden_dim, latent_dim)
        self.z_scale = nn.Linear(hidden_dim, latent_dim)
        self.activation = nn.Softplus()

    def forward(self, x):
        h = self.activation(self.hidden(x))
        return self.z_loc(h), self.activation(self.z_scale(h))
    
class Decoder(torch.nn.Module):
    def __init__(self, data_dim, latent_dim, hidden_dim):
        super(type(self), self).__init__()
        self.hidden = nn.Linear(latent_dim, hidden_dim)
        self.x_loc = nn.Linear(hidden_dim, data_dim)
        self.activation = nn.Softplus()

    def forward(self, z):
        h = self.activation(self.hidden(z))
        return self.x_loc(h)
    
class VariationalAutoEncoder(torch.nn.Module):
    
    def __init__(self, data_dim, latent_dim, hidden_dim, sigmax):
        super(VariationalAutoEncoder, self).__init__() 
        self.encoder = Encoder(data_dim, latent_dim, hidden_dim)
        self.decoder = Decoder(data_dim, latent_dim, hidden_dim)
        self.latent_dim = latent_dim
        self.data_dim = data_dim
        self.sigmax = sigmax
        
    def model(self, x):
        pyro.module("decoder", self.decoder) 
        with pyro.plate("data", size=x.shape[0]):
            # p(z)
            z_loc = torch.zeros(x.shape[0], self.latent_dim, device=x.device)
            z_scale = torch.ones(x.shape[0], self.latent_dim, device=x.device)
            z = pyro.sample("latent", Normal(z_loc, z_scale).to_event(1))
            # p(x|z)
            x_loc = self.decoder.forward(z)
            pyro.sample("observed", Normal(x_loc, self.sigmax).to_event(1), obs=x)
            return x_loc
    
    def guide(self, x):
        pyro.module("encoder", self.encoder)
        with pyro.plate("data", size=x.shape[0]):
            # q(z|x)
            z_loc, z_scale  = self.encoder.forward(x)
            pyro.sample("latent", Normal(z_loc, z_scale).to_event(1))

# Astronomical time series



- **Light curve:** Time series of a star's flux (brightness)
- We will consider two optical passbands
- The "apparent" brightness is estimated through **Photometry**
- Main tool for variable star studies

<table><tr><td>
    <img src="images/ZTF.png" width="250">
</td><td>
    <img src="images/intro-sources.png" width="300">
</td></tr></table>

<center>
    <img src="images/intro-sources-time.png" width="600">
</center>

In [None]:
import bz2
import pickle

with bz2.BZ2File("lcdata.pbz2", 'r') as f:
    lcs, periods, labels = pickle.load(f)

fig, ax = plt.subplots(figsize=(7, 3), tight_layout=True)
plot_lc(ax, lcs[0])

- Light curve "features": Irregular sampling, gaps in observations, [heteroscedastic](https://en.wikipedia.org/wiki/Heteroscedasticity) noise
- **Variable stars**: Brightness change in time either regularly or stochastically
- Some variable stars are radial pulsators. They expand/heat and contract/cool regularly. Examples: RR Lyrae and Cepheid

In [None]:
YouTubeVideo('sXJBrRmHPj8')

If we know the pulsation period we can use the **epoch folding transformation** to obtain a phase diagram

<img src="images/folding.png" width="800">

In [None]:
def fold(time, period):
    """
    returns phase = time/period - floor(time/period)
    """
    return np.mod(time, period)/period

fig, ax = plt.subplots(figsize=(7, 3), tight_layout=True)
plot_lc_folded(ax, lcs[0], periods[0])

# Training a VAE for periodic light curves

Very simple preprocessing and data preparation routine

In [None]:
pha_interp = np.linspace(0, 1, num=40)
mag_interp = np.zeros(shape=(len(lcs), 2, len(pha_interp)))
err_interp = np.zeros(shape=(len(lcs), 2, len(pha_interp)))

for k, (lc, period) in enumerate(zip(lcs, periods)):
    mag_interp[k], err_interp[k], stats = featurize_lc(lc, period, pha_interp)
    
fig, ax = plt.subplots(figsize=(7, 3), tight_layout=True)
plot_lc_features(ax, pha_interp, mag_interp[0], err_interp[0])

In [None]:
from sklearn.preprocessing import LabelEncoder  
from torch.utils.data import TensorDataset, DataLoader, Subset

# reproducibility
torch.manual_seed(12345);  
np.random.seed(12345) 

fig, ax = plt.subplots(figsize=(7, 2), tight_layout=True)
ax.hist(labels, bins=np.arange(-0.5, 5.5, step=1), rwidth=0.8);

le = LabelEncoder()
labels_int = le.fit_transform(labels)
# Create light curve dataset from numpy arrays
lc_dataset = TensorDataset(torch.from_numpy(mag_interp.astype('float32')), 
                           torch.from_numpy(err_interp.astype('float32'))
                           )

# Generate data loaders
idx = np.random.permutation(len(lcs))
train_loader = DataLoader(dataset=Subset(lc_dataset, idx[len(lcs)//5:]), 
                          batch_size=32, shuffle=True)
                          
valid_loader = DataLoader(dataset=Subset(lc_dataset, idx[:len(lcs)//5]), batch_size=512)

Define a VAE that receives two inputs (r and g) and generates two outputs

In [None]:
from pyro.distributions import Normal
import torch.nn as nn

class VariationalAutoEncoder(torch.nn.Module):
    
    def __init__(self, data_dim, latent_dim, hidden_dim):
        super(VariationalAutoEncoder, self).__init__() 
        self.encoder = Encoder(data_dim, latent_dim, hidden_dim)
        self.decoder = Decoder(data_dim, latent_dim, hidden_dim)
        self.latent_dim = latent_dim
        self.data_dim = data_dim
        
    def model(self, mags, errs):
        pyro.module("decoder", self.decoder)
        with pyro.plate("minibatch", size=mags.shape[0]):
            # p(z)
            z_loc = torch.zeros(mags.shape[0], self.latent_dim, device=mags.device)
            z_scale = torch.ones(mags.shape[0], self.latent_dim, device=mags.device)
            z = pyro.sample("latent", Normal(z_loc, z_scale).to_event(1))
            # p(x|z)
            x_loc1, x_loc2 = self.decoder.forward(z)
            pyro.sample("observed_g", Normal(x_loc1, errs[:, 0, :]).to_event(1), obs=mags[:, 0, :])
            pyro.sample("observed_r", Normal(x_loc2, errs[:, 1, :]).to_event(1), obs=mags[:, 1, :])
            return x_loc1, x_loc2
    
    def guide(self, mags, errs):
        pyro.module("encoder", self.encoder)
        with pyro.plate("minibatch", size=mags.shape[0]):
            # q(z|x)
            z_loc, z_scale  = self.encoder.forward(mags)
            pyro.sample("latent", Normal(z_loc, z_scale).to_event(1))

In [None]:
class Encoder(torch.nn.Module):
    def __init__(self, data_dim, latent_dim, hidden_dim):
        super(type(self), self).__init__()
        self.hidden1_g = nn.Linear(data_dim, hidden_dim)
        self.hidden1_r = nn.Linear(data_dim, hidden_dim)
        self.hidden2 = nn.Linear(hidden_dim*2, hidden_dim)
        self.z_loc = nn.Linear(hidden_dim, latent_dim)
        self.z_scale = nn.Linear(hidden_dim, latent_dim)
        self.activation = nn.Softplus()

    def forward(self, x):
        hg = self.activation(self.hidden1_g(x[:, 0, :]))
        hr = self.activation(self.hidden1_r(x[:, 1, :]))
        h = torch.cat([hg, hr], dim=1)
        h = self.activation(self.hidden2(h))        
        return self.z_loc(h), self.activation(self.z_scale(h))
    
class Decoder(torch.nn.Module):
    def __init__(self, data_dim, latent_dim, hidden_dim):
        super(type(self), self).__init__()
        self.hidden1 = nn.Linear(latent_dim, hidden_dim)
        self.hidden2 = nn.Linear(hidden_dim, hidden_dim)
        self.x_locg = nn.Linear(hidden_dim, data_dim)
        self.x_locr = nn.Linear(hidden_dim, data_dim)
        self.activation = nn.Softplus()

    def forward(self, z):
        h = self.activation(self.hidden1(z))
        h = self.activation(self.hidden2(h))
        return self.x_locg(h), self.x_locr(h)

Model and SVI object definitions and training loop

In [None]:
pyro.enable_validation(True) 
pyro.clear_param_store()

vae = VariationalAutoEncoder(data_dim=40, latent_dim=2, hidden_dim=20).cuda()
    
svi = pyro.infer.SVI(model=vae.model, 
                     guide=vae.guide, 
                     optim=pyro.optim.Adam({"lr": 1e-3}), 
                     loss=pyro.infer.Trace_ELBO())

fig, ax = plt.subplots(1, 2, figsize=(8, 3), tight_layout=True)

nepochs = 100
epoch_loss = np.zeros(shape=(nepochs, 2))
for nepoch in tqdm(range(nepochs)):
    
    # Training loop
    for mags, errs in train_loader:
        epoch_loss[nepoch, 0] += svi.step(mags.cuda(), errs.cuda())
    epoch_loss[nepoch, 0] /= len(train_loader.dataset)
    
    # Validation loop
    for mags, errs in valid_loader:
        epoch_loss[nepoch, 1] += svi.evaluate_loss(mags.cuda(), errs.cuda())
    epoch_loss[nepoch, 1] /= len(valid_loader.dataset)    
        
    # Plot latent space and learning curves
    Z = torch.cat((vae.encoder(lc_dataset.tensors[0].cuda())), dim=1)
    Z = Z.detach().cpu().numpy()
    make_train_plots(ax, nepoch, Z, labels_int, epoch_loss)
    fig.canvas.draw()

## Input and reconstructions

In [None]:
vae = vae.cpu()
    
fig, ax = plt.subplots(2, 10, figsize=(8, 2), tight_layout=True)

mags, errs = next(iter(valid_loader))
z_loc, z_scale = vae.encoder.forward(mags)
mug, mur = vae.decoder(z_loc)
for k, (c, mu) in enumerate(zip(['g', 'r'], [mug, mur])):
    for i in range(10):        
        ax[k, i].plot(mags.numpy()[i, k, :], c=c)
        ax[k, i].plot(mu[i].detach().numpy(), ls='--', c=c)
        ax[k, i].invert_yaxis(); 
        ax[k, i].axis('off')

## Latent space visualization

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 6), tight_layout=True)

for k, name in zip(np.unique(labels_int), np.unique(labels)):
    mask = labels_int == k
    ax.errorbar(x=Z[mask, 0], y=Z[mask, 1], 
                xerr=Z[mask, 2], yerr=Z[mask, 3],
                fmt='none', alpha=0.5, label=name)
ax.legend(loc=4)
ax.set_xlim([-5., 5.])
ax.set_ylim([-5., 5.]);

## Latent space interpolation

In [None]:
M = 10
z_plot = np.linspace(-5., 5., num=M)

fig, ax = plt.subplots(M, M, figsize=(8, 6), tight_layout=True)
for i in range(M):
    for j in range(M):
        z = torch.tensor(np.array([z_plot[j], z_plot[M-1-i]]), dtype=torch.float32)
        mug, mur = vae.decoder.forward(z)
        mug = mug.detach().numpy()
        mur = mur.detach().numpy()
        ax[i, j].plot(mug, lw=2, c='g')
        ax[i, j].plot(mur, lw=2, c='r')
        ax[i, j].invert_yaxis(); 
        ax[i, j].axis('off')

# Interested in Astroinformatics?

<a href="http://alerce.science"><img src="images/alerce.png" width="600"></a>