In [1]:
import os, sys 
import numpy as np
import torch
from torch import nn
from torch import optim
from torch.distributions.normal import Normal
from model import *

# # --- plotting --- 
%matplotlib inline
import matplotlib.pyplot as plt

from accelerate import Accelerator

In [2]:
device = torch.device(type='cuda', index=0)
#save_data = "/scratch/network/melchoir/sdss_spectra.38816.npz"
save_data = "/scratch/gpfs/yanliang/sdss-spectra/sdss_spectra.38816.npz"
data = load_data(save_data, which="test", device=device)

Loading 5761 spectra (which = test)


RuntimeError: No CUDA GPUs are available

In [None]:
# plot raw spectra 
fig = plt.figure(figsize=(12,5))
sub = fig.add_subplot(111)
for i in np.random.randint(0, data['N'], 1): 
    sub.plot(data['wave'], data['y'][i].cpu(), c='k')
    
    # show masked regions
    ylim = sub.get_ylim()
    sel = data['w'][i].cpu() == 0
    t = np.zeros(len(data['wave']))
    t[sel] = ylim[1]
    sub.fill_between(data['wave'], t, facecolor='k', alpha=0.3)
    
sub.set_xlim(3800,9200);
sub.set_xlabel('wavelength');
sub.set_ylabel('normalized flux');
sub.set_ylim(0,ylim[1]); 

In [None]:
label, n_config = "model.series.z-forward.10.sdss.38816", 10
models, losses, best_model = load_models(label, n_config)
print("best model:", best_model)

In [None]:
emissionlines = {
1033.82: "O VI",
1215.24: "Lyα",
1240.81: "N V",
1305.53: "O I",
1335.31: "C II",
1397.61: "Si IV",
1399.8: "Si IV + O IV",
1549.48: "C IV",
1640.4: "He II",
1665.85: "O III",
1857.4: "Al III",
1908.734: "C III",
2326.0: "C II",
2439.5: "Ne IV",
2799.117: "Mg II",
3346.79: "Ne V",
3426.85: "Ne VI",
3727.092: "", #"O II",
3729.875: "O II",
3889.0: "He I",
4072.3: "S II",
4102.89: "Hδ",
4341.68: "Hγ",
4364.436: "O III",
4862.68: "Hβ",
4932.603: "O III",
4960.295: "O III",
5008.240: "O III",
6302.046: "O I",
6365.536: "O I",
6529.03: "N I",
6549.86: "N II",
6564.61: "Hα",
6585.27: "N II",
6718.29: "S II",
6732.67: "S II",
}

absorptionlines = {
3934.777: "K",
3969.588: "H",
4305.61: "G",
5176.7: "Mg",
5895.6: "Na",
# 6496.9: "Ba II",
8500.36: "Ca II",
8544.44: "Ca II",
8664.52: "Ca II",
}

skylines = [5578.5, 5894.6, 6301.7, 7246.0]
lines = {**emissionlines, **absorptionlines}

In [None]:
import matplotlib
from operator import itemgetter
from itertools import groupby

matplotlib.rcParams.update({'font.size': 14})
matplotlib.rcParams.update({'font.sans-serif': 'DejaVu Sans'})
matplotlib.rcParams.update({'font.family': 'sans-serif'})

def find_edges(y):
    ranges = []
    for k,g in groupby(enumerate(y),lambda x:x[0]-x[1]):
        group = (map(itemgetter(1),g))
        group = list(map(int,group))
        if len(group) > 1:
            ranges.append((group[0], group[-1]))
        else:
            ranges.append((group[0], group[0]))
    return ranges

