# Blazar Synchrotron Peak Estimator (BlaSE)

BlaSE is an ensemble of neural networks to estimate the synchrotron with prediction interval based of an blazar given its spectral energy distribution as output by the VOUBlazar tool. The goal of this notebook is to train the networks. This notebook wont feature any hyperparameter tuning since that already happened as part of my bachelor thesis.

As a special quirk the tool uses double ensemble: Bagging is applied to create different subset to train. For each subset an ensemble is trained to improve the prediction as stated by Blundell et. al. 2017.

In the end the bagging allows to reapply the model on the out of bag data to hopefully increase the quality of the dataset. The final tool is expected to be applied on unseen data and thus uses all available ensembles.

## Data Preparation

We start by parsing the data set. It is simply a zip containing seds as produced by the VOUBlazar tool and looks like the following:

```
   1  matched source  227.16230  -49.88400  99
 Frequency     nufnu     nufnu unc.  nufnu unc. start time   end time   Catalog     Reference
    Hz       erg/cm2/s     upper       lower        MJD         MJD   
---------------------------------------------------------------------------------------------------------------------------
 2.418E+17   2.185E-13   3.139E-13   1.230E-13  55000.0000  55000.0000  RASS        Boller et al. 2016, A&A, 103, 1                                                                                                                                                                         
 2.418E+17   5.085E-13   6.281E-13   3.889E-13  58150.0000  58150.1016  OUSXB       Giommi et al. 2019, Accepted for publication in A&A  
```

Thus, we have 4 lines of header until the actual data begins. After that we are only interested in the first 4 entries per row.
The error is given for each direction separately, but would like a symmetric one better. We use the mean squarred error instead.

Next, we need the target value, the nupeak. It's stored in the filename (as it was determined by hand). Consider an example file:
```
SED_11.78_227.1623_-49.8840_PMNJ1508-49_1.551
```
The peak is stored in the first float.

Next, we have to sanitize the data. We encoutered following problems:
 - Zero Frequency
 - negative or zero flux
 - flux outside upper/lower bound

Finally the data is binned. This is necessary since the neural network is a simple fcn and thus expects a constant sized input. There are gaps to leave out biased data. some bins are only present for specific target values. The neural network would only look if the bin is present and thus fail to generalize making them useless for unseen data. The actual bin edges were determined beforehand such that they are approximately equally densly filled.

In [None]:
import numpy as np
from os.path import basename, splitext
from zipfile import ZipFile

data = []
label = []
pos = []
seds = ZipFile('SEDs.zip')
bin_edges = np.loadtxt('blase/bins.txt')
n_bins = bin_edges.shape[0]

#data loading
def getpeak(filename):
    filename = basename(filename)
    return float(filename.split("_")[1])
def getpos(filename):
    name = splitext(basename(filename))[0]
    parts = name.split('_')
    return float(parts[2]) ,float(parts[3])
def sanitize(_data):
    _data = np.delete(_data, _data[:,0] <= 0, axis=0)
    _data = np.delete(_data, _data[:,1] <= 0, axis=0)
    return _data
def bin_data(_data):
    result = []
    for sed in _data:
        line = []
        for bin in bin_edges:
            inside = (sed[:,0] >= bin[0]) & (sed[:,0] <= bin[1])
            flux = sed[inside][:,1]
            line.append(np.mean(flux) if len(flux) > 0 else 0.0)
        result.append(line)
    return np.array(result)
def loadfile(file):
    _data = []
    for line in file.readlines()[4:]:
        entries = line.split()
        x = float(entries[0])
        y = float(entries[1])
        up = float(entries[2])
        lo = float(entries[3])
        #sanity check errors
        if (up < y or y < lo) and up != 0.0 and lo != 0.0:
            continue #Skip this entry
        _data.append([x, y])
    _data = np.array(_data)
    _data = sanitize(_data)
    #we want log10
    _data = np.log10(_data)
    assert(np.isfinite(_data).all())
    return _data

for filename in seds.namelist():
    #Check if filename is a folder
    if filename[-1] == '/':
        continue
    label.append(getpeak(filename))
    pos.append(getpos(filename))
    with seds.open(filename) as file:
        data.append(loadfile(file))
data = bin_data(data)
label = np.array(label)
pos = np.array(pos)

print(f"{len(data)} data entries loaded")
print(f"{len(label)} labels loaded")

Next step is to create the bagging. There about 3,800 samples in the data set, thus each bag gets about 760 samples.

