# Creating E vs V Figure for IrO2 and IrO3
---

TODO Eliminate duplicates (correctly)
TODO Add color bar label

# Import Modules

In [None]:
import os
import ase
from ase.db import connect
import sys
import copy
import pickle

import numpy as np
import pandas as pd

from plotly.subplots import make_subplots

import chart_studio.plotly as py
import plotly.graph_objs as go
import plotly.express as px
import plotly.io as plio


from inputs import (
    Ir_ref,
    O_ref,
    coord_env_style)

from layout import layout



structure_id_map = {
    '64cg6j9any': 'i (rutile)',  # rutile
    'cg8p7fxq65': 'anatase', # anatase
    'm2bs8w82x5': 'brookite', # Brookite
    'n36axdbw65': 'ii', # 2nd stable columbite like?
    '85z4msnl6o': 'iii (pyrite)', # Pyrite                    
    #'myc4ng73xh': 'v', # Fm3m
    'zizr7rvpxs': 'vi', # Porous
    'b49kx4c19q': 'v (columbite)', # Columbite
    'nscdbpmdct': 'iv',  # P63 (layered)                    
    #'m2bs8w82x5': 'vi',
    # IrO3
    'mp6lno9jzr': 'i', # 482_2d
    '9i6ixublcr': 'iii', # porous
    'v2blxebixh': 'ii', # sg=2
    'nrml6dms9l': 'iv',   # 472_mplowest _63
    #'xozr8f7p7g': 'iv',  # Mp 2nd sg=38                    
    '6tmjv4myvg': 'v',  # 1D sg=1
    #'9lmkmh8s8r': '', # 489_alpha
    #'zimixdvdxd': '', #492_alpha_like
    'b5cgvsb16w': '(3)', #'rutile-like',
    '8p8evt9pcg': '(1)', #'alpha',
    'zimixdvdxd': '(2)', #'P6_322',
    'mj7wbfb5nt': '(4)', #'sg=52, battery?',
    '949rnem5z2': '(5)'   #'sg=53',
    }

dx = 0.2

# Read Data

In [None]:
# /mnt/f/Dropbox/01_norskov/00_git_repos/PROJ_IrOx_Active_Learning_OER/

# "workflow/ml_modelling/energy_vs_volume/kirsten_E_vs_V_analysis/scripts/out_data"

In [None]:
%%capture

# #############################################################################
# Structural Analysis db file
FinalStructuresdb_file = os.path.join(
    # os.environ["PROJ_irox_2"],
    os.environ["PROJ_irox"],
    "workflow/ml_modelling/energy_vs_volume/kirsten_E_vs_V_analysis/scripts/out_data",
    "FinalStructures_1.db")

    # "FIGS_IrOx_Active_Learning_OER/01_figures/00_main_publ_figs/03_E_vs_V_coord/raul_work/scripts",
    # "out_data",
    # # "FinalStructures2.db",
    # "FinalStructures_1.db")

db = connect(FinalStructuresdb_file)

# #############################################################################
# Duplicates list
duplicates = pickle.load(open("../duplicates.pickle", "rb"))


# #############################################################################
# Bulk DFT Dataframe
sys.path.insert(0, os.path.join(
    os.environ["PROJ_irox"], "workflow/ml_modelling"))
from ml_methods import get_data_for_al

data_dict = get_data_for_al(stoich="AB2", drop_too_many_atoms=True)
df_bulk_dft_ab2 = data_dict["df_bulk_dft"]

data_dict = get_data_for_al(stoich="AB3", drop_too_many_atoms=True)
df_bulk_dft_ab3 = data_dict["df_bulk_dft"]

# Combine AB2/3 Dataframes
df_bulk_dft = pd.concat([df_bulk_dft_ab2, df_bulk_dft_ab3])
df_bulk_dft = df_bulk_dft[df_bulk_dft.source == "raul"]

In [None]:
# assert False

In [None]:
FinalStructuresdb_file

In [None]:
# db.count()

# Construct DataFrame



In [None]:
ase.__version__

In [None]:
#for row in db.select():
#    print(row)

In [None]:
data_list = []
for row in db.select():
    row_dict = dict(
        energy=row.get("energy"),
        volume=row.get("volume"),
        **row.key_value_pairs,
        )
    data_list.append(row_dict)

df = pd.DataFrame(data_list)

