# Spender Diagnostics Plots

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch import optim
from accelerate import Accelerator

# machine optimized operations
accelerator = Accelerator(mixed_precision='fp16')

## Option 1: Load model explicitly (from local file or hub server, requires spender install)

In [None]:
import spender

url = "https://hub.pmelchior.net/spender.sdss.paperII-c273bb69.pt"
model = spender.load_model(url, instrument, map_location=torch.device('cpu'))

# also create the instrument
from spender.data.sdss import SDSS
sdss = SDSS()

## Option 2: Load model with hub (requires no prior spender install)

In [None]:
import torch.hub

github = "pmelchior/spender"
torch.hub.list(github)

In [None]:
print(torch.hub.help(github, 'sdss_II'))

In [None]:
sdss, model = torch.hub.load(github, 'sdss_II', map_location=accelerator.device)

## Load data (requires local files)

In [None]:
data_path = './DATA'
batch_size = 128
dataloader = sdss.get_data_loader(data_path, tag="variable", which="test", batch_size=batch_size)
dataloader = accelerator.prepare(dataloader)
batch = next(iter(dataloader))
spec, w, z, ids, norm, zerr = batch

## Single spectrum overview plot

In [None]:
import matplotlib
matplotlib.rcParams.update({'font.size': 14})
matplotlib.rcParams.update({'font.sans-serif': 'DejaVu Sans'})
matplotlib.rcParams.update({'font.family': 'sans-serif'})

def plot_spec(spec, w, wave, label=None, ax=None):
    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=(10, 4.5))
    else:
        fig = ax.get_figure()
        
    # plot the spectrum
    # masked areas in light color -> plot in segments of masked/unmasked spectrum
    change = (w[1:] > 0) != (w[:-1] > 0)
    start = 0
    label_set = False
    for point in list(change.nonzero().squeeze(0)) + [len(model.wave_obs)-1,]:
        color = "#ccc" if w[point] == 0 else "k"
        if not label_set and w[point] > 0:
            label_, label_set = "Data", True
        else:
            label_ = None
        ax.plot(wave[start:point], spec[start:point].detach(), c=color, label=label_, zorder=1)
        start = point
   
    if label is not None:
        ax.text(0.05, 0.95, label, ha='left', va='top', transform=ax.transAxes)
     
    ax.set_xlabel('Wavelength [Å]')
    ax.set_ylabel('Normalized Flux')
    ax.set_xlim(wave[0], wave[-1])
    ylim = [0, 1.1 * torch.max((w > 0)*spec).item()]
    ax.set_ylim(*ylim)
    
    return fig, ax

def plot_spec_model(model, spec, w, z, id, restframe=False):

    with torch.no_grad():
        _, spec_rest, spec_reco = model._forward(spec.unsqueeze(0), z=z.unsqueeze(0))
        # loss = model._loss(spec.unsqueeze(0), w, spec_reco, individual=True)
        spec_rest = spec_rest.squeeze(0).detach().cpu()
        spec_reco = spec_reco.squeeze(0).detach().cpu()
        
    fig, axes = plt.subplots(2, 1, 
                             gridspec_kw={'height_ratios': (3,1), 'hspace': 0, 'wspace':0 }, 
                             figsize=(10,6))

    # rest or observed frame
    if not restframe:
        spec_model = spec_reco
        wave = wave_model = model.wave_obs.cpu()
        frame = "Observed"
    else:
        spec_model = spec_rest
        wave = (model.wave_obs / (1 + z)).cpu()
        wave_model = model.wave_rest.cpu()
        frame = "Restframe"

    # add ID
    plate,mjd,fiberid = id
    label = f'ID {plate}-{mjd}-{fiberid}\n $z={z:2f}$'
    plot_spec(spec.cpu(), w.cpu(), wave.cpu(), label=label, ax=axes[0]);
    
    # add model
    axes[0].plot(wave_model.cpu(), spec_model.cpu(), c='tab:red', label='Reconstruction', zorder=10)
    axes[0].get_xaxis().set_visible(False)
    axes[0].legend(loc='upper right', frameon=False)
    axes[0].set_xlim(wave[0], wave[-1])

    # residuals
    axes[1].plot(wave.cpu(), (spec.cpu() - spec_reco) * w.sqrt().cpu(), c='k', drawstyle='steps-mid')
    axes[1].set_ylabel(r'Residuals [$\sigma$]')
    axes[1].set_xlabel(f'{frame} Wavelength [Å]')
    axes[1].set_xlim(wave[0], wave[-1])


