In [None]:
from definitions import ROOT_DIR
import json
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import seaborn as sns
import matplotlib.lines as lines

### Load Data from disc 

In [None]:
paths_to_dict = [

    'priors/test_model__neural_net__data_grammar_8_run_1__Bi_LSTM_Measurement_Encoder__Transformer_Encoder_String250__Endgame__.json',
    #'priors/test_model__neural_net__data_grammar_8_run_1__Bi_LSTM_Measurement_Encoder__Transformer_Encoder_String125__Endgame__.json',
    #'priors/test_model__neural_net__data_grammar_8_run_1__Bi_LSTM_Measurement_Encoder__Transformer_Encoder_String50__Endgame__.json',
    'priors/test_model__neural_net__data_grammar_8_run_1__Bi_LSTM_Measurement_Encoder__Transformer_Encoder_String5__Endgame__.json',
    'priors/test_model__neural_net__data_grammar_8_run_1__Bi_LSTM_Measurement_Encoder__Transformer_Encoder_Stringsupervised__Endgame__.json',
    'priors/test_model__neural_net__data_grammar_8_run_1__MeasurementEncoderDummy__Transformer_Encoder_Stringsupervised__Endgame__.json',
    'priors/test_model__neural_net__data_grammar_8_run_1__Bi_LSTM_Measurement_Encoder__EquationEncoderDummysupervised__Endgame__.json',
    'priors/test_model__uniform.json',
    #'priors/test_model_token__neural_net__data_grammar_8_run_1__DatasetTransformer__Transformer_Encoder_Stringsupervised__Endgame__.json'
]
y_axis_label = {
    'test_model__uniform': 'normalized',
    'test_model__neural_net__data_grammar_8_run_1__Bi_LSTM_Measurement_Encoder__Transformer_Encoder_String250__Endgame__': 'MCTS 250',
    'test_model__neural_net__data_grammar_8_run_1__Bi_LSTM_Measurement_Encoder__Transformer_Encoder_String125__Endgame__': 'MCTS 125',
    'test_model__neural_net__data_grammar_8_run_1__Bi_LSTM_Measurement_Encoder__Transformer_Encoder_String50__Endgame__': 'MCTS 50',
    'test_model__neural_net__data_grammar_8_run_1__Bi_LSTM_Measurement_Encoder__Transformer_Encoder_String5__Endgame__': 'MCTS 5',
    'test_model__neural_net__data_grammar_8_run_1__Bi_LSTM_Measurement_Encoder__Transformer_Encoder_Stringsupervised__Endgame__': 'Complete',
    'test_model__neural_net__data_grammar_8_run_1__MeasurementEncoderDummy__Transformer_Encoder_Stringsupervised__Endgame__': 'No Dataset',
    'test_model__neural_net__data_grammar_8_run_1__Bi_LSTM_Measurement_Encoder__EquationEncoderDummysupervised__Endgame__': 'No Tree',
    #'test_model_token__neural_net__data_grammar_8_run_1__Bi_LSTM_Measurement__Transformer_Encoder_Stringsupervised__Endgame__': 'Token_Supervised NPT',
}
priors_dict = {}
for path in paths_to_dict:
    with open(ROOT_DIR / path, 'r') as file:
        loaded_data = json.load(file)
        priors_dict[Path(path).stem] = loaded_data

### visualize prediction

In [None]:
def forward(value):
    result = np.where(value <= 0.1, value * 5, (value - 0.1) * 5 / 9 + 0.5)
    return result


def inverse(value):
    return np.where(value <= 0.5, value / 5, (value - 0.5) * 9 / 5 - 0.1)

# Skip actions 21 - 25 
fig, axs = plt.subplots(len(paths_to_dict), 23, figsize=(8,  len(paths_to_dict)), sharey=True, sharex=True)
equations = {
     " + c  sin x_0 + c  sin Variable": '$ c + \sin(x_0)$',
    " + c  sin x_1 + c  sin Variable" : '$ c + \sin(x_1)$',
   
}
for row, architecture in enumerate(priors_dict):
    for column in range(23):
        column = column
        if column >= 21: 
            action = column + 5
        else: 
            action = column
        prior_mean = []
        prior_std = []
        for equation in equations:
            keys = list(priors_dict[architecture][equation].keys())
            prior = []
            for key in keys:
                prior.append(priors_dict[architecture][equation][key][action])
            prior_mean.append(np.mean(prior))
            prior_std.append(np.std(prior))
        axs[row, column].set_yscale('function', functions=(forward, inverse))
        for i in range(len(equations)):
            pass
            axs[row, column].errorbar(x=i, 
                                      y=prior_mean[i],
                                      yerr=prior_std[i],
                                      fmt=['.', 'o', 'v', 'H'][i],
                                      ls='none',
                                      color=['black', 'dimgrey', 'gray', 'darkgrey', ][i],
                                      label=equations[list(equations.keys())[i]]
                                      )

        x_tick_labels = axs[row, column].get_xticklabels()
        axs[row, column].set_xticklabels(['' for label in x_tick_labels])
        axs[row, column].set_xlabel(action)
        if column == 0:
            axs[row, column].set_ylabel(y_axis_label[architecture])
        axs[row, column].grid(visible=False, axis='x', which='both')
        axs[row, column].spines['top'].set_visible(False)
        axs[row, column].spines['right'].set_visible(False)
        axs[row, column].spines['bottom'].set_visible(False)
        axs[row, column].spines['left'].set_visible(False)
        #axs[row, column].set_xlim((-0.1, 1.1))
        axs[row, column].set_ylim((-0.01, 1.15))
        axs[row, column].set_yticks([0,  0.1,  1])  # [0, 0.05, 0.1, 0.5, 1]


fig.text(0.5, 0, 'Actions')
fig.add_artist(lines.Line2D([-0.02, 0.05], [0.68, 0.68], color='k'))
fig.add_artist(lines.Line2D([-0.02, 0.05], [0.2, 0.2], color='k'))
fig.text(-0.02, 0.78, 'Prior', rotation=90)
fig.text(-0.02, 0.35, 'Prior Supervised ', rotation=90)
fig.text(-0.02, 0.1, 'Qsa', rotation=90)

handles, labels = axs[0, 0].get_legend_handles_labels()
fig.legend(handles, labels, loc='upper center', ncol=5)
fig.tight_layout()
fig.savefig('priors_intermediate_node.pdf', bbox_inches='tight')
plt.show()
        