df = df[~df["stoich"].isna()]
#df = df[~df["coor_env"].isna()]

print("Total df rows:", df.structure_id.shape[0])
print("Unique structure ids:", df.structure_id.unique().shape[0])
print("")

In [None]:
len(df.structure_id.unique())

In [None]:
# #############################################################################
# Drop duplicates #############################################################

# Drop AB3 duplicates
drop_index = df[df.structure_id.isin(duplicates["AB3"])].index
print(df.shape)
df = df.drop(index=drop_index)
print(df.shape)
# Drop AB2 duplicates
drop_index = df[df.structure_id.isin(duplicates["AB2"])].index
df = df.drop(index=drop_index)
print(df.shape)
# Set structure_id to df index
df = df.set_index("structure_id")


# #############################################################################
# Merge dataframes together ###################################################

# Drop unnecessary duplicate columns before merging
df = df.drop([
    "energy",
    # "volume",
    "stoich",
    "id_old",
    ], axis=1)

df = pd.merge(df, df_bulk_dft,
    left_index=True,
    right_index=True,
    )

#print("df.shape:", df.shape)
#print("df_bulk_dft.shape:", df_bulk_dft.shape)


# #############################################################################
# Calculate volume per atom for all systems
def method(row_i):
    atoms = row_i.atoms
    volume = row_i.volume

    num_atoms = atoms.get_number_of_atoms()
    vol_peratom = volume / num_atoms
    return(vol_peratom)

df["volume_peratom"] = df.apply(
    method,
    axis=1)

In [None]:
701 - 643

643 - 566

In [None]:
df.shape

In [None]:
assert False

# Process Dataframe

In [None]:
sys.path.insert(0, ".")
from colors import get_color_scale

colorscale_i = get_color_scale(df=df, dx=dx)

# Sorting data to bring out 4/6 coordination

In [None]:
df = df.sort_values("mean_coor")

df_concat_list = [
    df[
        (df.mean_coor < 4 + dx) & \
        (df.mean_coor > 4 - dx)
        ],

    df[
        (df.mean_coor < 6 + dx) & \
        (df.mean_coor > 6 - dx)
        ],
    
    ]

df_tmp = pd.concat(df_concat_list)

remaining_ids = [i for i in df.index if i not in df_tmp.index]
df = pd.concat([df_tmp, df.loc[remaining_ids]])


df = df.reindex(index=df.index[::-1])

print('Total IrO2:', len(df[df.stoich == "AB2"]['dH']))
print('Total IrO3:', len(df[df.stoich == "AB3"]['dH']))     

print('Metastable IrO2:', len(np.where(df[df.stoich == "AB2"]['dH'].values < -0.33)[0]))
print('Metastable IrO3:', len(np.where(df[df.stoich == "AB3"]['dH'].values < -0.34)[0]))


# Plotting

## Shared scatter attributes

In [None]:
scatter_shared = go.Scatter(
    mode="markers",
    hoverinfo="text",
    marker=dict(
        symbol="circle",
        size=4,
        opacity=0.8,
        line=dict(
            color="black",
            # width=1,
            width=0.,
            ),
        colorscale=colorscale_i,

        colorbar=dict(
            bordercolor="green",
            outlinecolor="black",
            tickcolor="black",
            xanchor="right",
            # x=1.091,
            # x=1.1,
            x=1.15,
            len=1.16,
            lenmode="fraction",
            # #################################################################
            thickness=15,
            thicknessmode=None,
            tickprefix=None,
            ticks="outside",
            # #################################################################
            tickvals = [2, 4, 6, 8, 10, 12],
            y=0.50005,
            yanchor="middle",
            ypad=10,
            borderwidth=None,

            title=go.scatter.marker.colorbar.Title(
                font=None,
                side="right",  # ['right', 'top', 'bottom']
                text="Ir-O Coord. Num.",
                ),

            # titlefont=None,
            # titleside=None,

            ),

        ),
    )

## Create AB2/3 traces

In [None]:
%%capture

df_i = df[df.stoich == "AB2"]
trace_ab2 = go.Scatter(
    x=df_i.volume_peratom,
    y=df_i.dH,
    text=[str(i) for i in df_i.mean_coor.tolist()],
    marker=dict(color=df_i.mean_coor, size=3))
trace_ab2.update(**scatter_shared.to_plotly_json())



