# 00. Basic Cosmology (pyccl)
Sweep core cosmological parameters and inspect background observables.


In [None]:
from pathlib import Path
import re
import matplotlib.pyplot as plt
import numpy as np

try:
    import pyccl as ccl
except Exception as e:
    raise RuntimeError('Install pyccl first: pip install pyccl') from e

print('pyccl version:', getattr(ccl, '__version__', 'unknown'))

SETTINGS = {
    'output_dir': Path('/home/z/Zekang.Zhang/cosmo_playground/outputs/pyccl_basic_cosmology'),
    'n_points': 7,
    'include_regex': [],
    'exclude_regex': [],
    'ell_min': 20,
    'ell_max': 3000,
    'n_ell': 120,
    'z_min': 0.001,
    'z_max': 3.0,
    'n_z': 600,
    'alpha': 2.0,
    'beta': 1.5,
    'z0': 0.5,
    'source_bins': [(0.20, 0.55), (0.55, 0.85), (0.85, 1.30), (1.30, 2.00)],
    'lens_bins': [(0.20, 0.45), (0.45, 0.70), (0.70, 0.95), (0.95, 1.20)],
    'transfer_function': 'eisenstein_hu',
    'parameters': {
        'omega_m': (0.24, 0.31, 0.40),
        'omega_c': (0.20, 0.26, 0.34),
        'omega_b': (0.040, 0.049, 0.060),
        'h': (0.60, 0.68, 0.76),
        'n_s': (0.92, 0.965, 1.02),
        'sigma8': (0.70, 0.82, 0.94),
        'w0': (-1.2, -1.0, -0.8),
        'wa': (-0.6, 0.0, 0.6),
        'm_nu': (0.0, 0.06, 0.30),
        'Omega_k': (-0.02, 0.0, 0.02),
        'Neff': (3.046, 3.046, 3.7),
        'a_ia': (-1.0, 0.5, 3.0),
        'eta_ia': (-2.0, 0.0, 2.0),
        'b_lens_1': (0.8, 1.2, 2.2),
        'b_lens_2': (0.8, 1.3, 2.3),
        'b_lens_3': (0.8, 1.4, 2.4),
        'b_lens_4': (0.8, 1.5, 2.5),
        'm_source_1': (-0.03, 0.00, 0.03),
        'm_source_2': (-0.03, 0.00, 0.03),
        'm_source_3': (-0.03, 0.00, 0.03),
        'm_source_4': (-0.03, 0.00, 0.03),
        'dz_source_1': (-0.08, 0.00, 0.08),
        'dz_source_2': (-0.08, 0.00, 0.08),
        'dz_source_3': (-0.08, 0.00, 0.08),
        'dz_source_4': (-0.08, 0.00, 0.08),
        'dz_lens_1': (-0.05, 0.00, 0.05),
        'dz_lens_2': (-0.05, 0.00, 0.05),
        'dz_lens_3': (-0.05, 0.00, 0.05),
        'dz_lens_4': (-0.05, 0.00, 0.05),
    }
}


In [None]:
def fiducial_params(settings):
    return {k: float(v[1]) for k, v in settings['parameters'].items()}


def filter_params(settings, names):
    out = []
    for n in names:
        if n not in settings['parameters']:
            continue
        if settings['include_regex'] and not any(re.search(p, n) for p in settings['include_regex']):
            continue
        if settings['exclude_regex'] and any(re.search(p, n) for p in settings['exclude_regex']):
            continue
        out.append(n)
    return out


def make_grids(settings):
    ell = np.unique(np.geomspace(settings['ell_min'], settings['ell_max'], settings['n_ell']).astype(int))
    z = np.linspace(settings['z_min'], settings['z_max'], settings['n_z'])
    return ell, z


def smail_nz(z, alpha, beta, z0):
    nz = z**alpha * np.exp(-(z / z0) ** beta)
    nz[z <= 0] = 0.0
    area = np.trapezoid(nz, z)
    return nz / area if area > 0 else nz


def top_hat_bin(z, zmin, zmax):
    m = np.zeros_like(z)
    m[(z >= zmin) & (z < zmax)] = 1.0
    return m


def shift_nz(z, nz, dz):
    out = np.interp(z - dz, z, nz, left=0.0, right=0.0)
    area = np.trapezoid(out, z)
    return out / area if area > 0 else out


