In [None]:
import numpy as np
np.set_printoptions(precision=3, suppress=True)
import math 
import re
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.colors import LinearSegmentedColormap
from parser import read_band_energies_and_klist_from_PROCAR, get_tot_index_from_procar,orbvis_orbital_specific_band_data_from_PROCAR
from utils import angle_between,clean_kpoints,dist_bw_two_kpoints,compute_kpoint_distances,merge_close_ticks,orbital_labels, base_colors

In [None]:
class orbvisband:
    def __init__(self,path,ispin):
        self.path = path
        self.ispin = ispin
        #self.scale = 20
        #self.transparency = 70
        #self.colmap_id = 0
        self.x_scale = 3
        self.E_fermi = None
        
        
    
        #processing the bs data from path/ispin using autility function
        self.bs_data, self.kpoints = read_band_energies_and_klist_from_PROCAR(self.path, self.ispin)
        
        #these will be used later
        self.tot_ind = get_tot_index_from_procar(self.path)#this will be used later not required now
        self.data = None
        
        #initial k pt processing
        self.kl_new, self.hs = clean_kpoints(self.kl)
        _, self.reduced_data=compute_kpoint_distances(self.kl_new, self.x_scale)
        #default kpoint setting
        self.num_labels = len(self.hs)
        self.labels = [f'K{i}' for i in range(self.num_labels)]
        self.set_labels()
        
        self.tick_vals = [dict(self.reduced_data)[i] for i in self.hs]
        self.tick_labels = self.labels
        self.merge_tick_vals,self.merge_tick_labels = merge_close_ticks(self.tick_vals, self.tick_labels, tol=1e-5)


        #selects bs_data and x_arr for plain bs_plot
        self.idx_selected = self.reduced_data[:, 0].astype(int)  
        if self.ispin ==1:
            self.bs_selected = self.bs_databs[:, self.idx_selected]
        elif self.ispin ==2:
            self.bs_selected = self.bs_data[:,:,self.idx_selected]
        
        self.x_arr = self.reduced_data[:,1]
    
    #if i call this i can chage HS labels, it will run label merging, which is required
    def set_labels(self):
        print("The following high symmetry points were found:\n")
        for hs_point in self.hs:
            print(self.kl[hs_point][1:4])
        # Unicode help message to aid the user in entering data
        unicode_help = r"""
        Enter the high-symmetry point labels separated by spaces 
        For example, enter:\u0393 X M \u0393 for Γ X M Γ
        These will be stored as raw Unicode strings (not decoded here). Common codes:
        - \u0393 → Γ (Gamma)
        - \u0394 → Δ (Delta)
        - \u03a3 → Σ (Sigma)
        - \u039B → Λ (Lambda)
        You can also combine like: \u039B1 → Λ1

        Press Enter without typing anything to use default labels: K0, K1, K2, ...
        Or type '0' to quit — default labels will still be used.
        """

        print(unicode_help)

        while True:
            
            label_input = input(f"Enter {self.num_labels} high-symmetry point labels (or 0 to quit): ").strip()

     
            if label_input == "0":
                print("User chose to quit. Using default labels:", self.labels)
                break
            
            if label_input == "":
                print("No input provided. Using default labels:", self.labels)
                break

            user_labels = [s.encode('utf-8').decode('unicode_escape') for s in label_input.split()]
            
            if len(self.labels) != self.num_labels:
                
                print(f"Please enter exactly {self.num_labels} labels. You entered {len(user_labels)}.\n")
            
            else:
                self.labels = user_labels
                self.tick_labels = self.labels
                break

        # Final label print
        print("High-symmetry point labels set to:")
        print(self.labels)
        print("\n")
    #you can set the oribital information with either get_input_with_help() or get_input() or from config file
    def set_orbital_data_with_help(self):
        input_data_help = r"""----Orbital Projection Data Input----
        Enter the entire data array at once.
        Each Item should be a list or tuple like
        
        [atom_numbers (list of ints), element_name (str), orbital_numbers (list of ints)]
        
        atom_numbers start from 0 and follow the order in POSCAR
        orbital_numbers such as 0 for $s$, 1 for $p_y$. For more info on orbital numbers,
        check first few lines of PROCAR or go through tutorials/documentation.
        
        You can also enter a combination of orbitals for a combination of different atoms.
        
        Examples: [([0, 1],'N', [0, 1]), ([3],'O', [2, 3])]
                               or
                [ [ [0, 1],'N', [0, 1] ], [ [2],'O', [2, 3] ] ]
        
        means, you want $s$ and $p_y$ orbitals of the 0(actually first because 0 indexed) and 1(second) atoms in POSCAR, which are nitrogen atoms
        and also $p_z$ and $p_x$ orbital contributions of 2nd(actually 3rd cz zero indexed) atom in POSCAR which is oxygen.
        
        for getting total contribution of an atom(s), check tot_ind attribute of orbitalbandplot parent class
        for getting contribution of a particular element, enter all the atom indices from POSCAR in zero indexed(starting from 0) form from POSCAR

        Press Enter without typing anything to cancel.

        """
        print(input_data_help)
        user_input = input("Enter full projection data: ").strip()

        if user_input == "":
            print("No input provided. Returning empty list.")
            return

        try:
            self.data = eval(user_input, {"__builtins__": None}, {})
            print("Data parsed successfully")
        
        except Exception as e:
            print(f"Error parsing input: {e}")
    
    
    def set_orbital_data(self, data):
        self.data = data
    
    def set_fermi_energy(self, e_f):
        self.E_fermi = e_f   

    def orbscatter(self,title="orb scatter",colscheme=0,ylim=[-6,6],linewidth=0.5,dpi=300,scale=1,transparency=70,save_as="orb_scatter.jpg"):
            all_procar_data = []
            all_labels = []
            entry_count = 0
            for idx, entry in enumerate(self.data):
                atom_list, element_name, orbital_list = entry
                # --- Start Checks ---------------
                for orb in orbital_list:
                    if orb > self.tot_ind:
                        raise ValueError(
                            f"Please check orbital numbers you entered. It should be less than or equal to {self.tot_ind}. "
                            f"Also, if you want 'tot', enter index = {self.tot_ind} and don't club it with other orbitals."
                        )
                if len(orbital_list) > 1 and self.tot_ind in orbital_list:
                    raise ValueError(
                        f"Please check orbital numbers you entered. If you want 'tot', enter index = {self.tot_ind} "
                        "and don't club it with other orbitals."
                    )
                # --- End Checks ------------------

                temp_label = element_name
                temp_procar_data = np.zeros_like(self.bs_data)
                atom_count = 0
                for atom in atom_list:
                    orbital_count = 0
                    for orbital in orbital_list:
                        temp_procar_data += orbvis_orbital_specific_band_data_from_PROCAR(self.path,atom,orbital,self.ispin)
                        if atom_count == 0 and orbital_count == 0:
                            if orbital == self.tot_ind:
                                temp_label+=r"$tot$"
                            else:
                        
                                temp_label+=orbital_labels[int(orbital)]
                            
                        
                    
                        elif atom_count ==0 and orbital_count > 0:
                            temp_label=temp_label+r"$+$"+orbital_labels[int(orbital)]
                            
                        else:
                            continue
                        orbital_count+=1
                    atom_count+=1

                all_labels.append(temp_label)
                all_procar_data.append(temp_procar_data)
                entry_count+=1
                #print(all_labels)
            colmaps = [custom_cmap(i / (len(all_labels) - 1)) for i in range(len(all_labels))]

    # Create figure and axes
            if self.ispin ==1:
                fig, ax = plt.subplots()
                custom_marker_array = []
                # Plot each curve with scatter points
                for bandindex in range(self.bs_selected.shape[0]):
                    ax.plot(self.x_arr, self.bs_selected[bandindex], linewidth=0.3, color="black")
                for entry_id in range(entry_count):
                    orbdata = all_procar_data[entry_id]
                    orbdata_at_selected_k = orbdata[:, self.idx_selected]
                    for bandindex in range(self.bs_selected.shape[0]):
                        ax.scatter(self.x_arr, self.bs_selected[bandindex], s=scale * orbdata_at_selected_k[bandindex], alpha=0.6, label=all_labels[entry_id], color=colmaps[entry_id])#make colmaps id dynamic its stuck to 0 now
                    custom_marker_array.append(Line2D([0], [0], color =colmaps[entry_id], marker='o', linestyle="",markersize=5, label=all_labels[entry_id]))

            # Set axis limits and labels
                plt.legend(handles=custom_marker_array,loc='lower right',framealpha=0.3)
                ax.set_ylim(ylim[0], ylim[1])
                ax.set_xlim(self.x_arr.min(),self.x_arr.max())
                ax.set_xticks(self.merge_tick_vals)
                ax.set_xticklabels(self.merge_tick_labels)
                plt.savefig(save_as)
            
            if self.ispin ==2:
                fig, ax = plt.subplots(1,2,figsize=(8,3))
                custom_marker_array = []
                # Plot each curve with scatter points
                for bandindex in range(self.bs_selected[0].shape[0]):
                    ax[0].plot(self.x_arr, self.bs_selected[0][bandindex], linewidth=0.3, color="black")
                for bandindex in range(self.bs_selected[1].shape[0]):
                    ax[1].plot(self.x_arr, self.bs_selected[1][bandindex], linewidth=0.3, color="black")

                for entry_id in range(entry_count):
                    orbdata_up = all_procar_data[entry_id][0]
                    orbdata_down = all_procar_data[entry_id][1]
                    orbdata_at_selected_k_up = orbdata_up[:, self.idx_selected]
                    orbdata_at_selected_k_down = orbdata_down[:, self.idx_selected]
                    for bandindex in range(self.bs_selected[0].shape[0]):
                        ax[0].scatter(self.x_arr, self.bs_selected[0][bandindex], s=scale * orbdata_at_selected_k_up[bandindex], alpha=0.6, label=all_labels[entry_id], color=colmaps[entry_id])#make colmaps id dynamic its stuck to 0 now
                    for bandindex in range(self.bs_selected[1].shape[0]):
                        ax[1].scatter(self.x_arr, self.bs_selected[1][bandindex], s=scale * orbdata_at_selected_k_down[bandindex], alpha=0.6, color=colmaps[entry_id])#make colmaps id dynamic its stuck to 0 now
                    custom_marker_array.append(Line2D([0], [0], color =colmaps[entry_id], marker='o', linestyle="",markersize=5, label=all_labels[entry_id]))

            # Set axis limits and labels
                plt.legend(handles=custom_marker_array,loc='lower right',framealpha=0.3)
                for ax in (ax[0], ax[1]):
                    ax.set_xticks(self.merge_tick_vals)
                    ax.set_xticklabels(self.merge_tick_labels)
                    ax.set_xlim(self.x_arr.min(),self.x_arr.max())
                    ax.set_ylim(ylim[0], ylim[1])
        
                plt.savefig(save_as)

        

