In [None]:
import matplotlib.pyplot as pp
import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.colors as mcolors
import numpy as np
import pandas as pd
import analyse_methods as am
from importlib import reload
%matplotlib widget

In [None]:
elements = [
    "h", "he", "li", "be", "b", "c", "n", "o", "f", "ne",
    "na", "mg", "al", "si", "p", "s", "cl", "ar", "k", "ca",
    "sc", "ti", "v", "cr", "mn", "fe", "co", "ni", "cu", "zn",
    "ga", "ge", "as", "se", "br", "kr", "rb", "sr", "y", "zr",
    "nb", "mo", "tc", "ru", "rh", "pd", "ag", "cd", "in", "sn",
    "sb", "te", "i", "xe", "cs", "ba", "la", "ce", "pr", "nd",
    "pm", "sm", "eu", "gd", "tb", "dy", "ho", "er", "tm", "yb",
    "lu", "hf", "ta", "w", "re", "os", "ir", "pt", "au", "hg",
    "tl", "pb", "bi", "po", "at", "rn", "fr", "ra", "ac", "th",
    "pa", "u", "np", "pu", "am", "cm", "bk", "cf", "es", "fm",
    "md", "no", "lr", "rf", "db", "sg", "bh", "hs", "mt", "ds",
    "rg", "cn", "nh", "fl", "mc", "lv", "ts", "og"
]

In [None]:
blocks_dict = {
    "s": ["li", "be", "na", "mg", "k", "ca",],
    "d": ["sc", "ti", "v", "cr", "mn", "fe", "co", "ni", "cu", "zn",
          ],
    "p": ["he", "b", "c", "n", "o", "f", "ne",
          "al", "si", "p", "s", "cl", "ar",
          "ga", "ge", "as", "se", "br", "kr"]
}



In [None]:
# ~~ CHATGPT SUGGESTS
reload(am)

def plot_all_elements(basis, elements, blocks_dict, save=False):
    nblocks = len(blocks_dict)

    # Grid layout for subplots
    ncols = int(nblocks**0.5 + 0.5)
    nrows = (nblocks + ncols - 1) // ncols

    fig, axes = pp.subplots(nrows, ncols, figsize=(4*ncols, 4*nrows), squeeze=False)
    axes = axes.flatten()

    fig._axis_data = []  # store per-axis metadata in the figure

    for ax, (block, data_list) in zip(axes, blocks_dict.items()):
        ax.set_title(f"{block}-block {basis}")
        n = len(data_list)

        cmap = pp.get_cmap("viridis")
        norm = mcolors.Normalize(vmin=0, vmax=n-1)

        line_map = {}  # for this axis
        active_element = {"name": None}

        for i, atom in enumerate(data_list):
            data = am.get_evsgamma(atom, basis)
            if data is None:
                continue
            line, = ax.plot(
                data.GEM_BETA, data.E - min(data.E),
                label=atom,
                color=cmap(norm(i)),
                lw=1.5
            )
            line_map[atom] = line
            gamma_min = am.get_min(data)
            ax.axvline(gamma_min, color=line.get_color(), linestyle="--")

        ax.set_xlabel("gamma")
        ax.set_ylabel("energy")

        sm = cm.ScalarMappable(cmap=cmap, norm=norm)
        sm.set_array([])
        cbar = pp.colorbar(sm, ax=ax, ticks=range(n))
        cbar.set_label("Element")
        cbar.set_ticklabels(data_list)

        # Make tick labels pickable
        for label in cbar.ax.get_yticklabels():
            label.set_picker(True)

        ax.set_ylim(0, 0.1)

        # store metadata: axis, lines, active element, and cbar ticklabels
        fig._axis_data.append({
            "ax": ax,
            "line_map": line_map,
            "active_element": active_element,
            "ticklabels": cbar.ax.get_yticklabels(),
        })

    # remove unused subplots if any
    for ax in axes[nblocks:]:
        fig.delaxes(ax)

    # --- Interactive toggle logic ---
    def on_pick(event):
        if isinstance(event.artist, pp.Text):
            picked_label = event.artist.get_text()
    
            for axis_info in fig._axis_data:
                if event.artist in axis_info["ticklabels"]:
                    line_map = axis_info["line_map"]
                    active_element = axis_info["active_element"]
    
                    if picked_label not in line_map:
                        continue
    
                    if active_element["name"] == picked_label:
                        # Reset all lines
                        for line in line_map.values():
                            line.set_alpha(1.0)
                        # Reset label style
                        event.artist.set_fontweight("normal")
                        active_element["name"] = None
                    else:
                        # Highlight picked, dim others
                        for atom, line in line_map.items():
                            line.set_alpha(1.0 if atom == picked_label else 0.25)
                        # Reset all label styles in this axis
                        for lbl in axis_info["ticklabels"]:
                            lbl.set_fontweight("normal")
                        # Set active label bold
                        event.artist.set_fontweight("bold")
                        active_element["name"] = picked_label
    
                    fig.canvas.draw_idle()
                    break  # stop after finding the right axis

    fig.canvas.mpl_connect("pick_event", on_pick)

    if save:
        fname = f"energy_vs_gamma_{basis}.svg"
        fig.savefig(fname)

    return fig


In [None]:
bases = ["aug-cc-pvtz", "awcvtz"]
#basis= bases[0]
for basis in bases:
    plot_all_elements(basis, elements, blocks_dict)

In [None]:
i = 0
for block, atoms in blocks_dict.items():
    for at in atoms:
        i += 1
print(i)