In [None]:
import pandas as pd
from matplotlib import pyplot as plt
import astropy.units as u
import numpy as np

from exo_finder.data_pipeline.generation.dataset_generation_types import TransitProfile
from exo_finder.data_pipeline.generation.time_generation import generate_time_days_of_length
from exo_finder.data_pipeline.generation.transit_generation import PlanetType, PeriodFrequency,generate_transits_from_params, generate_transit_parameters
from exo_finder.default_datasets import gaia_dataset

In [None]:
fields = ["gaia_id", "radius", "mass_flame", "teff_mean"]
gaia_df = gaia_dataset.load_gaia_parameters_dataset().to_pandas()[fields].dropna()
print("Dataset size:", len(gaia_df))

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
sample = gaia_df.sample(1)
radius, mass, teff = sample["radius"].item(), sample["mass_flame"].item(), sample["teff_mean"].item()
print(f"Star {sample['gaia_id'].item()}: Radius {radius}, Mass {mass}, Teff {teff}")
mass = mass * u.solMass
radius = radius * u.solRad
teff = teff * u.K

params_e = generate_transit_parameters(
    planet_type=PlanetType.MINI_NEPTUNE,
    orbital_period_interval=PeriodFrequency.THREE_TO_TEN_DAYS,
    star_radius=radius,
    star_mass= mass,
    transit_midpoint_range=(1, 2),
)

params_j = generate_transit_parameters(
    planet_type=PlanetType.JUPITER,
    orbital_period_interval=(1000, 1000),
    star_radius=radius,
    star_mass= mass,
    transit_midpoint_range=(0, 5),
)

x = generate_time_days_of_length(2**12)
transits_e = generate_transits_from_params(params_e, x)
transits_j = generate_transits_from_params(params_j, x)

plt.figure(figsize=(15, 3))
plt.plot(x, transits_e)
plt.plot(x, transits_j)
plt.show()

### Generate dataset and study the distribution of parameters

In [None]:
# Hot Jupyters: short period, at least 2 transits
hot_jupyters = TransitProfile(
    planet_type=PlanetType.JUPITER,
    transit_period_range=(1, 10),
    transit_midpoint_range=(0, 5),
    weight=1,
)
hot_jupyters = TransitProfile(
    planet_type=PlanetType.JUPITER,
    transit_period_range=(1, 10),
    transit_midpoint_range=(0, 5),
    weight=1,
)
n = 1000

rndgen = np.random.default_rng(8)
sample_stars = gaia_df.sample(n, replace=True, random_state=8)

all_params = []
for i, row in sample_stars.iterrows():
    for planet_type in PlanetType:
        all_params.append(generate_transit_parameters(
            planet_type=planet_type,
            orbital_period_interval=(0.5, 10),
            transit_midpoint_range=(0, 5),
            star_radius=row["radius"] * u.solRad,
            star_mass=row["mass_flame"] * u.solMass,
            rnd_generator=rndgen
        ))

all_transits = [generate_transits_from_params(p, x) for p in all_params]
all_transits = np.vstack(all_transits)

params_df = pd.DataFrame([p.to_dict() for p in all_params])
params_df

In [None]:
import seaborn as sns
period_mass = params_df[["period_d", "planet_mass_solmass"]].to_numpy()
period_radius = params_df[["period_d", "planet_radius_solrad"]].to_numpy()

# Create a figure with two subplots arranged vertically
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Plot period vs mass in the first subplot
sns.scatterplot(x="period_d", y="planet_mass_solmass", data=params_df, s=5, alpha=0.7, ax=ax1)
ax1.set_title('Period vs Planet Mass')
ax1.set_xlabel('Period (days)')
ax1.set_ylabel('Planet Mass (solar masses)')

# Plot period vs radius in the second subplot
sns.scatterplot(x="period_d", y="planet_radius_solrad", data=params_df, s=5, alpha=0.7, ax=ax2)
ax2.set_title('Period vs Planet Radius')
ax2.set_xlabel('Period (days)')
ax2.set_ylabel('Planet Radius (solar radii)')

# Adjust layout to prevent overlap
plt.tight_layout()
plt.show()

In [None]:
periods = params_df["period_d"].to_numpy()
midpoints = params_df["first_transit_midpoint_d"].to_numpy()
print(f"Periods stats: Min: {min(periods)}, max: {max(periods)}, median: {np.median(periods)}, 90% interval: {np.quantile(periods, q=(0.05, 0.95))}")
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
ax1.hist(midpoints, bins=100)
ax2.hist(periods[periods < 50], bins=100)
ax1.set(title="Distribution of first transit midpoints", xlabel="First transit (days)", xlim=(-1, 5))
ax2.set(title="Distribution of orbital periods", xlabel="Orbital Period (days)")
plt.tight_layout()
plt.show()