# #############################################################################
df_i = df[df.stoich == "AB3"]
trace_ab3 = go.Scatter(
    x=df_i.volume_peratom,
    y=df_i.dH,
    text=[str(i) for i in df_i.mean_coor.tolist()],
    marker=dict(color=df_i.mean_coor))
trace_ab3.update(**scatter_shared.to_plotly_json())

# Shapes

In [None]:
from shapes import get_plot_shapes

inset_range_0_x = [9.5, 17.]
inset_range_1_x = [9.5, 17.5]


out_dict = get_plot_shapes(
    df=df,
    inset_range_0_x=inset_range_0_x,
    inset_range_1_x=inset_range_1_x,
    )

shapes_list = out_dict["shapes_list"]

shape_inset_metastability_ab2 = go.layout.Shape(
    type="line",
    x0=0,
    y0=-0.33285956787756277, #ab2_min_e + metastability_limit,
    x1=40,
    y1=-0.33285956787756277,
    xref="x1",
    yref="y1",
    line=dict(
        color="grey",
        width=1,
    )
)

shape_inset_metastability_ab3 = go.layout.Shape(
    type="line",
    x0=0,
    y0=-0.3438547784081729, 
    x1=40,
    y1=-0.3438547784081729, 
    xref="x2",
    yref="y2",
    line=dict(
        color="grey",
        width=1,
    )
)

shapes_list += [shape_inset_metastability_ab2, shape_inset_metastability_ab3]
inset_range_0_y = out_dict["inset_range_0_y"]
inset_range_1_y = out_dict["inset_range_1_y"]

## Create subplot

In [None]:
%%capture

inset_attr = dict(l=0.5, b=0.5)
fig = make_subplots(
    rows=1, cols=2,
    shared_xaxes=True,
    shared_yaxes=True,
    specs = [[{}, {}]],
    insets=[
        {
            "cell": (1,1),
            **inset_attr,
            },

        {
            "cell": (1,2),
            **inset_attr,
            },
        ],
    horizontal_spacing=0.04)


# #############################################################################
# Add traces ##################################################################
fig.add_trace(trace_ab2, row=1, col=1)
fig.add_trace(trace_ab3, row=1, col=2)

fig.add_trace(copy.deepcopy(trace_ab2).update(xaxis="x3", yaxis="y3"))
fig.add_trace(copy.deepcopy(trace_ab3).update(xaxis="x4", yaxis="y4"))

In [None]:
# #############################################################################
# Update Layout ###############################################################
fig.update_layout(layout)
fig.update_xaxes(layout["xaxis"])
fig.update_yaxes(layout["yaxis"])

fig.layout.yaxis2.title = None

# Modifying inset props
fig.layout.xaxis3.title = None
fig.layout.yaxis3.title = None

fig.layout.xaxis4.title = None
fig.layout.yaxis4.title = None



fig.layout.xaxis3.range = inset_range_0_x
fig.layout.yaxis3.range = inset_range_0_y


fig.layout.xaxis4.range = inset_range_1_x
fig.layout.yaxis4.range = inset_range_1_y


# fig.layout.xaxis3
fig.layout.xaxis3.tickfont.size = 6 * (4 / 3)
fig.layout.yaxis3.tickfont.size = 6 * (4 / 3)

fig.layout.xaxis4.tickfont.size = 6 * (4 / 3)
fig.layout.yaxis4.tickfont.size = 6 * (4 / 3)

fig.layout.xaxis.dtick = 5
fig.layout.yaxis.dtick = 0.5

fig.layout.xaxis2.dtick = 5
fig.layout.yaxis2.dtick = 0.5

fig.layout.xaxis3.dtick = 2

fig.layout.xaxis4.dtick = 2

fig.layout.xaxis3.ticklen = 3
fig.layout.xaxis4.ticklen = 3

# COMBAK
fig.layout.yaxis3.tickmode = "array"
fig.layout.yaxis3.tickvals = [-0.8, -0.7, -0.6, -0.5, -0.4]
fig.layout.yaxis3.ticklen = 3

fig.layout.yaxis4.tickmode = "array"
fig.layout.yaxis4.tickvals = [-0.65, -0.6, -0.55, -0.5]#[-0.7, -0.6, -0.5, -0.4, -0.3]
fig.layout.yaxis4.ticklen = 3

In [None]:
%%capture

