In [None]:
import copy
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, interactive, fixed, interact_manual, Layout
import ipywidgets as widgets
from gaslight.grid import Grid
from synthesizer.line import (
    get_diagram_labels,
    get_ratio_label,)
from synthesizer import line_ratios

In [None]:

grid_dir = '/Users/sw376/Dropbox/Research/data/gaslight/grids'
# grid_name = 'bpass-2.2.1-bin_chabrier03-0.1,300.0-ages:6.,7.-metallicities:0.0001,0.001,0.01-c23.01-test'
# grid_name = 'bpass-2.2.1-bin_chabrier03-0.1,300.0-ages:6.,7.,8.-c23.01-reduced'
grid_name = 'bpass-2.2.1-bin_chabrier03-0.1,300.0-ages:6.,7.,8.-c23.01-full'
# grid_name = 'agnsed-limited-c23.01-agn-limited'
grid = Grid(grid_dir=grid_dir, grid_name=grid_name)

## Diagram explorer

In [None]:


diagram_limits = {
    'BPT-NII': [[-4.,1.],[-4.,1]]
}


def plot_diagram(diagram_id=None, **kwargs):

    x = []
    y = []

    for metallicity in grid.metallicity:

        grid_value_dict = {'metallicity': metallicity} | kwargs

        grid_point = grid.get_nearest_grid_point(grid_value_dict)

        lines = grid.get_line_collection(grid_point) 

        x_, y_ = lines.get_diagram(diagram_id)

        x.append(x_)  
        y.append(y_)  

    plt.plot(np.log10(x), np.log10(y))

    if diagram_id in diagram_limits:
        xlim, ylim = diagram_limits[diagram_id]
    else:
        xlim = [-5., 1.5]
        ylim = [-3., 1.5]

    plt.xlim(xlim)
    plt.ylim(ylim)

    x_label, y_label = get_diagram_labels(diagram_id)

    # add axes labels
    plt.xlabel(rf'${x_label}$')
    plt.ylabel(rf'${y_label}$')

    # show
    plt.show()

# ratio selection widget
diagram_id = widgets.Dropdown(
    options=line_ratios.available_diagrams,
    value='BPT-NII',
    disabled=False,
)

widget_dict = {'diagram_id': diagram_id}

axes = copy.deepcopy(grid.axes)
axes.remove('metallicity')

for axis in axes:

    widget = widgets.SelectionSlider(
        options=grid.axes_values[axis],
        value=grid.axes_values[axis][0], 
        description=axis,
        style={'description_width': '50%'},
        layout={'width': '700px'},
        )

    # add widget to dictionary
    widget_dict[axis] = widget

widget_list = list(widget_dict.values())

# define UI
ui = widgets.VBox(widget_list)

# collect widgets
out = widgets.interactive_output(
    plot_diagram, 
    widget_dict)

# display
display(ui, out)