In [None]:
from typing import Optional

import numpy as np
import pandas as pd
import plotly.express as px
import seaborn as sns
from matplotlib import pyplot as plt
from exo_finder.default_datasets import exo_dataset

# Known Exoplanet dataset exploration

In [None]:
exo_db = exo_dataset.load_known_exoplanets_dataset(with_gaia_star_data=True).get_default_records()
exo_df = exo_db.to_pandas()
exo_df

#### Discoveries per telescope

In [None]:
# Discoveries by telescope
(exo_df
    .groupby("disc_facility")
    .aggregate({"disc_year": "min", "disc_facility": "count"})
    .rename(columns={"disc_facility": "total_discoveries", "disc_year": "year_of_first_discovery"})
    .sort_values(by="total_discoveries", ascending=False))

#### Cumulative discoveries over the years, and discovery methods

In [None]:
def discoveries_by_year() -> pd.DataFrame:
    sum_by_year = (exo_df
                   .groupby("disc_year")
                   .aggregate({"disc_year": "count"})
                   .rename(columns={"disc_year": "count"})
                   .sort_values(by="disc_year")
                   .cumsum().reset_index())
    missing = pd.DataFrame({"disc_year": 1993, "count": 2}, index=[0])
    sum_by_year = pd.concat([sum_by_year, missing], ignore_index=True).sort_values(by="disc_year")
    return sum_by_year

def plot_discoveries_by_year_and_type():
    disc_by_year = discoveries_by_year()
    count_per_type = exo_df.groupby("discoverymethod", observed=False).size().rename("count").reset_index().sort_values(by="count", ascending=False)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 7))
    ax1.plot(disc_by_year["disc_year"], disc_by_year["count"], marker='o')
    ax1.set(title="Cumulative discoveries over the years", xlabel="Year", ylabel="Exoplanets discovered", xticks=disc_by_year['disc_year'].unique())
    ax1.xaxis.set_tick_params(rotation=45)
    
    # Create the bar plot
    # barplot = sns.barplot(x="discoverymethod", y="count", data=count_per_type, palette="viridis", ax=ax2)
    barplot = sns.barplot(x="discoverymethod", y="count", hue="discoverymethod", data=count_per_type, palette="viridis", legend=False, ax=ax2)
    
    # Set y-axis to logarithmic scale
    barplot.set_yscale('log')
    
    # Add the count labels on top of each bar
    for p in barplot.patches:
        barplot.annotate(format(p.get_height(), '.0f'),
                         (p.get_x() + p.get_width() / 2., p.get_height()),
                         ha='center', va='center', xytext=(0, 9),
                         textcoords='offset points', fontsize=12)
    
    # Set the labels and title
    barplot.set_xlabel('', fontsize=14)
    barplot.set_ylabel('Count (log scale)', fontsize=14)
    barplot.set_title('Count of Exoplanets Detected by Method')
    
    # Rotate x-axis labels for better readability
    plt.setp(barplot.get_xticklabels(), rotation=45, ha='right')
    plt.tight_layout()
    plt.show()

plot_discoveries_by_year_and_type()

## Planets and Host Stars statistics

In [None]:
# How many planets per star
planets_per_star = exo_df.groupby("hostname").size()
px.histogram(planets_per_star, title="How many planets per star").show()

In [None]:
px.scatter(exo_df, x="st_mass", y="pl_rade", log_y=True, log_x=True, title="Star size vs planet radius").show()

### Closer look at orbital period vs planet mass and radius

In [None]:
from exo_finder.visualization.planet_stars_plotting import plot_planet_period_radius_mass


orbit_mass = exo_df[["pl_orbper", "pl_masse"]].dropna().to_numpy()
orbit_radius = exo_df[["pl_orbper", "pl_rade"]].dropna().to_numpy()
print(f"Total planets: {len(orbit_mass)}, {len(orbit_radius)}")
plot_planet_period_radius_mass(period_mass=orbit_mass, period_radius=orbit_radius)