idx = 0
plot_spec_model(model, spec[idx], w[idx], z[idx], ids[idx])

## Overview plus zoom-in plots

In [None]:
from spender.data.emission_lines import *
from operator import itemgetter
from itertools import groupby

def find_edges(y):
    ranges = []
    for k,g in groupby(enumerate(y),lambda x:x[0]-x[1]):
        group = (map(itemgetter(1),g))
        group = list(map(int,group))
        if len(group) > 1:
            ranges.append((group[0], group[-1]))
        else:
            ranges.append((group[0], group[0]))
    return ranges

def plot_spec_zoom(model, y, z, w=None, ids=None, axes=None, color="tab:red", labels=('Data', 'Model')):
    
    n_test = len(y)    
    inch = 5
    if axes is None:
        fig, axes = plt.subplots(n_test, 5, sharex="col", sharey="row", figsize=(inch*inch, n_test*inch))
        if n_test == 1: # make sure we have a doubly indexed set of axes
            axes = [axes,]
        reuse_fig = False
    else:
        fig = axes[0][0].get_figure()
        reuse_fig = True
    
    with torch.no_grad():
        _, spec_rest, spec_recon = model._forward(y, z=z)
        
    def plot_panel(sub, wave_rest, spec, wave_model, spec_model, w=None, xlims=None, plot_spec=True, plot_lines=True, label_lines=False, labels=None):
        
        if labels is None:
            labels = (None, None)
        drawstyle = "steps-mid" if plot_lines else None
        
        ylims = [0, 1.2 * torch.max((w > 0)*spec).item()]
        if plot_spec:
            sub.plot(wave_rest, spec, drawstyle=drawstyle, c='k', label=labels[0], lw=2) 
        sub.plot(wave_model, spec_model, drawstyle=drawstyle, c=color, label=labels[1], lw=2, zorder=10) 
            
        # lines
        if plot_lines:
            for line,label in lines.items():
                if xlims[0] < line < xlims[1]:
                    sub.axvline(line, ymax=1.05, color='#888888', lw=0.5, zorder=-1, clip_on=False)
                    if label_lines:
                        sub.text(line, 1.08 * ylims[1], label, ha='center', va='bottom', rotation=90, color='#888888')                    

        # weights: only show in overview panel because horizontal location is mildly off
        if w is not None:
            cmap = plt.get_cmap('BuPu')
            max_color = 0.15
            w_color = cmap(max_color)
            if plot_lines is False:
                weight_im = np.ma.masked_where(w > 0, max_color * np.ones(w.shape)).reshape(1,-1)
                sub.imshow(weight_im, aspect='auto', extent=(wave_rest.min(), wave_rest.max(), 0, ylims[1]), cmap=cmap, vmin=0, vmax=1, zorder=-2, rasterized=True)
                
                # show unobserved region as masked
                sub.axvspan(wave_model.min(), wave_rest.min(), color=cmap(max_color), zorder=-2)
                sub.axvspan(wave_rest.max(), wave_model.max(), color=cmap(max_color), zorder=-2)
                
            else:
                # show unobserved region as masked
                if xlims[0] < wave_rest.min():
                    sub.axvspan(wave_model.min(), wave_rest.min(), color=w_color, zorder=-2)
                if xlims[1] > wave_rest.max():
                    sub.axvspan(wave_rest.max(), wave_model.max(), color=w_color, zorder=-2)

                # find masked regions in w
                sel = np.arange(len(w))[w <= 1e-6]
                ranges = find_edges(sel)
                for r in ranges:
                    sub.axvspan(wave_rest[r[0]], max(wave_rest[r[1]], wave_rest[r[0]+1]), color=w_color, zorder=-2)
            
        if xlims is not None:
            sub.set_xlim(*xlims)
        sub.set_ylim(ymin=0, ymax=ylims[1])
        
    for i in range(n_test):
        wave_rest = (model.wave_obs/(1+z[i])).cpu()
        wave_model = model.decoder.wave_rest.cpu()
        spec_ = y[i].cpu()
        spec_rest_ = spec_rest[i].cpu()
        z_ = z[i].cpu()
        w_ = w[i].cpu()
        if ids is not None:
            plate,mjd,fiberid = ids[i]
            title = f'ID {plate}-{mjd}-{fiberid}\n $z={z_:2f}$'
        else:
            title = f'$z={z_:2f}$'
        spec_recon_ = spec_recon[i].cpu()
        spec_rest_ = spec_rest[i].cpu()
        
        label_lines = i == 0
        xlims = wave_model.min(), wave_model.max()
        plot_panel(axes[i][0], wave_rest, spec_, wave_model, spec_rest_, w=w_, xlims=xlims, plot_spec=~reuse_fig, plot_lines=False, label_lines=label_lines, labels=labels)
        axes[i][0].set_ylabel('Normalized Flux')
        if not reuse_fig:
            axes[i][0].text(0.05, 0.95, title, ha='left', va='top', transform=axes[i][0].transAxes)
        
        # OII
        lmbda, delta = 3727, 90
        xlims = (lmbda - delta, lmbda + delta)
        plot_panel(axes[i][1], wave_rest, spec_, wave_model, spec_rest_, w=w_, xlims=xlims, plot_spec=~reuse_fig, label_lines=label_lines, labels=labels)

        # Hbeta, OIII
        lmbda, delta = 4940, 90
        xlims = (lmbda - delta, lmbda + delta)
        plot_panel(axes[i][2], wave_rest, spec_, wave_model, spec_rest_, w=w_, xlims=xlims, plot_spec=~reuse_fig, label_lines=label_lines, labels=labels)
        
        # Na
        lmbda, delta = 5895, 90
        xlims = (lmbda - delta, lmbda + delta)
        plot_panel(axes[i][3], wave_rest, spec_, wave_model, spec_rest_, w=w_, xlims=xlims, plot_spec=~reuse_fig, label_lines=label_lines, labels=labels)
        
        # Halpha, NII
        lmbda, delta = 6564, 90
        xlims = (lmbda - delta, lmbda + delta)
        plot_panel(axes[i][4], wave_rest, spec_, wave_model, spec_rest_, w=w_, xlims=xlims, plot_spec=~reuse_fig, label_lines=label_lines, labels=labels)
        if i == 0:
            axes[i][4].legend(loc='upper right', frameon=False).set_zorder(100)

        
    for sub in axes[-1]:
        sub.set_xlabel(r'Restframe $\lambda$ [\AA]')
    fig.tight_layout()
    fig.subplots_adjust(hspace=0.05, wspace=0.05)
    return fig, axes

