In [None]:
%%HTML
<style>
.container {
    width:80% ! important;
}
.output_png {
    display: table-cell;
    text-align: center;
    vertical-align: middle;
}
.rendered_html { 
    font-size:1.0em; 
}
.rendered_html table{
    width: 80%;
    margin-left:auto; 
    margin-right:auto;
    padding: 20px;
    border: 0px solid black;    
    background-color: #ff;
}
.rendered_html td, .rendered_html th 
{
    vertical-align: top;
    text-align: left;
    font-size: 14px;
    font-face: sans-serif;
}
</style>

<center>
<h1> Analysis and Classification of Periodic Variable Stars</h1>
<h2>Pablo Huijse H. (phuijse at inf dot uach dot cl)</h2>
<h3>Universidad Austral de Chile & Millennium Institute of Astrophysics</h3>
</center>

LIVE at https://github.com/phuijse/tutorial_periodic_stars

Thanks to
- The organizers
- The Millennium Institute of Astrophysics
- CONICYT FONDECYT 1170305 and PAI 79170017


# Variable stars

- Stars whose brightness change in time
- Different reasons behind this

### Pulsating variables
- Some variable stars pulsate radially
- They expand/heat and contract/cool regularly
- Examples: Cepheid and RR Lyrae

<center>
<a href="https://www.youtube.com/watch?v=sXJBrRmHPj8">
    <img src="https://media.giphy.com/media/QP4taxvfVmVEI/giphy.gif" width="400">
</a>
</center>

### Eclipsing Binaries

- System of two stars
- The rotational plane is aligned with us
- From our point of view we see brightness decrease with the mutual eclipses
<center>
<table>
    <tr><td>
        <a href="http://www.physast.uga.edu/~rls/astro1020/ch16/ovhd.html">
            <img src="img/intro-eb.gif" width="300">
        </a>
    </td>
    <td>
        <a href="https://en.wikipedia.org/wiki/File:Algol_AB_movie_imaged_with_the_CHARA_interferometer_-_labeled.gif">
            <img src="https://media.giphy.com/media/aYb0Ob2GHJ280/giphy.gif" width="300">
        </a>
    </td></tr>
</table>
</center>

# Scientific motivation
***

- Variable stars as distance tracers: Milky-way maps
<table>
    <tr><td>   
        <img src="img/period-luminosity-relation.gif" width="400">
    </td><td>
        <img src="img/intro-milky-way.jpg" width="400">
    </td></tr>
</table>


- Variable star analysis and classification: **Astrophysics**
<center>
<a href="http://www.atnf.csiro.au/outreach/education/senior/astrophysics/variable_types.html">
    <img src="img/variable-star-classification.gif" width="400">
</a>
</center>
- New methods to analyze astronomical data: **Signal processing** and **Data Science**
    - Room for interdisciplinary research
    - Astroinformatics and Astrostatistics


# Light curve
***

- Time series of a star's flux (brightness) on a given passband
- The "apparent" brightness is estimated through **Photometry**
- Variable stars are studied through their light curves

<table><tr><td>
    <img src="img/intro-vista.png" width="250">
</td><td>
    <img src="img/intro-sources.png" width="300">
</td></tr></table>

<center>
    <img src="img/intro-sources-time.png" width="600">
</center>

In [None]:
import gzip
import pickle
import numpy as np
%matplotlib notebook
import matplotlib.pylab as plt
from matplotlib import rcParams, animation
rcParams.update({'font.size': 12})
rcParams.update({'axes.grid': True})

# Get some light curves to play
with gzip.open("data/lc_data.pgz", mode="r") as f:
    lc_data = pickle.load(f)

lc_periods = pickle.load(open("data/lc_periods.pkl", "rb"))

## Inspecting a light curve

In this case light curves are text files with three colums
- **Modified Julian Data (MJD):** Corresponds to time 
- **Magnitude:** Corresponds to apparent brightness (log scale)
- **Error:** Photometric error estimation of the magnitude

In [None]:
fig, ax = plt.subplots(figsize=(9, 4), tight_layout=True)
mjd, mag, err = lc_data[6].T
ax.errorbar(mjd, mag, err, fmt='o')
ax.invert_yaxis(); 
ax.set_xlabel('Modified Julian Date (MJD)\n ')
ax.set_ylabel('Magnitude\n(The smaller the brighter)');