In [None]:
# We need to remember which sed is in which bag, thus we wont use standard methods but pick the bag ourselfes
bag_idx = np.random.randint(5, size=len(label))
bagged_data = [data[bag_idx == i] for i in range(5)]
bagged_label = [label[bag_idx == i] for i in range(5)]
for i in range(5):
    print(f'Size of bag {i}: {len(bagged_label[i])}')

Save position with bag index. We need it later for autimatic checking wether a sed was used for training.

In [None]:
np.save('blase/bag_index.npy', np.hstack((pos, bag_idx.reshape(-1,1))))

Another problem of the data set is that not all synchrotron peaks and bins are equally represented. We solve this problem with data augmentation and thus evem create a bigger train set. The augmentation works by oversampling especially the seds with an uncommen peak while undersampling, i.e. deleting bins that are overrepresented.

In [None]:
label_edges = np.loadtxt('label_edges.txt')
_n_labels = len(label_edges) - 1
target_bin_size = 90 #~10k training set size
max_deletions = 15
power_adjust = 4.0 #further boosts deletion probability of dense bins

def augment(data, label):
    #See which bins are actually filled
    filled = np.where(data < 0.0)
    hist, *_ = np.histogram2d(label[filled[0]], filled[1], bins=(label_edges, n_bins))

    #calculate probality of bin to be deleted
    p = (hist.T**power_adjust/np.sum(hist**power_adjust, axis=1)).T
    a = np.array(list(range(n_bins)))
    binned_label = np.digitize(label, label_edges, right=True) - 1#since bin zero gets 1

    target_n = target_bin_size * _n_labels
    result_data = np.zeros((target_n, n_bins))
    result_label = np.zeros((target_n,))
    #copy originals at back of output
    result_data[-len(label):,:] = data
    result_label[-len(label):] = label

    for i in range(_n_labels): #binned labels
        inside = np.where(binned_label == i)[0]
        n = target_bin_size - len(inside) #nr to be copied
        assert(n > 0)

        #copy originals
        _ii = target_bin_size*i
        result_data[_ii:_ii+len(inside),:] = data[inside]
        result_label[_ii:_ii+len(inside)] = label[inside]

        #make copies
        copy_sources = np.random.choice(inside, n) #which to copy (mutiple times)
        _ii += len(inside)
        result_data[_ii:_ii+n,:] = data[copy_sources]
        result_label[_ii:_ii+n] = label[copy_sources]

        #which bins to delete
        for ii in range(_ii, _ii+n):
            _del = np.random.choice(a, (max_deletions,), True, p[i])
            result_data[ii, _del] = 0.0

    #remove all copies with less than 5 bins populated
    fainted = np.where((result_data != 0.0).sum(axis=1) < 5)
    result_data = np.delete(result_data, fainted, axis=0)
    result_label = np.delete(result_label, fainted, axis=0)
    
    return result_data, result_label

In [None]:
augmented = [augment(*b) for b in zip(bagged_data, bagged_label)]

## Data Loading

Next step is to make the data available for training.

In [None]:
import torch
import torch.optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
print(torch.__version__)

import pytorch_lightning as pl
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

from sklearn.model_selection import train_test_split

To improve the performance, we're gonna standardize the input per bin and the label. We use the whole augmented data set for that.

Because of the standardization 0.0 becomes a valid flux and is thus ambigious for empty bins. Wether a bin is empty is therefore append after bins as a mask were zero denotes an empty and one a filled bin.

In [None]:
_data = np.concatenate([augmented[i][0] for i in range(5)])
_label = np.concatenate([augmented[i][1] for i in range(5)])
print(_data.shape)

bin_mean = np.mean(_data, axis=0, where=(_data != 0.0))
bin_scale = np.std(_data, axis=0, where=(_data != 0.0))

label_mean = np.mean(_label, axis=0)
label_scale = np.std(_label, axis=0)

In [None]:
np.savez('blase/scaling.npz',
    bin_mean=bin_mean,
    bin_scale=bin_scale,
    label_mean=label_mean,
    label_scale=label_scale
)