def plot_spec(model, y, z):
    
    plt.figure()
    plt.plot(model.wave_obs.cpu(), y.cpu().T)
    for line in skylines:
        plt.axvline(line, color='#888888', lw=0.5, zorder=-1)
    plt.xlabel('Observed $\lambda$ [Å]')

    with torch.no_grad():
        s, spectrum_restframe, spectrum_observed = model._forward(y, z0=z)
        
    plt.figure()
    plt.plot(model.wave_obs.cpu(), spectrum_observed.cpu().T);
    plt.xlabel('Observed $\lambda$ [Å]')

    plt.figure()
    plt.plot(model.wave_rest.cpu(), spectrum_restframe.cpu().T);
    
    xlims = plt.gca().get_xlim()
    for line,label in lines.items():
        if xlims[0] < line < xlims[1]:
            plt.axvline(line, color='#888888', lw=0.5, zorder=-1)
    plt.xlabel('Restframe $\lambda$ [Å]')
    

def plot_spec_zoom(model, y, z, id, w=None):
    
    n_test = len(y)    
    fig, axes = plt.subplots(n_test, 5, sharex="col", sharey="row", figsize=(5*5, n_test*5))
    
    with torch.no_grad():
        _, spec_rest, spec_recon = model._forward(y, z0=z)
        spec_rest = spec_rest.cpu()
        spec_recon = spec_recon.cpu()
        
    def plot_panel(sub, wave_rest, spec, wave_model, spec_model, w=None, xlims=None, plot_lines=True, label_lines=False):
        sub.plot(wave_rest, spec, drawstyle='steps-mid', c='k', label='Data', lw=2) 
        ylims = sub.get_ylim()
        sub.plot(wave_model, spec_model, drawstyle='steps-mid', c='tab:red', label='Model', lw=2, zorder=5) 
            
        # lines
        if plot_lines:
            for line,label in lines.items():
                if xlims[0] < line < xlims[1]:
                    sub.axvline(line, ymax=1.05, color='#888888', lw=0.5, zorder=-1, clip_on=False)
                    if label_lines:
                        sub.text(line, 1.08 * ylims[1], label, ha='center', va='bottom', rotation=90, color='#888888')                    

        # weights: only show in overview panel because horizontal location is mildly off
        if w is not None:
            cmap = matplotlib.cm.get_cmap('BuPu')
            max_color = 0.35
            if xlims is None:
                colors = (np.sqrt(w.max()) - np.sqrt(w))/np.sqrt(w.max()) * max_color
                sub.imshow(colors.reshape(1,-1), aspect='auto', extent=(wave_rest.min(), wave_rest.max(), 0, ylims[1]), cmap=cmap, vmax=1, zorder=-2, rasterized=True)
                
                # show unobserved region as masked
                sub.axvspan(wave_model.min(), wave_rest.min(), color=cmap(max_color), zorder=-2)
                sub.axvspan(wave_rest.max(), wave_model.max(), color=cmap(max_color), zorder=-2)
                
            else:
                # show unobserved region as masked
                if xlims[0] < wave_rest.min():
                    sub.axvspan(wave_model.min(), wave_rest.min(), color=cmap(max_color), zorder=-2)
                if xlims[1] > wave_rest.max():
                    sub.axvspan(wave_rest.max(), wave_model.max(), color=cmap(max_color), zorder=-2)
                
                # find masked regions in w
                sel = np.arange(len(w))[w <= 1e-6]
                ranges = find_edges(sel)
                for r in ranges:
                    sub.axvspan(wave_rest[r[0]], max(wave_rest[r[1]], wave_rest[r[0]+1]), color=cmap(max_color), zorder=-2)
            
        if xlims is not None:
            sub.set_xlim(*xlims)
        sub.set_ylim(ymin=0, ymax=ylims[1])
        
    for i in range(n_test):
        wave_rest = (model.wave_obs/(1+z[i])).cpu()
        wave_model = model.wave_rest.cpu()
        spec_ = y[i].cpu()
        z_ = z[i].cpu()
        w_ = w[i].cpu()
        id_ = id[i]
        spec_recon_ = spec_recon[i]
        spec_rest_ = spec_rest[i]
        
        label_lines = i == 0
        sub = axes[i][0]
        xlims = None
        plot_panel(sub, wave_rest, spec_, wave_model, spec_rest_, w=w_, xlims=xlims, plot_lines=False, label_lines=label_lines)
        sub.set_ylabel('normalized flux') 
        sub.legend(loc='best', frameon=False, title='ID {}\n $z={:2f}$'.format(id_, z_))
        sub.set_xlim(3000, 9200)
        
        # OII
        lmbda, delta = 3727, 50
        xlims = (lmbda - delta, lmbda + delta)
        plot_panel(axes[i][1], wave_rest, spec_, wave_model, spec_rest_, w=w_, xlims=xlims, label_lines=label_lines)