In [None]:
idx = slice(0,5)
plot_spec_zoom(model, spec[idx], z[idx], w=w[idx], ids=ids[idx]);

In [None]:
# get special set for paper figure 2
ids = ((412, 52254, 308), (412, 52250, 129), (410, 51877, 560), (406, 51900, 15), (404, 51877, 83))
spec, w, z, norm, zerr = sdss.make_batch(data_path, ids)
plot_spec_zoom(model, spec, z, w=w, ids=ids);

## Attention plots: Grad-FAM

In [None]:
# get dataloader batch again
batch_size = 128
dataloader = sdss.get_data_loader(data_path, tag="variable", which="test", batch_size=batch_size)
model, dataloader = accelerator.prepare(model, dataloader)
batch = next(iter(dataloader))
spec, w, z, ids, norm, zerr = batch

In [None]:
def grad_fam(model, spec, z, l_callback, combined=False):
    # compute attention value and weights
    with torch.no_grad():
        h, a = model.encoder._downsample(spec)
        a = model.encoder.softmax(a)

    # compute spectrum reconstruction, with gradients!
    s, spec_rest, spec_reco = model._forward(spec, z=z)
    
    # compute specific l
    l = l_callback(model, spec_reco, spec_rest)
    l.backward()

    att = a.detach()
    att_grad = model.encoder.attention_grad.detach()
    if combined:
        return att * att_grad
    return att, att_grad
    
