In [None]:
from definitions import ROOT_DIR
import json
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import seaborn as sns
import matplotlib.lines as lines

### Load Data from disc 

In [None]:
paths_to_dict = [

    'priors/test_model__neural_net__data_grammar_8_run_1__Bi_LSTM_Measurement_Encoder__Transformer_Encoder_String250__Endgame__.json',
    'priors/test_model__neural_net__data_grammar_8_run_1__Bi_LSTM_Measurement_Encoder__Transformer_Encoder_String125__Endgame__.json',
    'priors/test_model__neural_net__data_grammar_8_run_1__Bi_LSTM_Measurement_Encoder__Transformer_Encoder_String50__Endgame__.json',
    'priors/test_model__neural_net__data_grammar_8_run_1__Bi_LSTM_Measurement_Encoder__Transformer_Encoder_String5__Endgame__.json',
    'priors/test_model__neural_net__data_grammar_8_run_1__Bi_LSTM_Measurement_Encoder__Transformer_Encoder_Stringsupervised__Endgame__.json',
    'priors/test_model__neural_net__data_grammar_8_run_1__MeasurementEncoderDummy__Transformer_Encoder_Stringsupervised__Endgame__.json',
    'priors/test_model__neural_net__data_grammar_8_run_1__Bi_LSTM_Measurement_Encoder__EquationEncoderDummysupervised__Endgame__.json',
    'priors/test_model__uniform.json',

    #'priors/test_model_token__neural_net__data_grammar_8_run_1__DatasetTransformer__Transformer_Encoder_Stringsupervised__Endgame__.json'
]
y_axis_label = {
    'test_model__neural_net__data_grammar_8_run_1__Bi_LSTM_Measurement_Encoder__Transformer_Encoder_String250__Endgame__': '250',
    'test_model__neural_net__data_grammar_8_run_1__Bi_LSTM_Measurement_Encoder__Transformer_Encoder_String125__Endgame__': '125',
    'test_model__neural_net__data_grammar_8_run_1__Bi_LSTM_Measurement_Encoder__Transformer_Encoder_String50__Endgame__': '50',
    'test_model__neural_net__data_grammar_8_run_1__Bi_LSTM_Measurement_Encoder__Transformer_Encoder_String5__Endgame__': '10',
    'test_model__neural_net__data_grammar_8_run_1__Bi_LSTM_Measurement_Encoder__Transformer_Encoder_Stringsupervised__Endgame__': 'Complete',
    'test_model__neural_net__data_grammar_8_run_1__MeasurementEncoderDummy__Transformer_Encoder_Stringsupervised__Endgame__': ' No Dataset',
    'test_model__neural_net__data_grammar_8_run_1__Bi_LSTM_Measurement_Encoder__EquationEncoderDummysupervised__Endgame__': 'No  Tree ',
    'test_model__uniform': 'uniform',
    #'test_model_token__neural_net__data_grammar_8_run_1__Bi_LSTM_Measurement__Transformer_Encoder_Stringsupervised__Endgame__': 'Token_Supervised NPT',
}
priors_dict = {}
for path in paths_to_dict:
    with open(ROOT_DIR / path, 'r') as file:
        loaded_data = json.load(file)
        priors_dict[Path(path).stem] = loaded_data

### visualize prediction

In [None]:
def forward(value):
    result = np.where(value <= 0.1, value * 5, (value - 0.1) * 5 / 9 + 0.5)
    return result


def inverse(value):
    return np.where(value <= 0.5, value / 5, (value - 0.5) * 9 / 5 - 0.1)


fig, axs = plt.subplots(len(paths_to_dict), 23, figsize=(8,  len(paths_to_dict)+1), sharey=True, sharex=False)
equations = {
     ' / c x_1 ' : '$ c / x_1  $',
    ' + c  sin x_1 ' : '$c +  \\sin(x_1)$',
    ' +  * c  ** 2 x_0    +  * c  ** 3 x_0    * c x_1   ' : '$c \cdot x_0^3 + c \cdot x_0^2 + c \cdot x_0 $',
     ' + c  ** 2 x_0  ' : '$c + x_0^2$'
}
for row, architecture in enumerate(priors_dict):
    for column in range(23):
        column = column
        if column >= 21: 
            action = column + 5
        else: 
            action = column
        prior_mean = []
        prior_std = []
        for equation in equations.keys():
            keys = list(priors_dict[architecture][equation].keys())
            prior = []
            for key in keys:
                prior.append(priors_dict[architecture][equation][key][action])
            prior_mean.append(np.mean(prior))
            prior_std.append(np.std(prior))
        axs[row, column].set_yscale('function', functions=(forward, inverse))
        for i, equation in enumerate(equations):
            axs[row, column].errorbar(x=i, 
                                      y=prior_mean[i],
                                      yerr=prior_std[i],
                                      fmt=['.', 'o', 'v', 'H'][i],
                                      ls='none',
                                      color=['black', 'dimgrey', 'gray', 'darkgrey', ][i],
                                      label=equations[equation]
                                      )

        x_tick_labels = axs[row, column].get_xticklabels()
        axs[row, column].set_xticklabels(['' for label in x_tick_labels])
        axs[row, column].set_xlabel(action)
        if column == 0:
            axs[row, column].set_ylabel(y_axis_label[architecture])
        axs[row, column].grid(visible=False, axis='x', which='both')
        axs[row, column].spines['top'].set_visible(False)
        axs[row, column].spines['right'].set_visible(False)
        axs[row, column].spines['bottom'].set_visible(False)
        axs[row, column].spines['left'].set_visible(False)
        axs[row, column].set_ylim((-0.01, 1.15))
        axs[row, column].set_xlim((-0.3, len(equations) + 0.2))
        axs[row, column].set_yticks([0,  0.1,  1])
     
fig.add_artist(lines.Line2D([-0.02, 0.05], [0.50, 0.50], color='k'))
fig.add_artist(lines.Line2D([-0.02, 0.05], [0.13, 0.13], color='k'))

fig.text(-0.02, 0.7, 'Prior MCTS', rotation=90)
fig.text(-0.02, 0.25, 'Prior Supervised ', rotation=90)
fig.text(-0.02, 0.07, 'Qsa', rotation=90)

fig.text(0.5, 0, 'Actions')
handles, labels = axs[0, 0].get_legend_handles_labels()
fig.legend(handles, labels, loc='upper center', ncol=4)
fig.tight_layout()
plt.subplots_adjust( wspace=0.06)   
fig.savefig('priors_start_node.pdf', bbox_inches='tight', dpi=300)
plt.show()
        