In [None]:
orbital_labels = {0:  r"$s$",1:  r"$p_y$",2:  r"$p_z$",3:  r"$p_x$",4:  r"$d_{xy}$",5:  r"$d_{yz}$",6:  r"$d_{z^2}$",
7:  r"$d_{xz}$",8:  r"$d_{x^2 - y^2}$",9:  r"$f_{y(3x^2 - y^2)}$",10: r"$f_{xyz}$",11: r"$f_{yz^2}$",12: r"$f_{z^3}$",
13: r"$f_{xz^2}$",14: r"$f_{z(x^2 - y^2)}$",15: r"$f_{x(x^2 - 3y^2)}$"}
base_colors = [["r","g","b","orange","yellow","lightblue"],
            ["#8ecae6","#219ebc","#023047","#ffb703","#fb8500"],
            ["#264653","#2a9d8f","#e9c46a","#f4a261","#e76f51"],
            ["#ff595e","#ffca3a","#8ac926","#1982c4","#6a4c93"],
            ["#2b2d42","#8d99ae","#edf2f4","#ef233c","#d90429"],
            ["#0b3954","#087e8b","#bfd7ea","#ff5a5f","#c81d25"],
            ["#220901","#621708","#941b0c","#bc3908","#f6aa1c"],
            ["#355070","#6d597a","#b56576","#e56b6f","#eaac8b"],
            ["#003049","#d62828","#f77f00","#fcbf49","#eae2b7"],
            ["#eae2b7","#fe7f2d","#fcca46","#a1c181","#619b8a"]]


