# Analysing Spectral Energy Distribution and Emission/Absorption Visualization

<div class="alert alert-info">

**Note:** 

This notebook is only a sample demonstrating some of the features of the `sdecplotter` class. If you are interested in using additional features, you should directly access the [sdecplotter](https://github.com/tardis-sn/tardis/blob/master/tardis/visualization/tools/sdec_plot.py#L419) class. You can see the rest of the features of the sdecplotter class [here](docs/analysing_tardis_outputs/visualization/how_to_sdec_plot.ipynb).
</div>

A notebook for analyzing and visualizing the spectral energy distribution, emission and absorption patterns in supernova simulations using TARDIS.

In [None]:
import matplotlib.cm as cm
import matplotlib.colors as clr
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
from astropy import units as u

from tardis.util.base import atomic_number2element_symbol
from tardis.visualization import plot_util as pu
from tardis.visualization.sdec.util import (
    calculate_absorption_luminosities,
    calculate_emission_luminosities,
)


Every simulation run requires [atomic data](io/configuration/components/atomic/atomic_data.rst) and a [configuration file](io/configuration/index.rst). 

## Atomic Data

We recommend using the [kurucz_cd23_chianti_H_He.h5](https://github.com/tardis-sn/tardis-regression-data/raw/main/atom_data/kurucz_cd23_chianti_H_He.h5) dataset.

In [None]:
from tardis.io.atom_data import download_atom_data

# We download the atomic data needed to run the simulation
download_atom_data('kurucz_cd23_chianti_H_He')

## Example Configuration File

In [None]:
!wget -q -nc https://raw.githubusercontent.com/tardis-sn/tardis/master/docs/tardis_example.yml

In [None]:
!cat tardis_example.yml

## Loading Data

### Running simulation

To run the simulation, import the `run_tardis` function and create the `sim` object.

<div class="alert alert-info">

**Note:**

Get more information about the [progress bars](io/output/progress_bars.rst), [logging configuration](io/optional/tutorial_logging_configuration.ipynb), and [convergence plots](io/visualization/tutorial_convergence_plot.ipynb).

</div>


In [None]:
from tardis import run_tardis

simulation = run_tardis(
    "tardis_example.yml",
    virtual_packet_logging=True,
    show_convergence_plots=True,
    export_convergence_plots=True,
    log_level="INFO",
)

### HDF

In [None]:
# hdf_path = 'simulation_result.hdf'
# sim = from_hdf(hdf_path, packets_mode='virtual')

## Data Processing

### Wavelength and Frequency Grid Setup

In [None]:
plot_frequency_bins = (
    simulation.spectrum_solver.spectrum_real_packets._frequency
)
plot_wavelength = simulation.spectrum_solver.spectrum_real_packets.wavelength
plot_frequency = plot_frequency_bins[:-1]

packet_wvl_range_mask = np.ones(plot_wavelength.size, dtype=bool)

### Luminosity Calculations

In [None]:
(emission_luminosities_df, emission_species) = calculate_emission_luminosities(
    simulation,
    "real",
    packet_wvl_range=None,
)
(
    absorption_luminosities_df,
    absorption_species,
) = calculate_absorption_luminosities(
    simulation,
    packets_mode="real",
    packet_wvl_range=None,
)

total_luminosities_df = (
    absorption_luminosities_df
    + emission_luminosities_df.drop(["noint", "escatter"], axis=1)
)

species = np.array(list(total_luminosities_df.keys()))
species_name = [
    atomic_number2element_symbol(atomic_num)
    for atomic_num in species
]
species_length = len(species_name)

modeled_spectrum_luminosity = (
    simulation.spectrum_solver.spectrum_real_packets.luminosity_density_lambda[
        packet_wvl_range_mask
    ]
)

## Visualisation

### Matplotlib

In [None]:
# Create figure and axis
ax = plt.figure(figsize=(12, 7)).add_subplot(111)

# Generate a colormap for species
cmap = plt.get_cmap("jet", species_length)
color_list = []
color_values = []
for species_counter in range(species_length):
    color = cmap(species_counter / species_length)
    color_list.append(color)
    color_values.append(color)


# Create a custom colormap for species
custcmap = clr.ListedColormap(color_values)
norm = clr.Normalize(vmin=0, vmax=species_length)
mappable = cm.ScalarMappable(norm=norm, cmap=custcmap)
mappable.set_array(np.linspace(1, species_length + 1, 256))


# Add colorbar for species representation
cbar = plt.colorbar(mappable, ax=ax)
bounds = np.arange(species_length) + 0.5
cbar.set_ticks(bounds)
cbar.set_ticklabels(species_name)


# Plot emission contributions from different interactions
lower_level = np.zeros(emission_luminosities_df.shape[0])
upper_level = lower_level + emission_luminosities_df.noint.to_numpy()
ax.fill_between(
    plot_wavelength.value,
    lower_level,
    upper_level,
    color="#4C4C4C",
    label="No interaction",
)
lower_level = upper_level
upper_level = lower_level + emission_luminosities_df.escatter.to_numpy()
ax.fill_between(
    plot_wavelength.value,
    lower_level,
    upper_level,
    color="#8F8F8F",
    label="Electron Scatter Only",
)


# Plot emission contributions for each species
for species_counter, identifier in enumerate(species):
    lower_level = upper_level
    upper_level = lower_level + emission_luminosities_df[identifier].to_numpy()

    ax.fill_between(
        plot_wavelength.value,
        lower_level,
        upper_level,
        color=color_list[species_counter],
        cmap=cmap,
        linewidth=0,
    )


# Plot absorption contributions for each species
lower_level = np.zeros(absorption_luminosities_df.shape[0])
for species_counter, identifier in enumerate(species):
    upper_level = lower_level
    lower_level = (
        upper_level - absorption_luminosities_df[identifier].to_numpy()
    )

    ax.fill_between(
        plot_wavelength.value,
        upper_level,
        lower_level,
        color=color_list[species_counter],
        cmap=cmap,
        linewidth=0,
    )


# Plot the modeled spectrum
ax.plot(
    plot_wavelength.value,
    modeled_spectrum_luminosity.value,
    "--b",
    label="Real Spectrum",
    linewidth=1,
)

# Add labels, legend, and formatting
ax.legend(fontsize=12)
ax.set_xlabel(r"Wavelength $[\mathrm{\AA}]$", fontsize=12)
ax.set_ylabel(
    r"$L_{\lambda}$ [erg $\mathrm{s^{-1}}$ $\mathrm{\AA^{-1}}$]",
    fontsize=12,
)
plt.gca()

### Plotly

In [None]:
# # Create figure
fig = go.Figure()

# By specifying a common stackgroup, plotly will itself add up luminosities,
# in order, to created stacked area chart
fig.add_trace(
    go.Scatter(
        x=emission_luminosities_df.index,
        y=emission_luminosities_df.noint,
        mode="none",
        name="No interaction",
        fillcolor="#4C4C4C",
        stackgroup="emission",
        hovertemplate="(%{x:.2f}, %{y:.3g})",
    )
)

fig.add_trace(
    go.Scatter(
        x=emission_luminosities_df.index,
        y=emission_luminosities_df.escatter,
        mode="none",
        name="Electron Scatter Only",
        fillcolor="#8F8F8F",
        stackgroup="emission",
        hoverlabel={"namelength": -1},
        hovertemplate="(%{x:.2f}, %{y:.3g})",
    )
)


# Contribution from each element
for (species_counter, identifier), name_of_spec in zip(
    enumerate(species), species_name
):
    fig.add_trace(
        go.Scatter(
            x=emission_luminosities_df.index,
            y=emission_luminosities_df[identifier],
            mode="none",
            name=name_of_spec + " Emission",
            hovertemplate=f"<b>{name_of_spec:s} Emission<br>"  # noqa: ISC003
            + "(%{x:.2f}, %{y:.3g})<extra></extra>",
            fillcolor=pu.to_rgb255_string(color_list[species_counter]),
            stackgroup="emission",
            showlegend=False,
            hoverlabel={"namelength": -1},
        )
    )
    # Plot absorption part
    fig.add_trace(
        go.Scatter(
            x=absorption_luminosities_df.index,
            # to plot absorption luminosities along negative y-axis
            y=absorption_luminosities_df[identifier] * -1,
            mode="none",
            name=name_of_spec + " Absorption",
            hovertemplate=f"<b>{name_of_spec:s} Absorption<br>"  # noqa: ISC003
            + "(%{x:.2f}, %{y:.3g})<extra></extra>",
            fillcolor=pu.to_rgb255_string(color_list[species_counter]),
            stackgroup="absorption",
            showlegend=False,
            hoverlabel={"namelength": -1},
        )
    )

# Plot modeled spectrum
fig.add_trace(
    go.Scatter(
        x=plot_wavelength.value,
        y=modeled_spectrum_luminosity.value,
        mode="lines",
        line={
            "color": "blue",
            "width": 1,
        },
        name="Real Spectrum",
        hovertemplate="(%{x:.2f}, %{y:.3g})",
        hoverlabel={"namelength": -1},
    )
)

# Interpolate [0, 1] range to create bins equal to number of elements
colorscale_bins = np.linspace(0, 1, num=len(species_name) + 1)

# Create a categorical colorscale [a list of (reference point, color)]
# by mapping same reference points (excluding 1st and last bin edge)
# twice in a row (https://plotly.com/python/colorscales/#constructing-a-discrete-or-discontinuous-color-scale)
categorical_colorscale = []
for species_counter in range(len(species_name)):
    color = pu.to_rgb255_string(cmap(colorscale_bins[species_counter]))
    categorical_colorscale.append((colorscale_bins[species_counter], color))
    categorical_colorscale.append((colorscale_bins[species_counter + 1], color))


coloraxis_options = {
    "colorscale": categorical_colorscale,
    "showscale": True,
    "cmin": 0,
    "cmax": len(species_name),
    "colorbar": {
        "title": "Elements",
        "tickvals": np.arange(0, len(species_name)) + 0.5,
        "ticktext": species_name,
        # to change length and position of colorbar
        "len": 0.75,
        "yanchor": "top",
        "y": 0.75,
    },
}

# Plot an invisible one point scatter trace, to make colorbar show up
scatter_point_idx = pu.get_mid_point_idx(plot_wavelength)
fig.add_trace(
    go.Scatter(
        x=[plot_wavelength[scatter_point_idx].value],
        y=[0],
        mode="markers",
        name="Colorbar",
        showlegend=False,
        hoverinfo="skip",
        marker=dict(color=[0], opacity=0, **coloraxis_options),
    )
)

# Set label and other layout options
xlabel = pu.axis_label_in_latex("Wavelength", u.AA)
ylabel = pu.axis_label_in_latex(
    "L_{\\lambda}", u.Unit("erg/(s AA)"), only_text=False
)
fig.update_layout(
    xaxis={
        "title": xlabel,
        "exponentformat": "none",
    },
    yaxis={"title": ylabel, "exponentformat": "e"},
    height=600,
)