- Irregular sampling, data gaps
- Heteroscedastic noise: Error variance change in time

This light curve is actually from a periodic variable star...

In [None]:
def fold(time, period):
    """
    returns phase = time/period - floor(time/period)
    """
    return np.mod(time, period)/period

idx = 6
mjd, mag, err = lc_data[idx].T
fig, ax = plt.subplots(figsize=(9, 4), tight_layout=True)
phi = fold(mjd, lc_periods[idx])
ax.errorbar(np.hstack((phi, phi+1)), 
            np.hstack((mag, mag)), 
            np.hstack((err, err)), fmt='o')
ax.invert_yaxis(); 
ax.set_ylabel('Magnitude\n(The smaller the brighter)');
ax.set_xlabel('Phase @ Period %0.6f' %(lc_periods[idx]));

### Folding the light curve

- Technique used by astronomers to visually inspect periodic variables
- You need a candidate period $P$ to perform the folding
- The time axis is divided in chucks of size $P$ and plotted on top each other

$$
\phi = \text{modulo}(\text{MJD}, P)/P
$$
- Then you plot the magnitude as a function of $\phi$ 
    - If $P$ is close to the true period:  Nice periodic shape
    - Otherwise: Noisy pattern



In [None]:
fig, ax = plt.subplots(figsize=(9, 4), tight_layout=True)
period_grid = np.linspace(lc_periods[6]-0.001, lc_periods[6]+0.001, num=100)
phi = fold(mjd, period_grid[0])
line, caps, errorbars = ax.errorbar(np.hstack((phi, phi+1)), 
                                    np.hstack((mag, mag)), 
                                    np.hstack((err, err)), fmt='o')
segs = errorbars[0].get_segments()
ax.invert_yaxis(); 
ax.set_ylabel('Magnitude\n(The smaller the brighter)');