In [None]:
class DataModule(pl.LightningDataModule):
    def __init__(self, bag):
        super().__init__()

        _data = np.concatenate([augmented[i][0] for i in range(5) if i != bag])
        _mask = (_data != 0.0).astype(float)
        _label = np.concatenate([augmented[i][1] for i in range(5) if i != bag])

        train_data, val_data, train_mask, val_mask, train_label, val_label = \
            train_test_split(_data, _mask, _label, test_size=1500)

        self.train_data = (train_data - bin_mean) / bin_scale * train_mask
        self.val_data = (val_data - bin_mean) / bin_scale * val_mask

        self.train_mask = train_mask
        self.val_mask = val_mask

        self.train_label = (train_label - label_mean) / label_scale
        self.val_label = (val_label - label_mean) / label_scale
    
    def setup(self, stage=None):
        self.trainSet = TensorDataset(
            torch.tensor(np.concatenate((self.train_data, self.train_mask), axis=1), dtype=torch.float),
            torch.tensor(self.train_label, dtype=torch.float))
        self.valSet = TensorDataset(
            torch.tensor(np.concatenate((self.val_data, self.val_mask), axis=1), dtype=float),
            torch.tensor(self.val_label, dtype=float))
    
    def train_dataloader(self):
        return DataLoader(self.trainSet, batch_size=64, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.valSet, batch_size=64, shuffle=False)

## Model

Next up is the model. Once again, the architecture has already been tuned.

In [None]:
class Model(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.lr = 0.0018397297637578883
        self.weight_decay = 1.634587061861498e-05
        self.model = nn.Sequential(
            nn.Linear(52, 152),
            nn.BatchNorm1d(152),
            nn.Dropout(0.11623816061109485),
            nn.ReLU(),
            nn.Linear(152, 80),
            nn.BatchNorm1d(80),
            nn.Dropout(0.14953177977171542),
            nn.ReLU(),
            nn.Linear(80, 72),
            nn.BatchNorm1d(72),
            nn.Dropout(0.024569432237666035),
            nn.ReLU(),
            nn.Linear(72, 48),
            nn.BatchNorm1d(48),
            nn.Dropout(0.03208157605345701),
            nn.ReLU(),
            nn.Linear(48, 2)
        )

    def forward(self, X):
        out = self.model(X.float())
        #the second output is the variance
        #use softplus to force the variance in [0,inf]
        mean, var = torch.unbind(out, dim=1)#first axis is batch
        var = F.softplus(var) #enforce > 0
        return mean, var 

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, patience=10)
        return {
            'optimizer' : optim,
            'scheduler' : scheduler
        }

    def training_step(self, batch, batch_idx):
        X, y = batch
        y = y.squeeze()
        mean, var = self(X)
        #-2ln[p(y,x)]
        losses = torch.log(var) + (y - mean)**2/var
        loss = torch.mean(torch.unsqueeze(losses, 0)) #to keep losses comparable regardless of batch size
        self.log('loss', loss)
        return loss
  
    def validation_step(self, batch, batch_idx):
        X, y = batch
        y = y.squeeze()
        mean, var = self(X)
        losses = torch.log(var) + (y - mean)**2/var
        #for whatever reason, mean needs an extra dimension...
        loss = torch.mean(torch.unsqueeze(losses, 0))
        self.log('val_loss', loss)
        return {'val_loss': loss}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        self.log('avg_loss', avg_loss)

## Training

Now the model can be trained. We'll train 5 models for each bag. This way we will have later an ensemble of 20 models for each out of bag estimation and even 25 for unseen data.

In [None]:
ensemble_size = 5

for bag in range(5):
    dataset = DataModule(bag)

    for i in range(ensemble_size):
        model = Model()
        trainer = pl.Trainer(
            max_epochs=200,
            logger = pl.loggers.TensorBoardLogger('logs/', name=f'bag{bag}', version=i),
            progress_bar_refresh_rate=0,#disable progress bar
            callbacks=[ModelCheckpoint(
                dirpath='models/',
                filename=f'{bag}.{i}_{{epoch}}_{{val_loss:.6f}}',
                save_top_k=3,
                monitor='val_loss',
                mode='min',
                every_n_val_epochs=1)])
        trainer.fit(model, dataset)

Since we saved checkpoints, we now have to extract the models weights to reduce disk footprint.

In [None]:
from os import listdir
from os.path import abspath, join

models = []
dic = {}

for f in listdir('models/'):
    m = Model.load_from_checkpoint(abspath(join('models/', f)))
    models.append(m)
    dic[f'{f[0:3]}'] = m.state_dict() #bag.id
torch.save(dic, 'blase/models.pth') #save all models in one file

## Evaluation

Now that our model is done, let's evaluate it.

In [None]:
#Load blase
from blase import Estimator

estimator = Estimator()

truth = label
estimate = np.zeros_like(label)
error = np.zeros_like(label)
#estimate out of bag
for i in range(5):
    bag_mask = bag_idx == i
    _estimate, _err = estimator(data[bag_mask], i)
    estimate[bag_mask] = _estimate
    error[bag_mask] = _err