def l_halpha(model, spec_reco, spec_rest, dim=None):
    sel = (model.wave_rest > 6560) & (model.wave_rest < 6570)
    diff = spec_rest[:,sel] - 1
    if dim is None:
        return torch.sum(diff)
    else:
        diff = diff.reshape(len(spec_reco), sel.sum())
        return torch.sum(diff, dim=dim)

def grad_fam_halpha(model, spec, z, combined=False):
    return grad_fam(model, spec, z, l_halpha, combined=combined)

att, att_grad = grad_fam_halpha(model, spec, z)

In [None]:
with torch.no_grad():
    s, spec_rest, spec_reco = model._forward(spec, z=z)

In [None]:
from matplotlib.collections import LineCollection

def att_plot(att, att_grad, spec, w, ids, wave_obs, label=None):
    fig, axes = plt.subplots(2, 1, gridspec_kw={'height_ratios': (3,1), 'hspace': 0, 'wspace':0 }, figsize=(6,6))

    att_grad_ = np.maximum(0, (att * att_grad).cpu().numpy())
    cutoff = np.percentile(att_grad_, 95)
    size = 30 * att_grad_**2
    color = att_grad.cpu()
    
    n_channel, n_segment = att.shape
    y, x = np.mgrid[:n_channel, :n_segment]
    sel = (att_grad_ > cutoff)
    sc = axes[0].scatter(x[sel], y[sel], s=size[sel], c=color[sel], cmap='YlOrRd', vmin=0, rasterized=True)
    axes[0].set_xlim(0,n_segment)
    axes[0].set_ylim(0,n_channel)
    axes[0].set_xlabel('Wavelength Segment')
    axes[0].set_ylabel('Attention Channel')
    axes[0].xaxis.tick_top()
    axes[0].xaxis.set_label_position('top')
    
    # add line and label to largest attention channel
    max_channel = np.argmax(att_grad_.max(-1))
    axes[0].axhline(max_channel, xmax=1.01, color='#888888', lw=0.5, zorder=-1, clip_on=False)
    axes[0].text(1.02, max_channel / n_channel, max_channel, c='#888', ha='left', va='center', transform=axes[0].transAxes)
    
    if label is not None:
        axes[0].text(0.05, 0.95, label, color='C3', ha='left', va='top', transform=axes[0].transAxes)
    
    #cb = fig.colorbar(sc, ax=axes[0], aspect=50)
    #cb.set_label(r'$\nabla_e l$')

    #axes[1].plot(spec, c='k', lw=0.5)
    ylims = [0, 1.3 * torch.max(spec * ( w > 0)).item()]
    axes[1].set_ylim(ylims)
    axes[1].set_xlabel(r'Observed $\lambda\,[\AA]$')
    axes[1].set_ylabel('Flux')

    # plot xaxis is in spectral elements, show observed wavelength instead
    L = len(spec)
    axes[1].set_xlim(0, L)
    new_tick_labels = np.array([ 4000, 5000, 6000, 7000, 8000, 9000 ])
    new_tick_locations = [ np.argmin(np.abs(wave_obs.cpu().numpy() - l)) for l in new_tick_labels]
    axes[1].set_xticks(new_tick_locations);
    axes[1].set_xticklabels(new_tick_labels);
    
    # add Grad-CAM like plot
    # sum all channels
    att_sum = att_grad_.sum(0)
    # find the approximate wavelength region for each segment
    l_s = L // n_segment # elements per segment
    l_r = 1208 # CNN receptive window
    w = int(np.floor(l_r / l_s)) + 1
    att_sum_ = np.convolve(att_sum, np.ones(w), 'same') / w
    norm = plt.Normalize(att_sum_.min(), 2*att_sum_.max())
    
    segments = [ np.stack((np.arange(s*l_s, (s+1)*l_s), spec[s*l_s: (s+1)*l_s].cpu().numpy()), axis=1) for s in range(n_segment) ]
    lc = LineCollection(segments, cmap='hot', norm=norm, rasterized=True)
    # Set the values used for colormapping
    lc.set_array(att_sum_)
    lc.set_linewidth(1)
    axes[1].add_collection(lc)
    
    plate,mjd,fiberid = ids
    title = f'ID {plate}-{mjd}-{fiberid}'
    axes[1].text(0.05, 0.9, title, ha='left', va='top', transform=axes[1].transAxes)
    
    fig.tight_layout()
    return fig