#         # only for paper
#         if i == 0:
#             sub = axes[i][1]
#             sub.text(0.45, 0.95, 'unobserved', c='w', ha='right', va='top', transform=sub.transAxes)
#             sub.text(0.65, 0.95, 'observed', c='k', ha='left', va='top', transform=sub.transAxes)
        
        # Hbeta, OIII
        lmbda, delta = 4932, 100
        xlims = (lmbda - delta, lmbda + delta)
        plot_panel(axes[i][2], wave_rest, spec_, wave_model, spec_rest_, w=w_, xlims=xlims, label_lines=label_lines)
        
        # Na
        lmbda, delta = 5895, 100
        xlims = (lmbda - delta, lmbda + delta)
        plot_panel(axes[i][3], wave_rest, spec_, wave_model, spec_rest_, w=w_, xlims=xlims, label_lines=label_lines)
        
        lmbda, delta = 6564, 100
        xlims = (6450, 6650)
        plot_panel(axes[i][4], wave_rest, spec_, wave_model, spec_rest_, w=w_, xlims=xlims, label_lines=label_lines)

    for sub in axes[-1]:
        sub.set_xlabel('Restframe $\lambda$ [Å]')
    fig.tight_layout()
    fig.subplots_adjust(hspace=0.05, wspace=0.05)
    return fig

In [None]:
indices = [4029, 
           3592, 
           2480, 
           1102, 
           85]
fig = plot_spec_zoom(models[0], data['y'][indices], data['z'][indices], data['id'][indices], w=data['w'][indices])
# fig.savefig('examples.pdf')

In [None]:
n_test = 5
indices = np.random.randint(0, data['N'], n_test)
plot_spec(models[0], data['y'][indices], data['z'][indices])

## Embedding the latents

In [None]:
model = models[0]

batch_size=1024
loader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(data['y'],),
    batch_size=batch_size,
    shuffle=False)

s = None
for y_, in loader:
    with torch.no_grad():
        s_ = model.encode(y_)
        if s is None:
            s = s_.cpu().numpy()
        else:
            s = np.concatenate((s, s_.cpu().numpy()), axis=0)

In [None]:
import corner
corner.corner(s);

In [None]:
import umap

reducer = umap.UMAP()
embedding = reducer.fit_transform(s)

In [None]:
matplotlib.rcParams.update({'font.size': 10})

fig, axes = plt.subplots(1, 3, sharey="row", figsize=(15,5))
sc = axes[0].scatter(embedding[:,0], embedding[:,1], s=3, c=data['z'].cpu(), rasterized=True)
axes[0].set_xticks([])
axes[0].set_yticks([])
cb = plt.colorbar(sc, ax=axes[0], orientation='horizontal', aspect=50, fraction=0.05, pad=0.02)
cb.set_label('$z$')

sc = axes[1].scatter(embedding[:,0], embedding[:,1], s=3, c=np.log10(data['norm']), rasterized=True)
axes[1].set_xticks([])
axes[1].set_yticks([])
cb = plt.colorbar(sc, ax=axes[1], orientation='horizontal', aspect=50, fraction=0.05, pad=0.02)
cb.set_label('$\log_{10}(\mathrm{median\ flux})$')

def load_n_embed(label):
    # wave, y, w, z, zerr, norm, id
    test_data = load_data(
        f'/scratch/network/melchoir/sdss_spectra.{label}.npz', 
         device=device)

    s_ = model.encode(test_data['y'])
    s_ = s_.cpu().detach().numpy()
    # sort by redshift for better visual distinction
    z = test_data['z'].cpu().numpy()
    order = np.argsort(z)
    embed = reducer.transform(s_)
    return embed[order], z[order]

