In [None]:
import ast
import cmasher as cmr
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from opportunistic_planning import processing, visualization
from scipy.stats import friedmanchisquare, wilcoxon

In [None]:
df = processing.read_data('all_task_environments_list_2022-09-12_unique_sequences.csv', use_string_for_seq=False)

In [None]:
df = df[:186]
len(df)

In [None]:
with open ('../../model_evaluation/neural_net/results/nn_spatialinfo_prequential_summed_fitted_on_ts_results_all_2022-06-07.txt', 'r') as file:
    nn_results = ast.literal_eval(file.read())

In [None]:
nn_results = nn_results[:186]
len(nn_results)

In [None]:
with open ('../../model_evaluation/neural_net/results/pytorch_rnn_prequential_summed_fitted_on_ts_all_2022-06-29.txt', 'r') as file:
    rnn_results = ast.literal_eval(file.read())

In [None]:
rnn_results = rnn_results[:186]
len(rnn_results)

In [None]:
%time distances_dict = processing.generate_distances_dict(df)

In [None]:
%time results_sum = processing.calculate_prediction_error(df, distances_dict, error_function='prequential', n=100, dimensions=[[2, 'xy']])

In [None]:
lowest_mean, lowest_mean_idx, lowest_median, results_median = processing.get_lowest_error(results_sum)

In [None]:
#processing.save_results(results_median, 'results/results_mean_2D_n100_2022-04-19.csv')

In [None]:
lowest_mean_idx

In [None]:
results_median[lowest_mean_idx]

In [None]:
results_median.to_csv('results/results_tablesetting_2D_n100_fitted_to_ts_2023-06-22.csv',
                     header=True, index=False)

### Plot best model with baselines

In [None]:
#%matplotlib inline
%matplotlib qt

IDs = df['ID']

error = df['error']

#seqs= results_sum['sequence'][:-2].values

res = results_sum[lowest_mean_idx[0]][:-2].values
median = [np.nanmedian(res)] * len(res)

x = [x for x in range (0,len(res))]

# plot CPT baseline
#plt.scatter(x, cpt, marker='o', s=20, c='dodgerblue', alpha=0.3, 
#            label=str('CPT baseline median: ') + str(round(np.median(cpt),3)))
#plt.plot(x, [np.median(cpt)] * len(x), '-', c='dodgerblue', alpha=0.9, linewidth=2)
#plt.plot((x_original,x_original),(res_original,distances), '--', c='darkgreen', alpha=0.6)
#plt.fill_between(x, cpt, alpha=0.3, color='dodgerblue')

# plot RNN text prediction baseline
plt.scatter(x, rnn_results, marker='v', s=20, c='darkgreen', alpha=0.7, 
            label=str('NN text prediction (median: ') + str(round(np.median(rnn_results),3)) + ')')
plt.plot(x, [np.median(rnn_results)] * len(x), '-', c='darkgreen', alpha=0.9, linewidth=2)
#plt.plot((x_original,x_original),(res_original,distances), '--', c='darkgreen', alpha=0.6)
plt.fill_between(x, rnn_results, [14 for x in range(0,len(res))], alpha=0.2, 
                 color='limegreen')


# plot NN baseline
plt.scatter(x, nn_results, marker='D', s=20, c='navy', alpha=0.8, 
            label=str('NN spatial info (median: ') + str(round(np.median(nn_results),3)) + ')')
plt.plot(x, [np.median(nn_results)] * len(x), '-', c='dodgerblue', alpha=0.9, linewidth=2)
#plt.plot((x_original,x_original),(res_original,distances), '--', c='darkgreen', alpha=0.6)
plt.fill_between(x, nn_results, alpha=0.3, color='dodgerblue')


# plot scatter + lines for simulations
plt.scatter(x, res, marker='o', s=26, c='darkred', alpha=0.8, 
            label=str('model (median: ') + str(round(lowest_median,3)) + ')')
#plt.plot(x, res, c='blue', alpha=0.6)
plt.plot(x, median, c='darkred', alpha=0.95, linewidth=2)
#plt.fill_between(x, res, alpha=0.3, color='darkviolet')


plt.xticks(x, labels=IDs, rotation=90, fontsize=5)
#plt.xticklabels(IDs, rotation=90, fontsize=6)

plt.ylabel('accumulated prediction error', fontsize=22)
plt.xlabel('sequence', fontsize=22)
#plt.ylim(0.0, 0.51)
plt.title('best model tablesetting (parameters: k: 0.2,0.3,1.7; c: 1.8)', fontsize=24, pad=20)
plt.margins(0.01)

lgnd = plt.legend(fontsize=20, framealpha=0.8, loc='upper right', markerscale=2.5)
lgnd.legendHandles[0]._sizes = [70]
lgnd.legendHandles[1]._sizes = [70]
lgnd.legendHandles[2]._sizes = [70]

#plt.savefig('plot_median_editdist_individualerrors_diff.png', bbox_inches='tight')
plt.show()

### Stat analysis

In [None]:
# model vs NN
stat, p = wilcoxon(nn_results, res, zero_method='wilcox')
print('Wilcoxon: W = %.3f, p = %.5f' % (stat, p))

In [None]:
# model vs RNN
stat, p = wilcoxon(res, rnn_results, zero_method='wilcox')
print('Wilcoxon: W = %.3f, p = %.5f' % (stat, p))