In [None]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
import emcee
import corner
from minijpas_LF_and_puricomp import effective_volume
from scipy.optimize import minimize
from scipy.stats import multivariate_normal
from my_functions import schechter
from multiprocessing import Pool

In [None]:
%matplotlib inline

In [None]:
# My LF

nb_min = 1
nb_max = 25
nbs_list = [[1, 4]]
total_volume = effective_volume(nb_min, nb_max, 'both')

for i, [nb1, nb2] in enumerate(nbs_list):
    pathname = f'Luminosity_functions/LF_r17-24_nb{nb1}-{nb2}_ew15_ewoth400_nb'
    filename = f'{pathname}/LFs.pkl'
    with open(filename, 'rb') as file:
        this_LF_dict = pickle.load(file)

    this_volume = effective_volume(nb1, nb2, 'both')

    # In the first iteration of the loop:
    if i == 0:
        LF_bins = this_LF_dict['LF_bins']
        bin_width = np.diff(LF_bins)[0]

        this_LF_raw = this_LF_dict['LF_total'] * this_volume * bin_width
        LF_raw = this_LF_raw / bin_width
        LF_uncorr = this_LF_dict['LF_total_raw'] * this_volume

        LF_err = ((np.array(this_LF_dict['LF_total_err']) *
                  this_volume * bin_width) ** 2 - this_LF_raw) ** 0.5
    # Second iter and further
    else:
        this_LF_raw = this_LF_dict['LF_total'] * this_volume * bin_width
        LF_raw += this_LF_raw / bin_width
        LF_err += ((np.array(this_LF_dict['LF_total_err']) *
                  this_volume * bin_width) ** 2 - this_LF_raw) ** 0.5
        LF_uncorr += this_LF_dict['LF_total_raw'] * this_volume

LF_err = np.array(LF_err)
LF_err[~np.isfinite(LF_err)] = 0

LF_dict = {
    'LF_bins': LF_bins,
    'LF_total': LF_raw / total_volume,
    'LF_total_uncorr': LF_uncorr / total_volume,
    'LF_total_err': (LF_err ** 2 + LF_raw * bin_width) ** 0.5 / total_volume / bin_width
}

LF_phi = LF_dict['LF_total']
LF_bin = LF_dict['LF_bins']
LF_yerr_minus = LF_dict['LF_total_err'][0]
LF_yerr_plus = LF_dict['LF_total_err'][1]
LF_xerr = np.ones(LF_dict['LF_total_err'][2].shape) * bin_width * 0.5

In [None]:
# The fitting surve
def sch_fit(Lx, Phistar, Lstar):
    return schechter(Lx, 10 ** Phistar, 10 ** Lstar, -1.5) * Lx * np.log(10)

In [None]:
def prior_f(theta):
    Phistar0 = theta[0]
    Lstar0 = theta[1]
    Phistar_range = (-10 < Phistar0 < -4)
    Lstar_range = (42 < Lstar0 < 47)

    if Phistar_range & Lstar_range:
        return 0.
    else:
        return -np.inf

def log_likelihood(theta, Lx, Phi, yerr):
    Phistar0 = theta[0]
    Lstar0 = theta[1]

    model_Arr = sch_fit(Lx, Phistar0, Lstar0)
    sigma = yerr**2

    return -0.5 * np.sum((model_Arr - Phi) ** 2 / sigma + np.log(sigma))

def log_p(theta, Lx, Phi, yerr):
    return log_likelihood(theta, Lx, Phi, yerr) + prior_f(theta)

In [None]:
## MCMC parameters ##
N_walkers = 1000
N_steps = 500

# Error to use
yerr = (LF_yerr_plus + LF_yerr_minus) * 0.5
yerr[LF_phi == 0] = np.inf

# In which LF bins fit
where_fit = np.isfinite(yerr)

theta0 = np.zeros((N_walkers, 2))
theta0[:, 0] = np.random.normal(-6, 1e-3, N_walkers)
theta0[:, 1] = np.random.normal(44.5, 1e-3, N_walkers)

args = (10**LF_bins[where_fit], LF_phi[where_fit], yerr[where_fit])

with Pool() as pool:
    # Ensemble the sampler
    sampler = emcee.EnsembleSampler(N_walkers, 2, log_p, args=args, pool=pool)
    # Run the MCMC
    sampler.run_mcmc(theta0, N_steps, progress=True);

# print(f'Autocorrelation time: {sampler.get_autocorr_time()}')

In [None]:
flat_samples = sampler.get_chain(discard=N_steps // 5 * 4, thin=15, flat=True)
labels = ['Lstar', 'Phistar', 'log_f']

[Phistar_fit, Lstar_fit] = np.median(flat_samples, axis=0)
print(f'Best fit: Phi* = {Phistar_fit:0.2f}, L* = {Lstar_fit:0.2f}')

fig = corner.corner(flat_samples, labels=labels,
                    truths=[Phistar_fit, Lstar_fit])
plt.show()


In [None]:
fig, ax = plt.subplots(figsize=(5, 4))

# My fit
Lx = np.logspace(42, 46, 1000)
Phi_fit = sch_fit(Lx, Phistar_fit, Lstar_fit)
ax.plot(np.log10(Lx), Phi_fit)

ax.errorbar(LF_bin[where_fit], LF_phi[where_fit],
            yerr=[LF_yerr_minus[where_fit], LF_yerr_plus[where_fit]],
            xerr=LF_xerr[where_fit],
            fmt='s', color='r', capsize=4, label='My points')
ax.errorbar(LF_bin, LF_phi,
            yerr=[LF_yerr_minus, LF_yerr_plus],
            xerr=LF_xerr, fmt='s', color='r', capsize=4,
            markerfacecolor='none', label='My points')

ax.set_yscale('log')
ax.set_ylim(1e-8, 5e-3)
ax.set_xlim(42.5, 45.5)

plt.show()