## Description

This notebook is to investigate how different convergence properties related to each other. The goal is to find the best properties defined that are materials independent and can be used to predict the convergence of a PP. 

The properties that are investigated are:
- For pressure, compare the complex defined SSSP v1 residue volume and the vannila hydrostatic pressure
- For EOS metrics, compare nu wrt AE and nu with ref 200Ry. (Check and assure the guess that delta' and nu are correlated)
- Compare pressure and EOS metrics (nu ref 200Ry)
- Other pair see if those are correlated or not

What I think is, if I tuning the criteria of properties, there will be a cross from A > B to B > A. The different between if A, B are correlated or not is whether their will be a state where A, B are highly linearly correlated. 

The testing data is generated by running full convergence test in the grid of [20:5:200] Ry for all different properties calculation method, then can extract and construct the properties date from the output.
The tested PPs are Hg, Ga, N, Cs, Mn from gbrv, dojo, psl-paw-high and jth, in order to cover PPs from different generated code sources and different type of elements.

The AiiDA data is stored at group `SI/convergence-properties-compare`. 

In [1]:
from aiida import load_profile
import typing as t

load_profile("2023-08-07")

from aiida import orm

In [2]:
from aiida_sssp_workflow.workflows.convergence.pressure import helper_get_volume_from_pressure_birch_murnaghan
from aiida_sssp_workflow.calculations.calculate_bands_distance import get_bands_distance

def extract_data_scan_list1(node):
    real_scan_list = []
    for wf in node.called:
        if wf.process_label == 'ConvergenceBandsWorkChain':
            lst = []
            for wf2 in wf.called:

                if wf2.process_label == 'helper_bands_distence_difference':
                    lst.append(wf2)
                if wf2.process_label == 'convergence_analysis':
                    break
            
            real_scan_list = wf.outputs.output_parameters_wfc_test.get_dict()['ecutwfc']

        else:
            # parse_pseudo_info or _CachingConvergenceWorkChain
            continue
        
    expected_scan_list = list(range(20, 201, 5))
    # find what is in expected but not in real
    missing = list(set(expected_scan_list) - set(real_scan_list))
    if missing:
        # raise a warning
        print(f"Warning - the following cutoffs are missing from node {node.pk}: {missing}")
        scan_list = real_scan_list
    else:
        scan_list = expected_scan_list

    data = {}

    for i, wf in enumerate(lst):
        cutoff = scan_list[i]
        band_structure_a = wf.inputs.band_structure_a
        band_parameters_a = wf.inputs.band_parameters_a.get_dict()
        band_structure_b = wf.inputs.band_structure_b
        band_parameters_b = wf.inputs.band_parameters_b.get_dict()
        smearing = wf.inputs.smearing.value
        fermi_shift = wf.inputs.fermi_shift.value
        do_smearing = wf.inputs.do_smearing.value
        spin = wf.inputs.spin.value

        bandsdata_a = {
            "number_of_electrons": band_parameters_a["number_of_electrons"],
            "number_of_bands": band_parameters_a["number_of_bands"],
            "fermi_level": band_parameters_a["fermi_energy"],
            "bands": band_structure_a.get_bands(),
            "kpoints": band_structure_a.get_kpoints(),
            "weights": band_structure_a.get_array("weights"),
        }
        bandsdata_b = {
            "number_of_electrons": band_parameters_b["number_of_electrons"],
            "number_of_bands": band_parameters_b["number_of_bands"],
            "fermi_level": band_parameters_b["fermi_energy"],
            "bands": band_structure_b.get_bands(),
            "kpoints": band_structure_b.get_kpoints(),
            "weights": band_structure_b.get_array("weights"),
        }

        res = get_conv_data1(
            bandsdata_a,
            bandsdata_b,
            smearing,
            fermi_shift,
            do_smearing,
            spin,
        )
        data[cutoff] = res

    return data, scan_list

def get_conv_data1(bandsdata_a, bandsdata_b, smearing, fermi_shift, do_smearing, spin) -> float:
    res = get_bands_distance(
        bandsdata_a,
        bandsdata_b,
        smearing,
        fermi_shift,
        do_smearing,
        spin,
    )

    return res['eta_c']

def extract_data_scan_list2(node):
    real_scan_list = []
    for wf in node.called:
        if wf.process_label == 'ConvergenceBandsWorkChain':
            lst = []
            for wf2 in wf.called:

                if wf2.process_label == 'helper_bands_distence_difference':
                    lst.append(wf2)
                if wf2.process_label == 'convergence_analysis':
                    break
            
            real_scan_list = wf.outputs.output_parameters_wfc_test.get_dict()['ecutwfc']

        else:
            # parse_pseudo_info or _CachingConvergenceWorkChain
            continue
        
    expected_scan_list = list(range(20, 201, 5))
    # find what is in expected but not in real
    missing = list(set(expected_scan_list) - set(real_scan_list))
    if missing:
        # raise a warning
        print(f"Warning - the following cutoffs are missing from node {node.pk}: {missing}")
        scan_list = real_scan_list
    else:
        scan_list = expected_scan_list

    data = {}

    for i, wf in enumerate(lst):
        cutoff = scan_list[i]
        band_structure_a = wf.inputs.band_structure_a
        band_parameters_a = wf.inputs.band_parameters_a.get_dict()
        band_structure_b = wf.inputs.band_structure_b
        band_parameters_b = wf.inputs.band_parameters_b.get_dict()
        smearing = wf.inputs.smearing.value
        fermi_shift = wf.inputs.fermi_shift.value
        do_smearing = wf.inputs.do_smearing.value
        spin = wf.inputs.spin.value

        bandsdata_a = {
            "number_of_electrons": band_parameters_a["number_of_electrons"],
            "number_of_bands": band_parameters_a["number_of_bands"],
            "fermi_level": band_parameters_a["fermi_energy"],
            "bands": band_structure_a.get_bands(),
            "kpoints": band_structure_a.get_kpoints(),
            "weights": band_structure_a.get_array("weights"),
        }
        bandsdata_b = {
            "number_of_electrons": band_parameters_b["number_of_electrons"],
            "number_of_bands": band_parameters_b["number_of_bands"],
            "fermi_level": band_parameters_b["fermi_energy"],
            "bands": band_structure_b.get_bands(),
            "kpoints": band_structure_b.get_kpoints(),
            "weights": band_structure_b.get_array("weights"),
        }

        res = get_conv_data2(
            bandsdata_a,
            bandsdata_b,
            smearing,
            fermi_shift,
            do_smearing,
            spin,
        )
        data[cutoff] = res

    return data, scan_list

def get_conv_data2(bandsdata_a, bandsdata_b, smearing, fermi_shift, do_smearing, spin) -> float:
    res = get_bands_distance(
        bandsdata_a,
        bandsdata_b,
        smearing,
        fermi_shift,
        do_smearing,
        spin,
    )

    return res['eta_v']


In [4]:
g = 'SI/convergence-properties-compare/DC'
gs_nodes = []
gs_nodes.extend(orm.Group.collection.get(label=g).nodes)
    
all_data1 = {}
all_data2 = {}
for node in gs_nodes:
    # give a node and the tuple of criteria
    # return the deducted cutoffs of A and B
    try:
        data1, scan_list1 = extract_data_scan_list1(node)
        data2, scan_list2 = extract_data_scan_list2(node)
        all_data1[node.pk] = {
            "data": data1,
            "scan_list": scan_list1    
        }
        all_data2[node.pk] = {
            "data": data2,
            "scan_list": scan_list2    
        }
    except Exception as e:
        #print(f"Error: {e}")
        #continue
        print(node.pk)
        raise e

In [5]:
def extract_cutoff(data, scan_list, criteria):
    """Extract the cutoff for pA and pB from a verification workchain

    Args:
        data (dict): the data extracted from the verification workchain
        scan_list (list): the list of cutoffs used in the verification workchain
        criteria (tuple): first element is the criteria for pA, second element is the criteria for pB

    Returns:
        tuple: the cutoff for pA and pB.
    """
    # Get the cutoff of pA and pB
    cut = 200
    for cutoff in reversed(scan_list):
        try: 
            p = data[cutoff]
        except:
            continue
        
        if p > criteria:
            break

        cut = cutoff

    return cut

In [6]:
def compute_cutoff(data12_tuple, criteria):
    cut_A_lst = []
    cut_B_lst = []

    all_data1 = data12_tuple[0]
    all_data2 = data12_tuple[1]
    criteria1 = criteria[0]
    criteria2 = criteria[1]

    for node_pk in all_data1:
        data = all_data1[node_pk]['data']
        scan_list = all_data1[node_pk]['scan_list']
        cut_A = extract_cutoff(data, scan_list, criteria1)
        cut_A_lst.append(cut_A)
    
    for node_pk in all_data2:
        data = all_data2[node_pk]['data']
        scan_list = all_data2[node_pk]['scan_list']
        cut_B = extract_cutoff(data, scan_list, criteria2)
        cut_B_lst.append(cut_B)
        
    return cut_A_lst, cut_B_lst

In [7]:
# Get data for plotting
cut_A_lst, cut_B_lst = compute_cutoff(data12_tuple=(all_data1, all_data2), criteria=(0.1, 0.1))

In [8]:
import ipywidgets as ipw
import plotly.graph_objects as go

trace_corr_scatter = go.Scatter(x=cut_A_lst, y=cut_B_lst, mode='markers', name='cutoff correlation')
trace_xy_line = go.Scatter(x=[0, 200], y=[0, 200], name='x=y')
g = go.FigureWidget(data=[trace_corr_scatter, trace_xy_line])
g.layout.xaxis.title = 'cutoff pA'
g.layout.yaxis.title = 'cutoff pB'


In [12]:
pA_slider = ipw.FloatSlider(value=0.1, min=0.00, max=10.0, step=0.005, description='pA')
pB_slider = ipw.FloatSlider(value=0.1, min=0.00, max=10.0, step=0.005, description='pB')

def response(change):
    cut_A_lst, cut_B_lst = compute_cutoff(data12_tuple=(all_data1, all_data2), criteria=(pA_slider.value, pB_slider.value))
    with g.batch_update():
        g.data[0].x = cut_A_lst
        g.data[0].y = cut_B_lst
        
pA_slider.observe(response, names="value")
pB_slider.observe(response, names="value")

slider_widgets = ipw.HBox([pA_slider, pB_slider])
app = ipw.VBox([slider_widgets, g])
app

VBox(children=(HBox(children=(FloatSlider(value=0.1, description='pA', max=10.0, step=0.005), FloatSlider(valu…

In [24]:
# heatmap
import numpy as np

def compute_correlation(x, y):
    cut_A_lst, cut_B_lst = compute_cutoff(data12_tuple=(all_data1, all_data2), criteria=(x, y))
    arr_A = np.array(cut_A_lst)
    arr_B = np.array(cut_B_lst)
    
    N = len(arr_A)
    z = np.sum(np.abs(arr_A - arr_B) / N)
    
    return z

xxs = np.linspace(0.0, 20, 201)
yys = np.linspace(0.0, 20, 201)
# get z map from xxs and yys
zzs = np.zeros((201, 201))
for i, x in enumerate(xxs):
    for j, y in enumerate(yys):
        zzs[j, i] = compute_correlation(x, y)
        
fig = go.FigureWidget(data=go.Heatmap(z=zzs, x=xxs, y=yys, zmax=10, zmin=0))
fig.layout.xaxis.title = 'correlate map'
fig

FigureWidget({
    'data': [{'type': 'heatmap',
              'uid': 'b5de888f-c113-49f7-9614-91c44535bd1d',
              'x': array([ 0. ,  0.1,  0.2, ..., 19.8, 19.9, 20. ]),
              'y': array([ 0. ,  0.1,  0.2, ..., 19.8, 19.9, 20. ]),
              'z': array([[  0.  ,  79.25,  90.  , ..., 159.  , 159.  , 159.25],
                          [ 93.  ,  14.25,  15.  , ...,  68.5 ,  68.5 ,  68.75],
                          [105.5 ,  26.25,  16.  , ...,  56.  ,  56.  ,  56.25],
                          ...,
                          [163.  ,  83.75,  73.  , ...,   4.  ,   4.  ,   3.75],
                          [163.  ,  83.75,  73.  , ...,   4.  ,   4.  ,   3.75],
                          [163.  ,  83.75,  73.  , ...,   4.  ,   4.  ,   3.75]]),
              'zmax': 10,
              'zmin': 0}],
    'layout': {'template': '...', 'xaxis': {'title': {'text': 'correlate map'}}}
})

In [14]:
all_data1

{8281190: {'data': {20: 4.013179350836147,
   25: 0.9965167522862066,
   30: 0.5045872925641953,
   35: 0.06620550621382218,
   40: 0.7926243722305165,
   45: 0.2094187203240178,
   50: 0.6046380876998521,
   55: 0.13945123696640654,
   60: 0.30437171467709456,
   65: 0.1306219947121766,
   70: 0.16271431020328814,
   75: 0.16064727334611334,
   80: 0.10456885772726097,
   85: 0.07931129136178008,
   90: 0.09903409116504117,
   95: 0.06468584590251572,
   100: 0.07447303520526857,
   105: 0.0783951038277987,
   110: 0.029488779453027324,
   115: 0.0526562824289508,
   120: 0.04142748520461174,
   125: 0.03154006548871088,
   130: 0.06337447847977215,
   135: 0.06336693885513911,
   140: 0.06361365465797474,
   145: 0.05739670157647199,
   150: 0.052960036121034504,
   155: 0.046543752140924204,
   160: 0.04568876684156522,
   165: 0.02838677472385196,
   170: 0.02887086609977334,
   175: 0.021871859055550204,
   180: 0.025153564024654282,
   185: 0.045494140756374965,
   190: 0.0324340