def build_tomo_nz(z, bins, base_nz, dz_prefix, params):
    out = []
    for i, (zmin, zmax) in enumerate(bins, start=1):
        nz = base_nz * top_hat_bin(z, zmin, zmax)
        out.append(shift_nz(z, nz, float(params.get(f'{dz_prefix}_{i}', 0.0))))
    return out


def ia_bias_curve(z, params, pivot_z=0.62):
    return float(params.get('a_ia', 0.0)) * ((1.0 + z)/(1.0 + pivot_z))**float(params.get('eta_ia', 0.0))


def build_cosmology(params, settings):
    omega_b = float(params['omega_b'])
    if 'omega_m' in params:
        omega_c = float(params['omega_m']) - omega_b
    else:
        omega_c = float(params['omega_c'])

    if omega_c <= 0:
        raise ValueError(f'Unphysical Omega_c={omega_c:.4g}; check omega_m and omega_b ranges.')

    kwargs = dict(
        Omega_c=omega_c,
        Omega_b=omega_b,
        h=float(params['h']),
        n_s=float(params['n_s']),
        sigma8=float(params['sigma8']),
        w0=float(params.get('w0', -1.0)),
        wa=float(params.get('wa', 0.0)),
        transfer_function=settings.get('transfer_function', 'eisenstein_hu'),
    )

    # Optional available parameters (used when present in SETTINGS['parameters']).
    if 'm_nu' in params:
        kwargs['m_nu'] = float(params['m_nu'])
    if 'Omega_k' in params:
        kwargs['Omega_k'] = float(params['Omega_k'])
    if 'Neff' in params:
        kwargs['Neff'] = float(params['Neff'])

    return ccl.Cosmology(**kwargs)


def build_tracers(cosmo, z, src_nz, lens_nz, params):
    ia = ia_bias_curve(z, params)
    src = [ccl.WeakLensingTracer(cosmo, dndz=(z, nz), ia_bias=(z, ia)) for nz in src_nz]
    lens = []
    for i, nz in enumerate(lens_nz, start=1):
        b = np.full_like(z, float(params.get(f'b_lens_{i}', 1.5)))
        lens.append(ccl.NumberCountsTracer(cosmo, has_rsd=False, dndz=(z, nz), bias=(z, b)))
    return src, lens


In [None]:
PARAMS_ALL = filter_params(SETTINGS, sorted(SETTINGS['parameters'].keys()))
if not PARAMS_ALL:
    raise ValueError('No parameters selected after include/exclude filters.')

PHYS_GROUPS = {
    'expansion_geometry': ['omega_m', 'omega_b', 'h', 'Omega_k', 'w0', 'wa'],
    'growth_primordial': ['sigma8', 'n_s', 'm_nu', 'Neff'],
    'intrinsic_alignment': ['a_ia', 'eta_ia'],
    'lens_bias': ['b_lens_1', 'b_lens_2', 'b_lens_3', 'b_lens_4'],
    'shear_calibration': ['m_source_1', 'm_source_2', 'm_source_3', 'm_source_4'],
    'photoz_source': ['dz_source_1', 'dz_source_2', 'dz_source_3', 'dz_source_4'],
    'photoz_lens': ['dz_lens_1', 'dz_lens_2', 'dz_lens_3', 'dz_lens_4'],
}

assigned = set()
GROUPED_PARAMS = {}
for g, names in PHYS_GROUPS.items():
    keep = [p for p in names if p in PARAMS_ALL]
    if keep:
        GROUPED_PARAMS[g] = keep
        assigned.update(keep)
leftover = [p for p in PARAMS_ALL if p not in assigned]
if leftover:
    GROUPED_PARAMS['other'] = leftover

print('Parameter groups:')
for g, names in GROUPED_PARAMS.items():
    print(f'  {g}: {len(names)} params')

ell, z = make_grids(SETTINGS)
params0 = fiducial_params(SETTINGS)