#fig.layout.update(dict(
annotations=[
        go.layout.Annotation(
            x=9.4,
            y=1.76,
            xref="x",
            yref="y",
            text="IrO<sub>2</sub>",
            showarrow=False,

            bgcolor="rgba(255,255,255,0.7)",
            font=go.layout.annotation.Font(
                color="black",
                family=None,
                size=10 * (4/3),                
                ),

            ax=0,
            ay=0,
            ),


        go.layout.Annotation(
            x=9.4,
            y=1.76,
            xref="x2",
            yref="y2",
            text="IrO<sub>3</sub>",
            showarrow=False,

            bgcolor="rgba(255,255,255,0.7)",
            font=go.layout.annotation.Font(
                color="black",
                family=None,
                size=10 * (4/3),                
                ),

            ax=0,
            ay=0,
            ),


]
#    ))



#annotations = []
for key, val in structure_id_map.items():
    #print(key, val)
    #df_i = df[df.stoich == "AB2"]
    #print(df_i.values)
    try:
        df_i = df.loc[key]
    except:
        print(key, 'not found')
        continue

    y = df_i.dH
    x = df_i.volume_peratom

    if df_i.stoich == 'AB2':
        if y < -0.55:
            sub_x = 'x3'
            sub_y = 'y3'
        else:
            sub_x = 'x'
            sub_y = 'y'
    elif df_i.stoich == 'AB3':
        if y < -0.4 and x < 16.5:
            sub_x = 'x4'
            sub_y = 'y4'
        else:
            sub_x = 'x2'
            sub_y = 'y2'
            sx = 0.5
            sy = 0.5
    arrowshifty = 0
    if len(val) > 8:
        arrowshift = len(val) * 2.5
    elif len(val) > 4:
        arrowshift = len(val) * 3
    elif '(' in val:
        arrowshift = 15
    else:
        arrowshift = 10
    if val =='(2)':
        arrowshift *= -1
    if val == 'iii (pyrite)':
        arrowshifty = -4

    annot_i = go.layout.Annotation(
        x=x,
        y=y,

        # COMBA
        
        xref=sub_x,
        yref=sub_y,
        #align='left',
        text=val,
        showarrow=True,
        arrowhead=1,

        #bgcolor="rgba(255,255,255,0.7)",
        font=go.layout.annotation.Font(
            color="black",
            family=None,
            size=6 * (4/3),
            ),

        ax=arrowshift,
        ay=arrowshifty,
        )

    annotations.append(annot_i)


#print(annotations)

fig.layout.update(annotations=annotations)

## Write/display plot

In [None]:
for shape_i in shapes_list:
    fig.add_shape(shape_i)

#plio.write_html(fig, 'out_plot/E_vs_V_plot_2.html')

#fig.show()

In [None]:
if True:
    fig.write_image('out_plot/E_vs_V_plot_2.pdf')

In [None]:
fig.show()

In [None]:
assert False

# Larger plot

In [None]:
%%capture

# fig.layout.annotations[0]["font"]["size"] = 12 * (4 / 3)
# fig.layout.annotations[1]["font"]["size"] = 12 * (4 / 3)
fig.layout.margin = None

fig.update_layout(
    width=800, height=500, 
    )

fig.update_traces(marker_size=8)

In [None]:
assert False

# Histogram Plot

In [None]:
import plotly.express as px

fig = px.histogram(
    df,
    x="mean_coor",
    color="stoich",
    marginal="rug",  # can be `box`, `violin`
    opacity=0.9,
    nbins=100,
    # barnorm="fraction",
    histnorm="percent",
    # hover_data=tips.columns,
    )

#fig.show()

In [None]:
# COMBAK

In [None]:
from ase_modules.ase_methods import view_in_vesta


# #############################################################################
# AB2 #########################################################################
df_i = df[df.stoich == "AB2"]
df_tmp = df_i[df_i.dH < ab2_min_e + 0.2].sort_values("dH")
atoms_list = df_tmp.atoms.tolist()

print("Number of metastable AB2: ", len(atoms_list))

# # #############################################################################
# # AB3 #########################################################################
# df_i = df[df.stoich == "AB3"]
# df_tmp = df_i[df_i.dH < ab3_min_e + 0.2].sort_values("dH")
# atoms_list = df_tmp.atoms.tolist()

# print("Number of metastable AB3: ", len(atoms_list))          

# view_in_vesta(atoms_list)