In [None]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
import emcee
import corner
from minijpas_LF_and_puricomp import effective_volume
from scipy.optimize import minimize
from scipy.stats import multivariate_normal
from my_functions import schechter
from multiprocessing import Pool

In [None]:
%matplotlib inline

In [None]:
# My LF

# nbs_list = [[20, 24]]
nbs_list = [[1, 4], [4, 8], [8, 12], [12, 16], [16, 20], [20, 24]]

survey_list = [f'minijpasAEGIS00{i}' for i in range(1, 4 + 1)] + ['jnep']
total_volume = 0
for [this_nb_min, this_nb_max] in nbs_list:
    total_volume += effective_volume(this_nb_min, this_nb_max, 'both')
L_binning = np.load('npy/L_nb_err_binning.npy')
b = np.log10(L_binning)
LF_bins = np.array([(b[i] + b[i + 1]) / 2 for i in range(len(b) - 1)])
bin_width = np.array([b[i + 1] - b[i] for i in range(len(b) - 1)])

qso_factor = 0.5

hist_mat = None
LF_raw = None
for i, [nb1, nb2] in enumerate(nbs_list):
    this_volume = effective_volume(nb1, nb2)
    this_hist = None
    for survey_name in survey_list:
        pathname = f'Luminosity_functions/LF_r17-24_nb{nb1}-{nb2}_ew30_ewoth100_nb_{qso_factor:0.1f}'
        filename_hist = f'{pathname}/hist_i_mat_{survey_name}.npy'
        hist_i_mat = np.load(filename_hist)

        if this_hist is None:
            this_hist = hist_i_mat
        else:
            this_hist += hist_i_mat
    # When all 5 fields are added up, divide by total volume
    this_hist = this_hist / total_volume / bin_width

    if hist_mat is None:
        hist_mat = this_hist
    else:
        # hist_mat = np.vstack([hist_mat, this_hist])
        hist_mat = hist_mat + this_hist

    filename_dict = f'{pathname}/LFs.pkl'
    with open(filename_dict, 'rb') as file:
        if LF_raw is None:
            LF_raw = pickle.load(file)['LF_total_raw'] * this_volume
        else:
            LF_raw += pickle.load(file)['LF_total_raw'] * this_volume
    
L_LF_err_percentiles = np.percentile(hist_mat, [16, 50, 84], axis=0)
LF_err_plus = L_LF_err_percentiles[2] - L_LF_err_percentiles[1]
LF_err_minus = L_LF_err_percentiles[1] - L_LF_err_percentiles[0]
hist_median = L_LF_err_percentiles[1]

volwid = total_volume * bin_width
yerr_plus = (hist_median + volwid * (LF_err_plus) ** 2) ** 0.5 * volwid ** -0.5
yerr_minus = (hist_median + volwid * (LF_err_minus) ** 2) ** 0.5 * volwid ** -0.5

LF_dict = {
    'LF_bins': LF_bins,
    'LF_total': hist_median,
    'LF_total_uncorr': LF_raw / total_volume,
    'LF_total_err': [yerr_minus, yerr_plus]
}

LF_yerr_minus = yerr_minus
LF_yerr_plus = yerr_plus
LF_phi = LF_dict['LF_total']
LF_bin = LF_dict['LF_bins']

In [None]:
# The fitting surve
def power_fit(Lx, A, B):
    return 10 ** (A * np.log10(Lx) + B)

In [None]:
def prior_f(theta):
    A0 = theta[0]
    B0 = theta[1]
    A_range = (-1.6 < A0 < 0)
    B_range = (0 < B0 < 50)

    if A_range & B_range:
        return 0.
    else:
        return -np.inf

def log_likelihood(theta, Lx, Phi, yerr):
    A0 = theta[0]
    B0 = theta[1]

    model_Arr = power_fit(Lx, A0, B0)
    sigma = yerr**2

    return -0.5 * np.sum((model_Arr - Phi) ** 2 / sigma + np.log(sigma))

def log_p(theta, Lx, Phi, yerr):
    return log_likelihood(theta, Lx, Phi, yerr) + prior_f(theta)

In [None]:
## MCMC parameters ##
N_walkers = 10000
N_steps = 1000

