In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
plt.style.use("seaborn-v0_8-deep")
plt.rcParams.update({
    "text.usetex": True,  # Use LaTeX for text rendering
    "font.family": "serif",  # Use serif font
    "font.serif": ["Computer Modern Roman"],  # Use default LaTeX font
})

In [2]:
df_qrnn = pd.read_csv('/tikhome/tfellner/Projects/VQA_timeseries_benchmark/Results/qrnn_paper_averaged_ids.csv')
df_qrnn_reset = df_qrnn[(df_qrnn['Ansatz']== 'paper_reset') & (df_qrnn['Number Qubits'] == 4) & (df_qrnn['Sequence Length'] == 4)]
df_qrnn_no_reset = df_qrnn[(df_qrnn['Ansatz']== 'paper_no_reset') & (df_qrnn['Number Qubits'] == 4) & (df_qrnn['Sequence Length'] == 4)]

In [3]:
def find_best(df):
    df_intern = df.copy()
    best_rows = df_intern.loc[df_intern.groupby(['Prediction Step', 'Data', 'Sequence Length'], observed=True)['MSE Validation Median'].idxmin()]
    return best_rows

In [4]:
best_qrnn_reset = find_best(df_qrnn_reset)
best_qrnn_no_reset = find_best(df_qrnn_no_reset)

In [5]:
prediction_steps_mackey = [1, 70, 140]
prediction_steps_henon = [1, 2, 4]
prediction_steps_lorenz = [1, 13, 25]
models = {'QRNN reset': (best_qrnn_reset, '#33AA88'), 'QRNN no reset': (best_qrnn_no_reset, '#AA3377')}
# new color: FFBF46
models = {'QRNN reset': (best_qrnn_reset, '#33AA88'), 'QRNN no reset': (best_qrnn_no_reset, '#FFBF46')}
data = {'mackey_1000': (prediction_steps_mackey, 'Mackey-Glass'), 'henon_1000': (prediction_steps_henon, 'Hénon'), 'lorenz_1000': (prediction_steps_lorenz, 'Lorenz')}

In [None]:
legend_handles = {}
fig, axs = plt.subplots(1, len(data), figsize=(4, 3))
for i, (model, (df, color)) in enumerate(models.items()):
    for j, (data_label, (prediction_steps, data_title)) in enumerate(data.items()):
        for k, prediction_step in enumerate(prediction_steps):
            ax = axs[j]
            df_filtered = df[(df['Data'] == data_label) & (df['Prediction Step'] == prediction_step)]
            df_filtered = df_filtered.sort_values('Sequence Length')
            if not df_filtered.empty:
                handle = ax.errorbar(prediction_step, df_filtered['MSE Testing Median'], yerr=df_filtered['MSE Testing Mad'], marker='d', label=model, color=color, capsize=7, markeredgecolor='black', markeredgewidth=0.5)
            ax.set_yscale('log')
            ax.set_title(f'{data_title}', fontsize=11)
            ax.set_xticks(prediction_steps)
            ax.grid(True)
            if model not in legend_handles:
                legend_handles[model] = handle
                        # Set the tick font sizes
            ax.tick_params(axis='x', labelsize=11)  # Adjust fontsize for x ticks
            ax.tick_params(axis='y', labelsize=11)  # Adjust fontsize for y ticks
legend_handles = [
    Line2D([0], [0], color='#33AA88', marker='d', linestyle='None', 
           markeredgecolor='black', markeredgewidth=0.5, label='QRNN reset'),
    Line2D([0], [0], color='#FFBF46', marker='d', linestyle='None', 
           markeredgecolor='black', markeredgewidth=0.5, label='QRNN no reset')
]
fig.legend(handles=legend_handles, loc='lower center', bbox_to_anchor=(0.5, 0.002), ncol=2, fontsize=11)
fig.text(0.01, 0.54, 'Median MSE', va='center', rotation='vertical', fontsize=11)
fig.text(0.57, 0.15, 'Prediction Step', ha='center', fontsize=11)
plt.tight_layout()  # Adjust layout to accommodate the legend outside
fig.subplots_adjust(left=0.18, right=0.98, top=0.92, bottom=0.28, wspace=0.62, hspace=0.45)
for ax in axs:
    x_min, x_max = ax.get_xlim()  # Get current x-axis limits
    ax.set_xlim(x_min - (x_max-x_min)*0.15, x_max + (x_max-x_min)*0.15)  # Add padding
plt.savefig(f'/tikhome/tfellner/Projects/VQA_timeseries_benchmark/Plots/All_models/qrnn_comparison_new.pdf')