idx = 14 # 108
fig = att_plot(att[idx], att_grad[idx], spec[idx], w[idx], ids[idx], model.wave_obs, label=r"H$\alpha$ emission")

In [None]:
plt.figure(figsize=(6,6))
att_grad_ = att * att_grad
channel_ = att_grad_[:, 106, :].sum(-1)
norm = plt.Normalize(channel_.min(), 2*channel_.max())

cmap = plt.get_cmap('hot')

for i, idx in enumerate(torch.argsort(channel_, 0, descending=True)[::8][:6]):
    plt.plot(model.wave_rest.cpu(), spec_rest[idx].cpu()+5*i, c=cmap(norm(channel_[idx].cpu())), lw=1)
    plate, mjd, fiberid = ids[idx]
    plt.text(9100, 5*i+3, f'ID {plate}-{mjd}-{fiberid}', color=cmap(norm(channel_[idx].cpu())), ha='right', va='top')
plt.ylim(0, 30)
plt.xlabel(r'Restframe $\lambda\,[\mathrm{\AA}]$')
plt.yticks([])
plt.xlim(3600,9200)

## Embedding the latents

In [None]:
# data loader with all spectra
dataloader = sdss.get_data_loader(data_path, tag="variable", batch_size=batch_size)
dataloader = accelerator.prepare(dataloader)

ss, losses, halphas, zs, norms, ids = [], [], [], [], [], []
with torch.no_grad():
    for spec, w, z, id, norm, zerr in dataloader:
        # need the latents, of course
        s = model.encode(spec)
        # everything else is to color code the UMap
        s, spec_rest, spec_reco = model._forward(spec, z=z, s=s) # reuse latents
        loss = model._loss(spec, w, spec_reco, individual=True)
        halpha = l_halpha(model, spec_reco, spec_rest, dim=-1)
        
        ss.append(s.cpu())
        losses.append(loss.cpu())
        halphas.append(halpha.cpu())
        zs.append(z.cpu())
        norms.append(norm.cpu())
        ids.append(id.cpu())

ss = np.concatenate(ss, axis=0) # converts to numpy
losses = np.concatenate(losses, axis=0)
halphas = np.concatenate(halphas, axis=0)
zs = np.concatenate(zs, axis=0)
norms = np.concatenate(norms, axis=0)
ids = np.concatenate(ids, axis=0)

In [None]:
import corner
fig = corner.corner(ss);

In [None]:
import umap
reducer = umap.UMAP()
embedding = reducer.fit_transform(ss)

In [None]:
!pip install scikit-learn
!pip install astropy

import os
# get stellar masses from JHU-MPA catalog
# download from here and place in data_path:
# https://www.sdss3.org/dr8/spectro/spectro_access.php
filename = os.path.join(data_path, "galSpecExtra-dr8.fits")