In [11]:
def plot_orb_scatter(path,data,ispin = 1, scale=20, transparency=70,colmap_id=0):
    """
    Processes multiple sets of atomic data.

    Parameters:
    - data (list of lists or tuples): Each item should be a list or tuple with:
        [atom_numbers (list of ints), element_name (str), orbital_numbers (list of ints)]
    - scale (int, optional): Scale factor (default is 20).
    - transparency (int, optional): Transparency level (default is 70).
    -colmap_id (int): default 0 upto 10
    """


    custom_cmap = LinearSegmentedColormap.from_list("custom_gradient", base_colors[colmap_id])


    bs,kl = read_band_energies_and_klist_from_PROCAR(path, ispin)
    tot_ind = get_tot_index_from_procar(path)
    kl_new,hs=clean_kpoints(kl)

    print("The following high symmetry points were found:\n")
    for hs_point in hs:
        print(kl[hs_point][1:4])
 

    # Number of high-symmetry points
    num_labels = len(hs)

    # Unicode help message
    unicode_help = r"""
    Enter the high-symmetry point labels separated by spaces (e.g., \u0393 X M \u0393)
    These will be stored as raw Unicode strings (not decoded here). Common codes:
    - \u0393 → Γ Gamma
    - \u0394 → Δ Delta
    - \u03a3 → Σ Sigma
    - \u039B → Λ Lambda
    You can also combine like: \u039B1 → Λ1

    Press Enter without typing anything to use default labels: K0, K1, K2, ...
    Or type '0' to quit — default labels will still be used.
    """

    print(unicode_help)

    while True:
        label_input = input(f"Enter {num_labels} high-symmetry point labels (or 0 to quit): ").strip()

        if label_input == "0" or label_input == "":
            # Use default labels either on quit or empty input
            labels = [f'K{i}' for i in range(num_labels)]
            if label_input == "0":
                print("User chose to quit. Using default labels:", labels)
            else:
                print("No input provided. Using default labels:", labels)
            break

        labels = [s.encode('utf-8').decode('unicode_escape') for s in label_input.split()]
        if len(labels) != num_labels:
            print(f"Please enter exactly {num_labels} labels. You entered {len(labels)}.\n")
        else:
            break

    # Final label print
    print("High-symmetry point labels:")
    print(labels)
    print("\n")
 
    
    full_data, reduced_data=compute_kpoint_distances(kl_new, x_scale=3)
    tick_vals = [dict(reduced_data)[i] for i in hs]
    tick_labels = labels
    idx_selected = reduced_data[:, 0].astype(int)  
    if ispin ==1:
        bs_selected = bs[:, idx_selected]
    elif ispin ==2:
        bs_selected = bs[:,:,idx_selected]
    x_arr = reduced_data[:,1]
    all_procar_data = []
    all_labels = []
    entry_count = 0
    for idx, entry in enumerate(data):
        atom_list, element_name, orbital_list = entry
        # --- Start Checks ---------------
        for orb in orbital_list:
            if orb > tot_ind:
                raise ValueError(
                    f"Please check orbital numbers you entered. It should be less than or equal to {tot_ind}. "
                    f"Also, if you want 'tot', enter index = {tot_ind} and don't club it with other orbitals."
                )
        if len(orbital_list) > 1 and tot_ind in orbital_list:
            raise ValueError(
                f"Please check orbital numbers you entered. If you want 'tot', enter index = {tot_ind} "
                "and don't club it with other orbitals."
            )
        # --- End Checks ------------------

        temp_label = element_name
        temp_procar_data = np.zeros_like(bs)
        atom_count = 0
        for atom in atom_list:
            orbital_count = 0
            for orbital in orbital_list:
                temp_procar_data += orbvis_orbital_specific_band_data_from_PROCAR(path,atom,orbital,ispin)
                if atom_count == 0 and orbital_count == 0:
                    if orbital == tot_ind:
                        temp_label+=r"$tot$"
                    else:
                  
                        temp_label+=orbital_labels[int(orbital)]
                    
                
              
                elif atom_count ==0 and orbital_count > 0:
                    temp_label=temp_label+r"$+$"+orbital_labels[int(orbital)]
                    
                else:
                    continue
                orbital_count+=1
            atom_count+=1

        all_labels.append(temp_label)
        all_procar_data.append(temp_procar_data)
        entry_count+=1
        #print(all_labels)
    # Generate n colors from this custom colormap which was set previously
    
    colmaps = [custom_cmap(i / (len(all_labels) - 1)) for i in range(len(all_labels))]

    # Create figure and axes
    if ispin ==1:
        fig, ax = plt.subplots()
        custom_marker_array = []
        # Plot each curve with scatter points
        for bandindex in range(bs_selected.shape[0]):
            ax.plot(x_arr, bs_selected[bandindex], linewidth=0.3, color="black")
        for entry_id in range(entry_count):
            orbdata = all_procar_data[entry_id]
            orbdata_at_selected_k = orbdata[:, idx_selected]
            for bandindex in range(bs_selected.shape[0]):
                ax.scatter(x_arr, bs_selected[bandindex], s=scale * orbdata_at_selected_k[bandindex], alpha=0.6, label=all_labels[entry_id], color=colmaps[entry_id])#make colmaps id dynamic its stuck to 0 now
            custom_marker_array.append(Line2D([0], [0], color =colmaps[entry_id], marker='o', linestyle="",markersize=5, label=all_labels[entry_id]))

    # Set axis limits and labels
        plt.legend(handles=custom_marker_array,loc='lower right',framealpha=0.3)
        ax.set_ylim(-6, 3)
        ax.set_xlim(x_arr.min(),x_arr.max())
        ax.set_xticks(tick_vals)
        ax.set_xticklabels(tick_labels)
        plt.show()
    
    if ispin ==2:
        fig, ax = plt.subplots(1,2,figsize=(8,3))
        custom_marker_array = []
        # Plot each curve with scatter points
        for bandindex in range(bs_selected[0].shape[0]):
            ax[0].plot(x_arr, bs_selected[0][bandindex], linewidth=0.3, color="black")
        for bandindex in range(bs_selected[1].shape[0]):
            ax[1].plot(x_arr, bs_selected[1][bandindex], linewidth=0.3, color="black")

        for entry_id in range(entry_count):
            orbdata_up = all_procar_data[entry_id][0]
            orbdata_down = all_procar_data[entry_id][1]
            orbdata_at_selected_k_up = orbdata_up[:, idx_selected]
            orbdata_at_selected_k_down = orbdata_down[:, idx_selected]
            for bandindex in range(bs_selected[0].shape[0]):
                ax[0].scatter(x_arr, bs_selected[0][bandindex], s=scale * orbdata_at_selected_k_up[bandindex], alpha=0.6, label=all_labels[entry_id], color=colmaps[entry_id])#make colmaps id dynamic its stuck to 0 now
            for bandindex in range(bs_selected[1].shape[0]):
                ax[1].scatter(x_arr, bs_selected[1][bandindex], s=scale * orbdata_at_selected_k_down[bandindex], alpha=0.6, color=colmaps[entry_id])#make colmaps id dynamic its stuck to 0 now
            custom_marker_array.append(Line2D([0], [0], color =colmaps[entry_id], marker='o', linestyle="",markersize=5, label=all_labels[entry_id]))

    # Set axis limits and labels
        plt.legend(handles=custom_marker_array,loc='lower right',framealpha=0.3)
        for ax in (ax[0], ax[1]):
            ax.set_xticks(tick_vals)
            ax.set_xticklabels(tick_labels)
            ax.set_xlim(x_arr.min(),x_arr.max())
            ax.set_ylim(-6, 3)
 
        plt.show()