In [None]:
#calculate some metric
print(f'Median Absolute Error: {np.median(np.abs(truth-estimate)):.3f}')
print(f'         25% quantile: {np.quantile(np.abs(truth-estimate),0.25):.3f}')
print(f'         75% quantile: {np.quantile(np.abs(truth-estimate),0.75):.3f}')
print('')
print(f'Median PI Width: {np.median(2*error):.3f}')
print(f'   25% quantile: {np.quantile(2*error, 0.25):.3f}')
print(f'   75% quantile: {np.quantile(2*error, 0.75):.3f}')

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import binned_statistic

sns.set_theme()
hist_cmap = sns.color_palette('PuBu', as_cmap=True)

p = (np.min(label), np.max(label))

plt.figure(figsize=(8,6))
_, _label_edges, *_ = plt.hist2d(truth, estimate, bins=30, cmap=hist_cmap)
_bins = (_label_edges[:-1] + _label_edges[1:]) / 2
_median, *_ = binned_statistic(truth, estimate, statistic='median', bins=_label_edges)
_lower, *_ = binned_statistic(truth, estimate, bins=_label_edges, statistic=lambda x: np.percentile(x, 10))
_upper, *_ = binned_statistic(truth, estimate, bins=_label_edges, statistic=lambda x: np.percentile(x, 90))
plt.plot(p, p, color='gray', linewidth=1.35)
plt.plot(_bins, _median, color='black', linewidth=1.2)
plt.plot(_bins, _lower, color='black', linestyle='dashed', linewidth=1)
plt.plot(_bins, _upper, color='black', linestyle='dashed', linewidth=1)
plt.xlabel('Ground Truth')
plt.ylabel('Predictions')
plt.show()

sort = error.argsort()
_estimate = estimate[sort]
_error = error[sort]
_truth = truth[sort] - _estimate
N = len(truth)
_below = (_truth < -_error).sum() / N * 100
_above = (_truth > _error).sum() / N * 100

plt.figure(figsize=(8,6))
plt.fill_between(np.arange(N), _error, -_error)
plt.scatter(np.arange(N), _truth, s=2.5)
plt.text(0,3,f'{_above:.2f}%')
plt.text(0,-3,f'{_below:.2f}%')
plt.ylabel('Prediction interval with ground truth (centered)')
plt.xlabel('ordered samples')
plt.ylim(-4,4)
plt.show()

In [None]:
from scipy.interpolate import splev, splrep

sample_points = np.arange(_min, _max, 0.25)
x = np.linspace(_min, _max, 200)
linewidth = 1.5

def moving_apply(data, f, width=0.5):
    result = []
    for s in sample_points:
        mask = (truth >= s - width) & (truth <= s + width)
        result.append(f(data[mask]))
    return np.array(result)

plt.figure(figsize=(8,6))
ax1 = plt.gca()
ax2 = ax1.twinx()
dense_ax = ax1.twinx()

sns.kdeplot(truth, cut=0, ax=dense_ax, color='black', linestyle='dashed', linewidth=1)
dense_ax.grid=False
dense_ax.set(ylabel=None, yticks=[])

mad = moving_apply(np.abs(truth - estimate), np.median)
mad = splev(x, splrep(sample_points, mad))

ax1.plot(x, mad)
ax1.set_ylabel('Median Absolute Eror')
ax1.set_xlabel('Log Synchrotron Peak [Hz]')
ax1.set_yticks(np.linspace(0.1,0.7, 6))
ax1.tick_params(axis='y', labelcolor='C0')
ax1.set_axisbelow(True)

pi = moving_apply(error, np.median)
pi = splev(x, splrep(sample_points, pi))

ax2.plot(x, pi, color='C1')
ax2.set_ylabel('Median PI Width')
ax2.set_yticks(np.linspace(0.7,1.0, 6))
ax2.tick_params(axis='y', labelcolor='C1')
ax2.set_axisbelow(True)

plt.show()

## Refining Data Set

The estimates from the avaluation section can actually be used to refine the data set. Since it is known to have some outliners it should be possible to reduce them through the generalization of this tool.

In [None]:
#print estimates as csv
with open('estimates.csv', 'w') as csv:
    csv.write('Right Ascension,Declination,Bag,Catalogue Peak,Estimated Peak,Estimate Error (95%)\n')
    for i in range(len(label)):
        csv.write(f'{pos[i,0]},{pos[i,1]},{bag_idx[i]},{label[i]},{estimate[i]:.2f},{error[i]:.2f}\n')