axes[2].scatter(embedding[:,0], embedding[:,1], s=5, c='#e0e0e0', rasterized=True)

from matplotlib.lines import Line2D
legend_elements = []
for label,marker,vmax,cmap in zip(['starburst', 'agn_broad', 'starforming'], ['v', 's', 'o'], [0.15, 0.5, 0.2], ['Blues', 'YlOrBr', 'Reds']):
    embedding_, z_ = load_n_embed(label)
    cmap = matplotlib.cm.get_cmap(cmap)
    sc = axes[2].scatter(embedding_[:,0], embedding_[:,1], s=40, c=z_, cmap=cmap, vmin=0, vmax=vmax, marker=marker, label=label)
    legend_elements.append(Line2D([0], [0], marker=marker, color=cmap(150), label=label, lw=0))
axes[2].legend(handles=legend_elements, loc='upper left', frameon=False)

cb = plt.colorbar(sc, ax=axes[2], orientation='horizontal', ticks=[0.0, 0.1, 0.2], aspect=50, fraction=0.05, pad=0.02)
cb.set_label('$z$')
#cb.set_ticks([[0,0.1,0.2]])
axes[2].set_xticks([])
axes[2].set_yticks([])

fig.tight_layout()
fig.subplots_adjust(wspace=0.02, hspace=0)
# fig.savefig('umap.pdf')

## Look at specific subpopulations

In [None]:
# wave, y, w, z, zerr, norm, id
test_data = load_data(
    '/scratch/network/melchoir/sdss_spectra.agn_broad.npz', 
     device=device)

s_ = model.encode(test_data['y'])
s_ = s_.cpu().detach().numpy()

embedding_ = reducer.transform(s_)
sc = plt.scatter(embedding[:,0], embedding[:,1], s=2, c='#aaaaaa')
plt.scatter(embedding_[:,0], embedding_[:,1], s=10, c=test_data['z'].cpu().numpy())
cb = plt.colorbar(sc)
cb.set_label('$z$')

plt.grid()
# plt.xlim(10,15)
# plt.ylim(4,6)

In [None]:
n_test = 5
indices = np.random.randint(0, test_data['N'], n_test)
# indices = indices[:5]
fig = plot_spec_zoom(model, test_data['y'][indices], test_data['z'][indices], test_data['id'][indices], w=test_data['w'][indices])

## Redshift estimation (experimental!!)

In [None]:
def solve_z(self, x, w, z_guess, n_epoch=10, lr=1e-3):
    model.eval()
    accelerator = Accelerator()
    z = torch.tensor(z_guess.clone().detach(), requires_grad=True)
    print(z)
    optimizer = optim.Adam([z,], lr=lr)
    scheduler = optim.lr_scheduler.OneCycleLR(optimizer, lr, total_steps=n_epoch)
    self, optimizer, z, x, w = accelerator.prepare(self, optimizer, z, x, w)
    
    for epoch in range(n_epoch):
        optimizer.zero_grad()
        s, spectrum_restframe, spectrum_observed = self._forward(x, z0=z)
        loss = self._loss(x, w, spectrum_observed)
        accelerator.backward(loss)
        optimizer.step()        
        scheduler.step()
    return z


In [None]:
s_ = model.forward(y_test)

In [None]:
z__ = solve_z(model, y_test[:1000], w_test[:1000], z_test[:1000], n_epoch=10, lr=1e-3)

In [None]:
plt.scatter(z_test.cpu(), z_.cpu(), s=1)
plt.gca().set_aspect('equal')

In [None]:
sc = plt.scatter(z_test[:1000].cpu(), z__.cpu(), s=2, alpha=0.5, c=np.log10(data['norms'][test_slice]))
cb = plt.colorbar(sc)
plt.gca().set_aspect('equal')