results = {}
for pname in PARAMS_ALL:
    lo, _, hi = SETTINGS['parameters'][pname]
    vals = np.linspace(lo, hi, SETTINGS['n_points'])

    curves_distance, curves_growth = [], []
    curves_transfer, curves_pk_lin, curves_pk_nl, curves_hmf = [], [], [], []

    for val in vals:
        p = dict(params0)
        p[pname] = float(val)
        try:
            cosmo = build_cosmology(p, SETTINGS)
        except ValueError as e:
            print(f'[skip] {pname}={val:.4g}: {e}')
            k = np.geomspace(1e-3, 10.0, 220)
            m = np.geomspace(1e11, 1e15, 120)
            nan_z = np.full_like(z, np.nan, dtype=float)
            nan_k = np.full_like(k, np.nan, dtype=float)
            nan_m = np.full_like(m, np.nan, dtype=float)
            curves_distance.append((z, nan_z))
            curves_growth.append((z, nan_z.copy()))
            curves_transfer.append((k, nan_k))
            curves_pk_lin.append((k, nan_k.copy()))
            curves_pk_nl.append((k, nan_k.copy()))
            curves_hmf.append((m, nan_m))
            continue

        a = 1.0 / (1.0 + z)
        chi = ccl.comoving_radial_distance(cosmo, a)
        growth = ccl.growth_factor(cosmo, a)

        k = np.geomspace(1e-3, 10.0, 220)
        pk_lin = ccl.linear_matter_power(cosmo, k, a=1.0)
        pk_nl = ccl.nonlin_matter_power(cosmo, k, a=1.0)
        t_eff = np.sqrt(np.maximum(pk_lin, 1e-60) / (k ** float(p['n_s'])))
        t_eff /= t_eff[0]

        curves_distance.append((z, chi))
        curves_growth.append((z, growth))
        curves_transfer.append((k, t_eff))
        curves_pk_lin.append((k, pk_lin))
        curves_pk_nl.append((k, pk_nl))

        try:
            m = np.geomspace(1e11, 1e15, 120)
            mf = ccl.halos.MassFuncTinker08(mass_def=ccl.halos.MassDef200m())
            curves_hmf.append((m, mf(cosmo, m, 1.0)))
        except Exception:
            curves_hmf.append((m, np.full_like(m, np.nan, dtype=float)))

    results[pname] = {
        'vals': vals,
        'distance': curves_distance,
        'growth': curves_growth,
        'transfer': curves_transfer,
        'pk_lin': curves_pk_lin,
        'pk_nl': curves_pk_nl,
        'hmf': curves_hmf,
    }


def plot_quantity_group(quantity_key, title_prefix, xlab, ylab, group_name, pnames, logx=False, logy=False):
    if not pnames:
        return

    n = len(pnames)
    ncols = 3
    nrows = int(np.ceil(n / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(5.0 * ncols, 3.8 * nrows), squeeze=False)
    axes = axes.ravel()

    for ax in axes[n:]:
        ax.axis('off')

    for ax, pname in zip(axes, pnames):
        vals = results[pname]['vals']
        curves = results[pname][quantity_key]
        cmap = plt.get_cmap('viridis')
        colors = cmap(np.linspace(0, 1, len(vals)))

        for color, (xx, yy) in zip(colors, curves):
            ax.plot(xx, yy, color=color, lw=1.3)

        if logx:
            ax.set_xscale('log')
        if logy:
            ax.set_yscale('log')

        ax.set_title(pname)
        ax.set_xlabel(xlab)
        ax.set_ylabel(ylab)
        ax.grid(alpha=0.2)

        sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=float(vals.min()), vmax=float(vals.max())))
        fig.colorbar(sm, ax=ax, fraction=0.046, pad=0.02, label=pname)

    fig.suptitle(f'{title_prefix} | group: {group_name}', y=1.02)
    fig.tight_layout()
    plt.show()


for gname, plist in GROUPED_PARAMS.items():
    print(f'\n=== Group: {gname} ===')
    plot_quantity_group('distance', 'Comoving distance (panels by parameter)', 'z', 'chi(z) [Mpc]', gname, plist)
    plot_quantity_group('growth', 'Growth D(z) (panels by parameter)', 'z', 'D(z)', gname, plist)
    plot_quantity_group('transfer', 'Transfer-shape proxy from P_lin (panels by parameter)', 'k [1/Mpc]', 'T_proxy(k)', gname, plist, logx=True)
    plot_quantity_group('pk_lin', 'Linear P(k) (panels by parameter)', 'k [1/Mpc]', 'P_lin(k)', gname, plist, logx=True, logy=True)
    plot_quantity_group('pk_nl', 'Nonlinear P(k) (panels by parameter)', 'k [1/Mpc]', 'P_nl(k)', gname, plist, logx=True, logy=True)
    plot_quantity_group('hmf', 'Halo mass function (panels by parameter)', 'M [Msun]', 'dn/dlog10M', gname, plist, logx=True, logy=True)