In [None]:
def plot_orbital_period_distribution(mass_range: Optional[tuple[float, float]] = None):
    if mass_range:
        m_low, m_high = mass_range
        planets_and_masses = exo_df[["pl_orbper", "pl_masse"]].dropna()
        mask = (planets_and_masses["pl_masse"] > m_low) & (planets_and_masses["pl_masse"] < m_high)
        periods = planets_and_masses.loc[mask, "pl_orbper"].to_numpy()
    else:
        periods = exo_df[["pl_orbper"]].dropna().to_numpy()

    print(f"Periods stats: Min: {min(periods)}, max: {max(periods)}, median: {np.median(periods)}, 90% interval: {np.quantile(periods, q=(0.05, 0.95))}")

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    ax1.hist(np.log10(periods), bins=100)
    ax2.hist(periods[periods < 50], bins=100)
    ax1.set(title="Distribution of orbital periods", xlabel="Orbital Period log_10(days)", xlim=(-1, 5))
    ax2.set(title="Distribution of orbital periods", xlabel="Orbital Period (days)")
    plt.tight_layout()
    plt.show()

plot_orbital_period_distribution()
plot_orbital_period_distribution(mass_range=(0, 1))
plot_orbital_period_distribution(mass_range=(1, 10))
plot_orbital_period_distribution(mass_range=(10, 100))
plot_orbital_period_distribution(mass_range=(100, 1000))

### And to the transit depth

In [None]:
from exo_finder.visualization.planet_stars_plotting import plot_transit_depth

period_depth = exo_df[["pl_rade", "pl_trandep", "st_rad"]].dropna()
print(len(period_depth))

plot_transit_depth(pl_radius=period_depth.iloc[:, 0], transit_depth=period_depth.iloc[:, 1], stellar_radius=period_depth.iloc[:, 2])

# Stars statistics

In [None]:
unique_stars = exo_df.groupby('hostname').agg({
    "st_mass": "mean",
    "st_lum": "mean",
    "st_age": "mean",
    "sy_dist": "mean",
    "st_rad": "mean",
    "st_teff": "mean",
    "sy_tmag": "mean",
    "sy_kepmag": "mean",
    "sy_vmag": "mean",
    "st_spectype": "first"  
})

# Plotting histograms for each column
cols = unique_stars.columns[:-1]
nrows = (1+len(cols)) // 2
fig, axes = plt.subplots(nrows=nrows, ncols=2, figsize=(14, nrows*4))
axes = axes.flatten()

# Loop through each column and plot histogram
for i, column in enumerate(cols):
    ax = axes[i]
    filtered_values = unique_stars[unique_stars[column] != 0][column]
    ax.hist(filtered_values, bins=50, color='skyblue', edgecolor='black')
    ax.set_title(column)
    ax.set_xlabel('Value')
    ax.set_ylabel('Frequency')

# Adjust layout
plt.tight_layout()
plt.show()

# Planets statistics

In [None]:
planet_parameters = ["pl_rade", "pl_trandur", "pl_orbper", "pl_masse", "pl_orbsmax", "pl_orbincl", "pl_orbeccen", "pl_orblper", "pl_orbtper"]
# Plotting histograms for each column
cols = planet_parameters
ncols = 3
nrows = (2+len(cols)) // ncols
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, nrows*4))
axes = axes.flatten()

# Loop through each column and plot histogram
for i, column in enumerate(cols):
    ax = axes[i]
    filtered_values = exo_df[exo_df[column] != 0][column].dropna()
    filtered_values = filtered_values[filtered_values.abs() < filtered_values.quantile(q=0.95)]
    ax.hist(filtered_values, bins=100, color='skyblue', edgecolor='black')
    ax.set_title(column)

# Adjust layout
plt.tight_layout()
plt.show()

In [None]:
def study_transit_midpoints():
    """
    Studies the transit midpoints for systems with multiple planets
    """
    midpoints = (
        exo_df[["pl_tranmid", "disc_year", "hostname"]]
        .dropna()
        .groupby("hostname")
        .agg(
            count=("pl_tranmid", "count"),
            pl_tranmid_min=("pl_tranmid", "min"),
            pl_tranmid_max=("pl_tranmid", "max"),
            diff_pl_tranmid=("pl_tranmid", lambda x: x.max() - x.min()),
            disc_year=("disc_year", "first"))
         .sort_values(by="disc_year", ascending=False))
    
    midpoints = midpoints[midpoints["diff_pl_tranmid"].values.astype(float) > 0]
    #fit_data(midpoints["diff_pl_tranmid"], plot=True, upper_limit=500, include_zeros=False)
    return midpoints

study_transit_midpoints()