from astropy.io import fits
hdu = fits.open(filename)
t_plate = hdu[1].data['PLATEID']
t_mjd = hdu[1].data['MJD']
t_fiber = hdu[1].data['FIBERID']
t_sm = hdu[1].data['LGM_FIB_P50']

# pairwise match based on plate, mjd, fiberid
from sklearn.neighbors import KDTree
tree = KDTree(np.stack((t_plate, t_mjd, t_fiber), axis=-1))
d, idx = tree.query(ids, k=1)
found = (d.reshape(-1) == 0)
print ("found", found.sum(), "of", len(found), "matches")
sms = np.zeros(len(ids))
sms[found] = t_sm[idx[found]].reshape(-1)

In [None]:
def load_n_embed(model, label):
    assert label in ['agn', 'agn_broad', 'starburst', 'starburst_broad', 'starforming', 'starforming_broad']

    # make batch for each of the subsets
    filename = os.path.join(data_path, f'{label}_ids.npy')
    ids = np.load(filename)

    # get batch
    spec, w, z, id, norm, zerr = sdss.make_batch(data_path, ids)

    with torch.no_grad():
        s = model.encode(spec)
    embed = reducer.transform(s)
    
    return spec, w, z, norm, embed

fig, axes = plt.subplots(2, 2, sharey="row", figsize=(10,10))

# no coloring, but subsamples added
sc = axes[0][0].scatter(embedding[:,0], embedding[:,1], s=5, c='#e0e0e0', rasterized=True)

from matplotlib.lines import Line2D
legend_elements = []
for label,marker,vmax,cmap in zip(['starforming', 'starburst', 'agn_broad'], ['o', 'v', 's'], [0.2, 0.15, 0.4], ['Reds', 'Blues', 'YlGn']):
    spe_c, w_, z_, norm_, embed_ = load_n_embed(model, label)
    cmap = matplotlib.cm.get_cmap(cmap)
    sc = axes[0][0].scatter(embed_[:,0], embed_[:,1], s=40, c=z_, cmap=cmap, vmin=0, vmax=vmax, marker=marker, label=label)
    legend_elements.append(Line2D([0], [0], marker=marker, color=cmap(150), label=label, lw=0))
axes[0][0].legend(handles=legend_elements, loc='lower right', frameon=False)
cb = plt.colorbar(sc, ax=axes[0][0], orientation='horizontal', ticks=[0.0, 0.1, 0.2, 0.3, 0.4], aspect=50, fraction=0.05, pad=0.02)
cb.set_label('$z$')
axes[0][0].set_xticks([])
axes[0][0].set_yticks([])

# Halpha coloring
sc = axes[0][1].scatter(embedding[:,0], embedding[:,1], s=3, c=np.log10(np.maximum(1e-2, halphas)), rasterized=True, cmap='inferno')
axes[0][1].set_xticks([])
axes[0][1].set_yticks([])
cb = plt.colorbar(sc, ax=axes[0][1], orientation='horizontal', ticks=[-1, 0, 1, 2], aspect=50, fraction=0.05, pad=0.02)
cb.set_label(r'$\log_{10}(l_{\mathrm{H}\alpha})$')

# redshift coloring
sc = axes[1][0].scatter(embedding[:,0], embedding[:,1], s=3, c=zs, rasterized=True, cmap='Spectral_r')
axes[1][0].set_xticks([])
axes[1][0].set_yticks([])
cb = plt.colorbar(sc, ax=axes[1][0], orientation='horizontal', aspect=50, fraction=0.05, pad=0.02)
cb.set_label('$z$')

# stellar mass
sel = sms > 5
sc = axes[1][1].scatter(embedding[sel,0], embedding[sel,1], s=5, c=sms[sel], cmap='viridis', rasterized=True)
axes[1][1].set_xticks([])
axes[1][1].set_yticks([])
cb = plt.colorbar(sc, ax=axes[1][1], orientation='horizontal', aspect=50, fraction=0.05, pad=0.02)
cb.set_label('$\log_{10}(M_\star)$')

fig.tight_layout()
fig.subplots_adjust(wspace=0.02)