## Description

This notebook is to investigate how different convergence properties related to each other. The goal is to find the best properties defined that are materials independent and can be used to predict the convergence of a PP. 

The properties that are investigated are:
- For pressure, compare the complex defined SSSP v1 residue volume and the vannila hydrostatic pressure
- For EOS metrics, compare nu wrt AE and nu with ref 200Ry. (Check and assure the guess that delta' and nu are correlated)
- Compare pressure and EOS metrics (nu ref 200Ry)
- Other pair see if those are correlated or not

What I think is, if I tuning the criteria of properties, there will be a cross from A > B to B > A. The different between if A, B are correlated or not is whether their will be a state where A, B are highly linearly correlated. 

The testing data is generated by running full convergence test in the grid of [20:5:200] Ry for all different properties calculation method, then can extract and construct the properties date from the output.
The tested PPs are Hg, Ga, N, Cs, Mn from gbrv, dojo, psl-paw-high and jth, in order to cover PPs from different generated code sources and different type of elements.

The AiiDA data is stored at group `SI/convergence-properties-compare`. 

In [1]:
from aiida import load_profile
import typing as t

load_profile("2023-08-07")

from aiida import orm

In [4]:
from aiida_sssp_workflow.workflows.convergence.pressure import helper_get_volume_from_pressure_birch_murnaghan
from aiida_sssp_workflow.calculations.calculate_bands_distance import get_bands_distance
from aiida_sssp_workflow.calculations.calculate_metric import rel_errors_vec_length

def extract_data_scan_list1(node):
    real_scan_list = []
    for wf in node.called:
        if wf.process_label == 'ConvergencePressureWorkChain':
            lst = []
            for wf2 in wf.called:

                if wf2.process_label == 'helper_pressure_difference':
                    lst.append(wf2)
                if wf2.process_label == 'convergence_analysis':
                    break
            
            real_scan_list = wf.outputs.output_parameters_wfc_test.get_dict()['ecutwfc']
        
        else:
            # parse_pseudo_info or _CachingConvergenceWorkChain
            continue
        
    expected_scan_list = list(range(20, 201, 5))
    # find what is in expected but not in real
    missing = list(set(expected_scan_list) - set(real_scan_list))
    if missing:
        # raise a warning
        print(f"Warning - the following cutoffs are missing from node {node.pk}: {missing}")
        scan_list = real_scan_list
    else:
        scan_list = expected_scan_list

    data = {}

    for i, wf in enumerate(lst):
        cutoff = scan_list[i]

        i_para = wf.inputs.input_parameters.get_dict()
        r_para = wf.inputs.ref_parameters.get_dict()
        V0 = wf.inputs.V0.value
        B0 = wf.inputs.B0.value
        B1 = wf.inputs.B1.value
    
        # Get the data
        # Ref_200_nu: the nu value w.r.t. the 200 Ry reference
        # Ref_200_deltap: the delta_p value w.r.t. the 200 Ry reference
        res = get_conv_data1(i_para, r_para, V0, B0, B1)
        data[cutoff] = res

    return data, scan_list

def get_conv_data1(i_para, r_para, V0, B0, B1) -> float:
    """Return AE_delta (per atom), AE_nu, REL_nu (the nu compare to the reference not to AE)"""
    res_pressure = i_para["hydrostatic_stress"]
    ref_pressure = r_para["hydrostatic_stress"]
    absolute_diff = abs(res_pressure - ref_pressure)
    relative_diff = helper_get_volume_from_pressure_birch_murnaghan(
        absolute_diff, V0, B0, B1,
    )

    return relative_diff

def extract_data_scan_list2(node):
    real_scan_list = []
    for wf in node.called:
        if wf.process_label == 'ConvergenceDeltaWorkChain':
            lst = []
            for wf2 in wf.called:

                if wf2.process_label == 'helper_delta_difference':
                    lst.append(wf2)
                if wf2.process_label == 'convergence_analysis':
                    break
            
            real_scan_list = wf.outputs.output_parameters_wfc_test.get_dict()['ecutwfc']
        
        else:
            # parse_pseudo_info or _CachingConvergenceWorkChain
            continue
        
    expected_scan_list = list(range(20, 201, 5))
    # find what is in expected but not in real
    missing = list(set(expected_scan_list) - set(real_scan_list))
    if missing:
        # raise a warning
        print(f"Warning - the following cutoffs are missing from node {node.pk}: {missing}")
        scan_list = real_scan_list
    else:
        scan_list = expected_scan_list

    data = {}

    for i, wf in enumerate(lst):
        cutoff = scan_list[i]

        i_para = wf.inputs.input_parameters.get_dict()
        r_para = wf.inputs.ref_parameters.get_dict()
    
        res = get_conv_data2(i_para, r_para)
        data[cutoff] = res

    return data, scan_list

def get_conv_data2(i_para, r_para) -> float:
    ref_V0, ref_B0, ref_B1 = r_para['birch_murnaghan_results']
    V0, B0, B1 = i_para['birch_murnaghan_results']

    RELA_nu = rel_errors_vec_length(ref_V0, ref_B0, ref_B1, V0, B0, B1)

    return RELA_nu


In [5]:
g = 'SI/convergence-properties-compare'
gs_nodes = []
gs_nodes.extend(orm.Group.collection.get(label=g).nodes)
    
all_data1 = {}
all_data2 = {}
for node in gs_nodes:
    # give a node and the tuple of criteria
    # return the deducted cutoffs of A and B
    try:
        data1, scan_list1 = extract_data_scan_list1(node)
        data2, scan_list2 = extract_data_scan_list2(node)
        all_data1[node.pk] = {
            "data": data1,
            "scan_list": scan_list1    
        }
        all_data2[node.pk] = {
            "data": data2,
            "scan_list": scan_list2    
        }
    except Exception as e:
        #print(f"Error: {e}")
        #continue
        print(node.pk)
        raise e



In [6]:
def extract_cutoff(data, scan_list, criteria):
    """Extract the cutoff for pA and pB from a verification workchain

    Args:
        data (dict): the data extracted from the verification workchain
        scan_list (list): the list of cutoffs used in the verification workchain
        criteria (tuple): first element is the criteria for pA, second element is the criteria for pB

    Returns:
        tuple: the cutoff for pA and pB.
    """
    # Get the cutoff of pA and pB
    cut = 200
    for cutoff in reversed(scan_list):
        try: 
            p = data[cutoff]
        except:
            continue
        
        if p > criteria:
            break

        cut = cutoff

    return cut

In [7]:
def compute_cutoff(data12_tuple, criteria):
    cut_A_lst = []
    cut_B_lst = []

    all_data1 = data12_tuple[0]
    all_data2 = data12_tuple[1]
    criteria1 = criteria[0]
    criteria2 = criteria[1]

    for node_pk in all_data1:
        data = all_data1[node_pk]['data']
        scan_list = all_data1[node_pk]['scan_list']
        cut_A = extract_cutoff(data, scan_list, criteria1)
        cut_A_lst.append(cut_A)
    
    for node_pk in all_data2:
        data = all_data2[node_pk]['data']
        scan_list = all_data2[node_pk]['scan_list']
        cut_B = extract_cutoff(data, scan_list, criteria2)
        cut_B_lst.append(cut_B)
        
    return cut_A_lst, cut_B_lst

In [8]:
# Get data for plotting
cut_A_lst, cut_B_lst = compute_cutoff(data12_tuple=(all_data1, all_data2), criteria=(0.1, 0.1))

In [9]:
import ipywidgets as ipw
import plotly.graph_objects as go

trace_corr_scatter = go.Scatter(x=cut_A_lst, y=cut_B_lst, mode='markers', name='cutoff correlation')
trace_xy_line = go.Scatter(x=[0, 200], y=[0, 200], name='x=y')
g = go.FigureWidget(data=[trace_corr_scatter, trace_xy_line])
g.layout.xaxis.title = 'cutoff pA'
g.layout.yaxis.title = 'cutoff pB'


In [16]:
pA_slider = ipw.FloatSlider(value=2, min=0.00, max=10.0, step=0.1, description='pA')
pB_slider = ipw.FloatSlider(value=0.1, min=0.00, max=0.2, step=0.01, description='pB')

def response(change):
    cut_A_lst, cut_B_lst = compute_cutoff(data12_tuple=(all_data1, all_data2), criteria=(pA_slider.value, pB_slider.value))
    with g.batch_update():
        g.data[0].x = cut_A_lst
        g.data[0].y = cut_B_lst
        
pA_slider.observe(response, names="value")
pB_slider.observe(response, names="value")

slider_widgets = ipw.HBox([pA_slider, pB_slider])
app = ipw.VBox([slider_widgets, g])
app

VBox(children=(HBox(children=(FloatSlider(value=2.0, description='pA', max=10.0), FloatSlider(value=0.1, descr…

In [18]:
# heatmap
import numpy as np

def compute_correlation(x, y):
    cut_A_lst, cut_B_lst = compute_cutoff(data12_tuple=(all_data1, all_data2), criteria=(x, y))
    arr_A = np.array(cut_A_lst)
    arr_B = np.array(cut_B_lst)
    
    z = np.sum(np.abs(arr_A - arr_B) / 40.0)
    
    return z

xxs = np.linspace(0.0, 5, 201)
yys = np.linspace(0.0, 0.1, 201)
# get z map from xxs and yys
zzs = np.zeros((201, 201))
for i, x in enumerate(xxs):
    for j, y in enumerate(yys):
        zzs[j, i] = compute_correlation(x, y)
        
fig = go.FigureWidget(data=go.Heatmap(z=zzs, x=xxs, y=yys, zmax=10, zmin=0))
fig.layout.xaxis.title = 'cutoff from complex pressure'
fig.layout.yaxis.title = 'cutoff from nu'
fig

FigureWidget({
    'data': [{'type': 'heatmap',
              'uid': '99446e5f-ee6d-4522-9452-ba1b5c9600c0',
              'x': array([0.   , 0.025, 0.05 , ..., 4.95 , 4.975, 5.   ]),
              'y': array([0.    , 0.0005, 0.001 , ..., 0.099 , 0.0995, 0.1   ]),
              'z': array([[ 0.25 , 19.875, 31.   , ..., 77.25 , 77.25 , 77.375],
                          [14.625, 18.   , 19.125, ..., 62.875, 62.875, 63.   ],
                          [21.875, 14.   , 13.375, ..., 55.625, 55.625, 55.75 ],
                          ...,
                          [71.5  , 51.375, 40.25 , ...,  6.5  ,  6.5  ,  6.625],
                          [71.5  , 51.375, 40.25 , ...,  6.5  ,  6.5  ,  6.625],
                          [71.5  , 51.375, 40.25 , ...,  6.5  ,  6.5  ,  6.625]]),
              'zmax': 10,
              'zmin': 0}],
    'layout': {'template': '...',
               'xaxis': {'title': {'text': 'cutoff from complex pressure'}},
               'yaxis': {'title': {'text': 'cutoff f