In [None]:
import numpy as np
import matplotlib.pyplot as plt
import requests
import h5py
from concurrent.futures import ThreadPoolExecutor, as_completed

# Set base URL and headers for the TNG API
baseUrl = 'http://www.tng-project.org/api/'
headers = {"api-key": "128de4248c745e040927ee558a9bcd62"}

def get(url, params=None):
    r = requests.get(url, params=params, headers=headers)
    r.raise_for_status()
    if r.headers['content-type'] == 'application/json':
        return r.json()
    return r

def get_data(url, params=None):
    """
    Download cutout data from the given URL.
    Returns the filename if a file is downloaded.
    """
    r = requests.get(url, params=params, headers=headers)
    r.raise_for_status()
    if r.headers['content-type'] == 'application/json':
        return r.json()
    if 'content-disposition' in r.headers:
        filename = r.headers['content-disposition'].split("filename=")[1]
        with open(filename, 'wb') as f:
            f.write(r.content)
        return filename
    return r

def search_mergers(mergers, bhid):
    """
    Given a structured array 'mergers' and a black hole ID (bhid),
    return the rows where either id1 or id2 equals the BH id.
    """
    mask = mergers['id1'] == bhid
    mask |= mergers['id2'] == bhid
    return mergers[mask]

def frequency(m1, m2, z):
    """
    Compute the frequency based on the formula:
      frequency = 3.9 * ((10**4) / (m1 + m2)) * (1/(1+z))
    m1 and m2 are in solar masses and z is the redshift.
    """
    m = m1 + m2
    return 3.9 * ((10**4) / m) * (1 / (1 + z))

# Define the allowed redshifts (13 values)
allowed_redshifts = [0, 0.5, 1, 1.5, 2, 2.5, 3, 3.5, 4, 4.5, 5, 5.5, 6]

# Load merger data (ensure that 'TNG100_mergers_withID.npy' exists in your working directory)
data_dir = "./"
tng100_mergers = np.load(data_dir + 'TNG100_mergers_withID.npy')

# Get simulation info for TNG100-1
r = get(baseUrl)
names = [sim['name'] for sim in r['simulations']]
i = names.index('TNG100-1')
sim = get(r['simulations'][i]['url'])

# Retrieve snapshots
snaps = get(sim['url'] + 'snapshots/')

# Select a snapshot of interest (we assume the snapshot dict uses key 'number')
snap_num = 50
snap_list = [s for s in snaps if s.get('number') == snap_num]
if not snap_list:
    raise ValueError("Snapshot number %d not found." % snap_num)
snap = snap_list[0]

# Retrieve all subhalo data for the snapshot using pagination
subhalos = []
sub_url = snap['url'] + 'subhalos/'
params = {'limit': 100, 'order_by': '-mass_stars'}
while sub_url:
    sub_data = get(sub_url, params)
    subhalos.extend(sub_data['results'])
    sub_url = sub_data.get('next')
    params = None  # subsequent pages already include the parameters

# Define the Hubble constant for unit conversion (use TNG value, e.g., 0.6774)
hubble = 0.6774

def process_subhalo(sub):
    """
    Process a single subhalo: get its largest BH, retrieve its cutout,
    and search for merger events.
    Returns a dict with subhalo info and allowed redshifts found,
    or None if something fails.
    """
    sub_id = sub['id']
    sub_url = f"{snap['url']}subhalos/{sub_id}/"
    cutout_request = {'bhs': 'ParticleIDs,Masses'}
    try:
        cutout_file = get_data(sub_url + "cutout.hdf5", cutout_request)
    except Exception:
        return None

    try:
        with h5py.File(cutout_file, 'r') as f:
            bh_ids = f['PartType5']['ParticleIDs'][:]
            bh_masses = f['PartType5']['Masses'][:] * 1e10 / hubble  # Convert to solar masses
    except Exception:
        return None

    if len(bh_masses) == 0:
        return None

    # Find the most massive BH and get its ID
    max_index = np.argmax(bh_masses)
    largest_bh_id = bh_ids[max_index]

    mergers_this_bh = search_mergers(tng100_mergers, largest_bh_id)
    if len(mergers_this_bh) == 0:
        return None

    # Record which allowed redshifts are encountered
    redshifts_found = {round(m['redshift'], 1) for m in mergers_this_bh if round(m['redshift'], 1) in allowed_redshifts}

    return {
         'subhalo_id': sub_id,
         'bh_id': largest_bh_id,
         'mergers': mergers_this_bh,
         'redshifts': redshifts_found
    }

# Use a thread pool to process subhalos concurrently.
bh_merger_info = []
collected_redshifts = set()

# You can adjust max_workers; too many threads might overload your connection.
with ThreadPoolExecutor(max_workers=10) as executor:
    futures = {executor.submit(process_subhalo, sub): sub for sub in subhalos}
    for future in as_completed(futures):
        result = future.result()
        if result is not None:
            bh_merger_info.append(result)
            collected_redshifts.update(result['redshifts'])
        # If we've already collected all allowed redshifts, we can optionally break early.
        if len(collected_redshifts) == len(allowed_redshifts):
            break

# Create a dictionary to collect frequencies for each allowed redshift
freq_dict = {z: [] for z in allowed_redshifts}

# Process each merger event and store the frequency if the redshift is allowed
for info in bh_merger_info:
    for merger in info['mergers']:
        merger_z = round(merger['redshift'], 1)
        if merger_z in allowed_redshifts:
            f = frequency(merger['m1'], merger['m2'], merger_z)
            freq_dict[merger_z].append(f)

# Compute the average frequency for each allowed redshift
avg_freqs = []
for z in allowed_redshifts:
    if freq_dict[z]:
        avg_freqs.append(np.mean(freq_dict[z]))
    else:
        avg_freqs.append(np.nan)

# Plot Average Frequency vs Redshift
plt.figure(figsize=(10, 8))
plt.semilogy(allowed_redshifts, avg_freqs, 'ro-', markersize=8)
plt.xlabel("Redshift (z)")
plt.ylabel("Average Frequency (Hz)")
plt.title("Average Merger Frequency vs Redshift")
plt.xticks(allowed_redshifts)  # Label the x-axis with all 13 redshift values
plt.grid(True, which="both", ls="--")
plt.savefig("average_merger_frequency_redshift.png", bbox_inches="tight")
plt.show()