def update(n):
    phi = fold(mjd, period_grid[n])
    for i in range(len(segs)//2):
        segs[i][:, 0] = phi[i]
        segs[i+len(phi)][:, 0] = phi[i]+1

    line.set_xdata(np.hstack((phi, phi+1)))
    errorbars[0].set_segments(segs)
    ax.set_xlabel('Phase @ Period %0.6f' %(period_grid[n]))

anim = animation.FuncAnimation(fig, update, frames=100, interval=100, repeat=False, blit=True)

### Periodogram

- We want to find the period (fundamental frequency) of the star
- This is generally done using the **Fourier transform** (FT) or correlation
- FT and correlation assumme regular time sampling
- Estimating the period with irregular sampling
    1. Least squares: Lomb-Scargle periodogram
    1. ANOVA periodogram
    1. Conditional Entropy and **Mutual Information** periodograms
    1. ....

In [None]:
!pip install P4J --user

In [None]:
import P4J
my_per = P4J.periodogram(method='QMIEU') 
mjd, mag, err = lc_data[6].T
my_per.set_data(mjd, mag, err, h_KDE_P=0.2)
my_per.frequency_grid_evaluation(fmin=0.0, fmax=4.0, fresolution=1e-4)
my_per.finetune_best_frequencies(fresolution=1e-5, n_local_optima=10)
freq, per = my_per.get_periodogram()
fbest, pbest  = my_per.get_best_frequencies()

fig, ax = plt.subplots(figsize=(9, 4), tight_layout=True)
ax.plot(freq, per)
ax.set_xlabel('Frequency [1/MJD]')
ax.set_ylabel('Periodogram')
print("Best period: %f days" %(1.0/fbest[0]))

In [None]:
fig, ax = plt.subplots(figsize=(9, 4), tight_layout=True)
phi = fold(mjd, 1.0/fbest[0])
ax.errorbar(np.hstack((phi, phi+1)), 
            np.hstack((mag, mag)), 
            np.hstack((err, err)), fmt='o')
ax.invert_yaxis(); 
ax.set_ylabel('Magnitude\n(The smaller the brighter)');
ax.set_xlabel('Phase @ Period %0.6f' %(1.0/fbest[0]));

## Getting features from our periodic light curves

- We want to train a neural network to discriminate a particular type of star: **RR Lyrae** 
- Given that we have the period we train on the folded light curve
- We will normalize and interpolate the folded light curve using kernel regression



In [None]:
def featurize_lc(lc_data, period, phi_interp, sp=0.15): 
    mjd, mag, err = lc_data.T
    phi = np.mod(mjd, period)/period
    mag_interp = np.zeros_like(phi_interp)
    err_interp = np.zeros_like(phi_interp)
    w = 1.0/err**2
    for i in range(len(phi_interp)):
        gt = np.exp((np.cos(2.0*np.pi*(phi_interp[i] - phi)) -1)/sp**2)
        norm = np.sum(w*gt)
        mag_interp[i] = np.sum(w*gt*mag)/norm
        err_interp[i] = np.sqrt(np.sum(w*gt*(mag - mag_interp[i])**2)/norm)
    err_interp += np.sqrt(np.median(err**2))
    idx_max =  np.argmin(mag_interp)
    mag_interp = np.roll(mag_interp, -idx_max)
    err_interp = np.roll(err_interp, -idx_max)
    max_val = np.amax(mag_interp + err_interp)
    min_val = np.amin(mag_interp - err_interp)
    mag_interp = 2*(mag_interp - min_val)/(max_val - min_val) - 1
    err_interp = 2*err_interp/(max_val - min_val)
    return mag_interp, err_interp, [max_val, min_val, idx_max]

def defeaturize_lc(mag, err, norm):
    # center, scale, idx_max = norm[0], norm[1], norm[2]
    max_val, min_val, idx_max = norm[0], norm[1], norm[2]
    idx_max = int(idx_max)
    return 0.5*(np.roll(mag, idx_max) +1)*(max_val - min_val) + min_val, 0.5*np.roll(err, idx_max)*(max_val - min_val)


phi_interp = np.linspace(0, 1, num=40)
features = np.zeros(shape=(len(lc_data), len(phi_interp)))
weights = np.zeros(shape=(len(lc_data), len(phi_interp)))
norm = np.zeros(shape=(len(lc_data), 3))
for i in range(len(lc_data)):
    features[i, :], weights[i, :], norm[i, :] = featurize_lc(lc_data[i], lc_periods[i], phi_interp)
 
import torch
from torch.utils.data import TensorDataset, DataLoader


lc_dataset = TensorDataset(torch.from_numpy(features.astype('float32')), 
                           torch.from_numpy(weights.astype('float32')),
                           torch.from_numpy(lc_periods.astype('float32')))

In [None]:
from IPython.display import display
from ipywidgets import Button

next_button = Button(description="Next")
idx = 4950
fig, ax = plt.subplots(figsize=(6, 5), tight_layout=True)

def plot_features(idx):
    ax.cla(); sample = lc_dataset.__getitem__(idx)
    ax.set_title("Idx: %d\nPeriod: %0.6f" %(idx, sample[2]))
    mag, err = defeaturize_lc(sample[0][:40], sample[1][:40], norm[idx])
    #mag, err = sample[0][:40].numpy(), sample[1][:40].numpy()
    ax.plot(phi_interp, mag, lw=4)
    ax.fill_between(phi_interp, (mag-err), (mag+err), alpha=0.5)
    ax.set_xlabel('Phase'); ax.set_ylabel('Normalized magnitude');
    mjd, mag, err = lc_data[idx][:, 0], lc_data[idx][:, 1], lc_data[idx][:, 2]
    phi = fold(mjd, sample[2])
    ax.errorbar(phi, mag, err, fmt='.', c='k', alpha=0.5, label='data'); 
    ax.invert_yaxis()
    plt.legend();

def on_nbutton_clicked(b):
    global idx
    idx += 1
    plot_features(idx)
                
next_button.on_click(on_nbutton_clicked)
plot_features(idx)
next_button

# Trainining a variational autoencoder

1. In this part we will train an [autoencoder](https://docs.google.com/presentation/d/1IJ2n8X4w8pvzNLmpJB-ms6-GDHWthfsJTFuyUqHfXg8/edit?usp=sharing) to visualize the feature space 
- We will use [PyTorch](https://pytorch.org/) to create and train the model
- We have light curves with unknown label and 50 light curves labeled as **RR Lyrae**
- Can we find unlabeled light curves that belong to the RR Lyrae class?

In [None]:
from torch.nn import functional as F

def logsumexp(inputs, dim=None, keepdim=True):    
    # From: https://github.com/YosefLab/scVI/issues/13
    return (inputs - F.log_softmax(inputs, dim=dim)).sum(dim, keepdim=keepdim)

class VAE(torch.nn.Module):
    def __init__(self, n_input=40, n_hidden=64, n_latent=2, importance_sampling=False):
        super(VAE, self).__init__()
        self.importance = importance_sampling
        # Encoder layers
        self.enc_hidden = torch.nn.Linear(n_input, n_hidden)
        self.enc_mu = torch.nn.Linear(n_hidden, n_latent)
        self.enc_logvar = torch.nn.Linear(n_hidden, n_latent)
        # decoder layers
        self.dec_hidden = torch.nn.Linear(n_latent, n_hidden) 
        self.dec_mu = torch.nn.Linear(n_hidden, n_input)
        self.dec_logvar = torch.nn.Linear(n_hidden, 1)

        
    def encode(self, x):
        h = F.relu(self.enc_hidden(x))
        return self.enc_mu(h), self.enc_logvar(h)

    def sample(self, mu, logvar, k=1):
        batch_size, n_latent = logvar.shape
        std = (0.5*logvar).exp()
        eps = torch.randn(batch_size, k, n_latent, device=std.device, requires_grad=False)
        return eps.mul(std.unsqueeze(1)).add(mu.unsqueeze(1))

    def decode(self, z):
        h = F.relu(self.dec_hidden(z))
        hatx, hatlogvar = self.dec_mu(h), self.dec_logvar(h)
        return hatx, hatlogvar        

    def forward(self, x, k=1):
        enc_mu, enc_logvar = self.encode(x)
        z = self.sample(enc_mu, enc_logvar, k)
        dec_mu, dec_logvar = self.decode(z)
        return dec_mu, enc_mu, enc_logvar, dec_logvar, z
    
    def ELBO(self, x, w, mc_samples=1):  
        dec_mu, enc_mu, enc_logvar, dec_logvar, z = self.forward(x, mc_samples)
        #logpxz = -0.5*(2.*torch.log(w.unsqueeze(1) + (dec_logvar/2).exp()) \
        #               + (x.unsqueeze(1) - dec_mu).pow(2)/(dec_logvar.exp() + w.pow(2).unsqueeze(1))).sum(dim=-1)    
        logpxz = -0.5*(dec_logvar + (x.unsqueeze(1) - dec_mu).pow(2)/dec_logvar.exp()).sum(dim=-1)    
        
        if self.importance: # Importance-Weighted autoencoder (IWAE)
            logqzxpz = 0.5 * (z.pow(2) - z.sub(enc_mu.unsqueeze(1)).pow(2)/enc_logvar.unsqueeze(1).exp() - enc_logvar.unsqueeze(1)).sum(dim=-1)
        else:  # Variational autoencoder
            logqzxpz = -0.5 * (1.0 + enc_logvar - enc_mu.pow(2) - enc_logvar.exp()).sum(dim=-1).unsqueeze_(1)
        ELBO = torch.sum(logsumexp(logqzxpz - logpxz, dim=1) + np.log(mc_samples))
        return ELBO, logpxz.sum()/mc_samples, logqzxpz.sum()/logqzxpz.shape[1]
 

In [None]:
from vae import live_metric_plotter

batch_size_, nepochs, mc_samples = 8, 50, 32
torch.manual_seed(0)
np.random.seed(0)
P = np.random.permutation(5000)
train_loader = DataLoader(dataset=lc_dataset, batch_size=batch_size_, 
                          sampler=torch.utils.data.SubsetRandomSampler(P[:3000]))
valid_loader = DataLoader(dataset=lc_dataset, batch_size=batch_size_, 
                          sampler=torch.utils.data.SubsetRandomSampler(P[3000:]))

model = VAE(n_input=40, n_hidden=20, n_latent=2, importance_sampling=True)
print(model)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
print("Number of trainable parameters: %d" %(sum([np.prod(p.size()) for p in model_parameters])))

optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
plotter = live_metric_plotter(figsize=(9, 4))
metrics = np.zeros(shape=(nepochs, 2, 2))

for epoch in range(nepochs):
    # Train 
    for feature, weight, period in train_loader:
        optimizer.zero_grad()        
        loss, rec_loss, reg_loss = model.ELBO(feature, weight, mc_samples)        
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)        
        optimizer.step()
        metrics[epoch, 0, 0] += rec_loss.detach().numpy()/len(train_loader.dataset)
        metrics[epoch, 0, 1] += reg_loss.detach().numpy()/len(train_loader.dataset)
    # Test
    for feature, weight, period in valid_loader:
        loss, rec_loss, reg_loss = model.ELBO(feature, weight, mc_samples)
        metrics[epoch, 1, 0] += rec_loss.detach().numpy()/len(valid_loader.dataset)
        metrics[epoch, 1, 1] += reg_loss.detach().numpy()/len(valid_loader.dataset)
    
    if epoch > 0:
        plotter.update(epoch, metrics)

In [None]:
test_dec_mu, test_enc_mu, test_enc_logvar, test_dec_logvar, test_z = model.forward(lc_dataset.tensors[0], k=10)
test_enc_mu, test_enc_sigma = test_enc_mu.detach().numpy(), (test_enc_logvar.detach()*0.5).exp().numpy()
test_dec_mu = test_dec_mu.detach().numpy()

"""
The figure shows the digits in latent space as a dot (mean of the variational posterior) with 
errorbars (standard deviation of the variational posterior). Each point is a distribution!
"""
fig = plt.figure(figsize=(10, 5), dpi=80)
ax_main = plt.subplot2grid((2, 3), (0, 0), colspan=2, rowspan=2)
ax_ori = plt.subplot2grid((2, 3), (0, 2))
ax_rec = plt.subplot2grid((2, 3), (1, 2))
a, b, c = ax_main.errorbar(x=test_enc_mu[:, 0], y=test_enc_mu[:, 1], 
                           xerr=test_enc_sigma[:, 0], yerr=test_enc_sigma[:, 1], 
                           fmt='none', alpha=0.2, zorder=-1)
labels = np.zeros(shape=(5000,)); labels[4950:] = 1

sc = ax_main.scatter(test_enc_mu[:, 0], test_enc_mu[:, 1], s=2, alpha=0.2, 
                     c=labels, cmap=plt.cm.RdBu_r)
clb = plt.colorbar(sc, ax=ax_main)
for i in range(2):
    c[i].set_color(clb.to_rgba(labels))
    
    
c_lim, r_lim = ax_main.get_xlim(), ax_main.get_ylim()
plt.tight_layout()
phi_interp = np.linspace(0, 1, num=40)
def onclick(event):
    z_closest = [event.xdata, event.ydata]
    print(z_closest)
    idx = np.argmin(np.sum((test_enc_mu[:, :2] - z_closest)**2, axis=1))
    ax_ori.cla(); ax_ori.set_title("Idx:%d, Label:%d" %(idx, labels[idx]))
    mjd, mag, err = lc_data[idx].T
    phi = fold(mjd, lc_dataset.tensors[2][idx])
    ax_ori.errorbar(phi, mag, err, c='k', fmt='.')
    mag, err = defeaturize_lc(lc_dataset.tensors[0][idx].numpy(), 
                              lc_dataset.tensors[1][idx].numpy(), norm[idx])
    ax_ori.plot(phi_interp, mag, lw=2)
    ax_ori.fill_between(phi_interp, mag - err, mag + err, alpha=0.5)
    ax_ori.invert_yaxis(); 
    ax_rec.cla(); ax_rec.invert_yaxis(); #ax_rec.set_ylim([2.5, -1.5]); 
    mag, err = lc_dataset.tensors[0][idx].numpy(), lc_dataset.tensors[1][idx].numpy()
    ax_rec.plot(phi_interp, mag, lw=2)
    ax_rec.fill_between(phi_interp, mag - err, mag + err, alpha=0.5)
    mu_dec = np.mean(test_dec_mu[idx], axis=0)
    s_dec = np.std(test_dec_mu[idx], axis=0)
    ax_rec.plot(phi_interp, mu_dec, c='r', lw=2)
    ax_rec.fill_between(phi_interp, mu_dec-2*s_dec, mu_dec+2*s_dec, facecolor='r', alpha=0.5)
    
cid = fig.canvas.mpl_connect('button_press_event', onclick);