# TEMP
---

# Import Modules

In [None]:
import os
import sys

import json
import time

import copy
import pickle

import numpy as np
import pandas as pd

import plotly.graph_objects as go
import chart_studio.plotly as py
from chart_studio.grid_objs import Grid, Column

# #############################################################################
from plotting.my_plotly import my_plotly_plot

sys.path.insert(0, os.path.join(os.environ["PROJ_irox"], "data"))
from proj_data_irox import (
    bulk_dft_data_path,
    ids_to_discard__too_many_atoms_path,
    unique_ids_path,
    df_dij_path)

print(os.getcwd())

In [None]:
from plotting.my_plotly import my_plotly_plot


In [None]:
sys.path.insert(0, os.path.join(
    os.environ["PROJ_irox"],
    "python_classes/active_learning"))
from active_learning import (
    ALBulkOpt,
    ALGeneration,
    RegressionModel,
    FingerPrints,
    CandidateSpace,
    )

# Methods

In [None]:
def process_parity_plot(al_data_dict=None, df_bulk_dft=None, df_features_pre=None, df_features_post=None):
    """
    """
    last_key = list(al_data_dict.keys())[-1]
    al_i = al_data_dict[last_key]

    gp_model = al_i["model_inst"]
    computed_ids = al_i["computed_ids"]

    # %%capture

    df_bulk_dft = df_bulk_dft.loc[df_bulk_dft.index.intersection(computed_ids)]
    df_features_pre = df_features_pre.loc[df_features_pre.index.intersection(computed_ids)]
    df_features_post = df_features_post.loc[df_features_post.index.intersection(computed_ids)]

    FP = FingerPrints(
        df_features_pre,
        df_features_post=df_features_post)

    FP.clean_data(df_features_post, df_features_pre,
        clean_variance_flag=True,
        clean_skewness_flag=True,
        clean_infinite_flag=True,
        standardize_data_flag=True)

    FP.pca_analysis(
        pca_mode="num_comp",
        pca_comp=11,
        pca_perc=None)


    pred = gp_model.predict(test_fp=FP.df_test.sort_index(), uncertainty=True)

    dft_energy = df_bulk_dft.sort_index()["energy_pa"].tolist()
    dft_energy = np.array(dft_energy)

    pred_unstd = (pred["prediction"] * np.std(dft_energy)) + np.mean(dft_energy)
    
    return(dft_energy, pred_unstd)

# Read AL Data

In [None]:
sys.path.insert(0, os.path.join(
    os.environ["PROJ_irox"],
    "workflow/ml_modelling"))
from ml_methods import get_data_for_al


# #############################################################################
al_output_data_path_root = os.path.join(
    os.environ["PROJ_irox"],
    "workflow/ml_modelling", "00_ml_workflow/190611_new_workflow",
    "02_gaus_proc/out_data")
# with open(al_output_data_path, "rb") as fle:
#     al_data_dict = pickle.load(fle)

In [None]:
stoich_i = "AB2"
custom_name = "regular"

# #############################################################################
out_dict = get_data_for_al(stoich=stoich_i, verbose=False)

df_bulk_dft = out_dict["df_bulk_dft"]
df_bulk_dft = df_bulk_dft[df_bulk_dft["source"] == "raul"]
df_bulk_dft = df_bulk_dft[["atoms", "energy_pa"]]

df_features_pre = out_dict["df_features_pre"]
df_features_post = out_dict["df_features_post"]

# #############################################################################
# AL Data #####################################################################
al_output_data_path = os.path.join(
    al_output_data_path_root,
    "data_dict_" + stoich_i + "_" + custom_name + ".pickle")

with open(al_output_data_path, "rb") as fle:
    al_data_dict = pickle.load(fle)

In [None]:
%%capture
dft_energy_ab2, pred_unstd_ab2 = process_parity_plot(
    al_data_dict=al_data_dict,
    df_bulk_dft=df_bulk_dft,
    df_features_pre=df_features_pre,
    df_features_post=df_features_post)

In [None]:
stoich_i = "AB3"
custom_name = "regular"

# #############################################################################
out_dict = get_data_for_al(stoich=stoich_i, verbose=False)

df_bulk_dft = out_dict["df_bulk_dft"]
df_bulk_dft = df_bulk_dft[df_bulk_dft["source"] == "raul"]
df_bulk_dft = df_bulk_dft[["atoms", "energy_pa"]]

df_features_pre = out_dict["df_features_pre"]
df_features_post = out_dict["df_features_post"]

# #############################################################################
# AL Data #####################################################################
al_output_data_path = os.path.join(
    al_output_data_path_root,
    "data_dict_" + stoich_i + "_" + custom_name + ".pickle")
with open(al_output_data_path, "rb") as fle:
    al_data_dict = pickle.load(fle)

In [None]:
dft_energy_ab3, pred_unstd_ab3 = process_parity_plot(
    al_data_dict=al_data_dict,
    df_bulk_dft=df_bulk_dft,
    df_features_pre=df_features_pre,
    df_features_post=df_features_post)

In [None]:
# AB2 Parity Plot #############################################################
x_array_ab2 = dft_energy_ab2
y_array_ab2 = pred_unstd_ab2
trace_ab2 = go.Scatter(
    x=x_array_ab2,
    y=y_array_ab2,
    mode="markers")


# AB3 Parity Plot #############################################################
x_array_ab3 = dft_energy_ab3
y_array_ab3 = pred_unstd_ab3
trace_ab3 = go.Scatter(
    x=x_array_ab3,
    y=y_array_ab3,
    mode="markers")


# x=y line ####################################################################
plot_range = [
    min(
        y_array_ab2.min(), x_array_ab2.min(),
        y_array_ab3.min(), x_array_ab3.min()),
    max(
        y_array_ab2.max(), x_array_ab2.max(),
        y_array_ab3.max(), x_array_ab3.max())]
trace_xy = go.Scatter(
    x=plot_range,
    y=plot_range,
    mode="lines")


from layout import layout

data = [trace_xy, trace_ab2, trace_ab3]
fig = go.Figure(data=data, layout=layout)


my_plotly_plot(
    figure=fig,
    plot_name='irox_parity_plot',
    write_pdf_svg=True,
    write_html=True,
    write_png=False)

fig.show()