# Error to use
yerr = (LF_yerr_plus + LF_yerr_minus) * 0.5
yerr[LF_phi == 0] = np.inf

# In which LF bins fit
where_fit = np.isfinite(yerr) & (LF_bins > 43) & (LF_bins < 44.7)

theta0 = np.zeros((N_walkers, 2))
theta0[:, 0] = np.random.normal(-0.58, 1e-4, N_walkers)
theta0[:, 1] = np.random.normal(20.26, 1e-4, N_walkers)

args = (10**LF_bins[where_fit], LF_phi[where_fit], yerr[where_fit])

with Pool() as pool:
    # Ensemble the sampler
    sampler = emcee.EnsembleSampler(N_walkers, 2, log_p, args=args, pool=pool)
    # Run the MCMC
    sampler.run_mcmc(theta0, N_steps, progress=True);

# print(f'Autocorrelation time: {sampler.get_autocorr_time()}')

In [None]:
flat_samples = sampler.get_chain(discard=N_steps // 5 * 4, thin=15, flat=True)
labels = ['A', 'B', 'log_f']

[A_fit, B_fit] = np.median(flat_samples, axis=0)
[A_perc84, B_perc84] = np.percentile(flat_samples, [84], axis=0)[0]
[A_perc16, B_perc16] = np.percentile(flat_samples, [16], axis=0)[0]
A_fit_err = (A_perc84 - A_perc16) * 0.5
B_fit_err = (B_perc84 - B_perc16) * 0.5

print(f'Best fit: Phi* = {A_fit:0.2f} ± {A_fit_err:0.2f}, L* = {B_fit:0.2f} ± {B_fit_err:0.2f}')

fig = corner.corner(flat_samples, labels=labels,
                    truths=[A_fit, B_fit])
plt.show()

In [None]:
from my_functions import double_schechter
fig, ax = plt.subplots(figsize=(5, 4))

# Compute fit for all steps in chain
Phi_fit_i = []
Lx = np.logspace(42, 46, 1000)
for ii, step in enumerate(flat_samples[::-1]):
    if ii == 10_000:
        break
    Phi_fit_i.append(power_fit(Lx, step[0], step[1]))
Phi_fit_84 = np.percentile(Phi_fit_i, 84, axis=0)
Phi_fit_16 = np.percentile(Phi_fit_i, 16, axis=0)

# My fit
Phi_fit = power_fit(Lx, A_fit, B_fit)
ax.plot(np.log10(Lx), Phi_fit)
ax.fill_between(np.log10(Lx), Phi_fit_16, Phi_fit_84, alpha=0.3, color='C0')

ax.errorbar(LF_bin[where_fit], LF_phi[where_fit],
            yerr=[LF_yerr_minus[where_fit], LF_yerr_plus[where_fit]],
            fmt='s', color='r', capsize=4, label='My points')
ax.errorbar(LF_bin, LF_phi,
            yerr=[LF_yerr_minus, LF_yerr_plus],
            fmt='s', color='r', capsize=4,
            markerfacecolor='none', label='My points')

ax.set_yscale('log')
ax.set_ylim(1e-8, 5e-3)
ax.set_xlim(42.5, 45.5)


# Plot the reference LF curves

phistar1 = 3.33e-6
Lstar1 = 44.65
alpha1 = -1.35

phistar2 = -3.45
Lstar2 = 42.93
alpha2 = -1.93

Phi_center = double_schechter(
    Lx, phistar1, 10 ** Lstar1, alpha1, 10 ** phistar2, 10 ** Lstar2, alpha2
) * Lx * np.log(10)

ax.plot(
    np.log10(Lx), Phi_center, ls='-.', alpha=0.7,
    zorder=1, color='C6'
)

phistar1 = 10 ** -3.41
Lstar1 = 10 ** 42.87
alpha1 = -1.7

phistar2 = 10 ** -5.85
Lstar2 = 10 ** 44.6
alpha2 = -1.2

Phi_center = double_schechter(
    Lx, phistar1, Lstar1, alpha1, phistar2, Lstar2, alpha2
) * Lx * np.log(10)

ax.plot(
    np.log10(Lx), Phi_center, ls='-.', alpha=0.7,
    zorder=0, color='C7'
)
plt.show()