In [9]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

from utils import savefig

plt.rcParams['font.size'] = 14

In [10]:
exp_path = Path("./experiments/CondQA/MemTransModel/figures/sep_matched/DeepValueMemoryTransformGRU/setup_recq_thresdecay")
timestep_each_phase = 4
fig_path = Path("./figures_qa/recalling")


accuracy_by_timestep = []
accuracy_by_num_retrieved_memories = []
num_retrieved_memories = []
num_matched_memories_retrieved = []

for i in range(10):
    file_path = exp_path / str(i) / "data"
    accuracy_by_timestep.append(np.load(file_path / "accuracy_by_timestep_all.npy"))
    accuracy_by_num_retrieved_memories.append(np.load(file_path / "accuracy_by_num_retrieved_memories.npy"))
    num_retrieved_memories.append(np.load(file_path / "num_retrieved_memories_mean_by_timestep.npy"))
    num_matched_memories_retrieved.append(np.load(file_path / "num_matched_memories_retrieved_mean_by_timestep.npy"))

accuracy_by_timestep_mean = np.mean(accuracy_by_timestep, axis=0)
accuracy_by_timestep_std = np.std(accuracy_by_timestep, axis=0)

accuracy_by_num_retrieved_memories_mean = np.mean(accuracy_by_num_retrieved_memories, axis=0)
accuracy_by_num_retrieved_memories_std = np.std(accuracy_by_num_retrieved_memories, axis=0)

num_retrieved_memories_mean = np.mean(num_retrieved_memories, axis=0)
num_retrieved_memories_std = np.std(num_retrieved_memories, axis=0)

num_matched_memories_retrieved_mean = np.mean(num_matched_memories_retrieved, axis=0)
num_matched_memories_retrieved_std = np.std(num_matched_memories_retrieved, axis=0)

In [11]:
plt.figure(figsize=(4.7, 3), dpi=180)
for i in range(0, timestep_each_phase):
    plt.plot(np.arange(1, timestep_each_phase+1), accuracy_by_timestep_mean[i], label="{} matched".format(i+1), marker="o")
    plt.fill_between(np.arange(1, timestep_each_phase+1), accuracy_by_timestep_mean[i]-accuracy_by_timestep_std[i], accuracy_by_timestep_mean[i]+accuracy_by_timestep_std[i], alpha=0.2)
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.xlabel("time in recall phase")
plt.ylabel("accuracy")
plt.legend(fontsize=10, bbox_to_anchor=(1.05, 1.0), loc='upper left')
plt.tight_layout()
savefig(fig_path, "accuracy_timesteps")

In [13]:
plt.figure(figsize=(4.7, 3), dpi=180)
for i in range(1, timestep_each_phase+1):
    plt.plot(np.arange(1, timestep_each_phase+1), num_retrieved_memories_mean[i], label="{} matched".format(i), marker="o")
    plt.fill_between(np.arange(1, timestep_each_phase+1), num_retrieved_memories_mean[i]-num_retrieved_memories_std[i], num_retrieved_memories_mean[i]+num_retrieved_memories_std[i], alpha=0.2)
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.xlabel("time in recall phase")
plt.ylabel("number of\nretrieved memories")
plt.legend(fontsize=10, bbox_to_anchor=(1.05, 1.0), loc='upper left')
plt.tight_layout()
savefig(fig_path, "num_retrieved_memories")

plt.figure(figsize=(4.7, 3), dpi=180)
for i in range(1, timestep_each_phase+1):
    plt.plot(np.arange(1, timestep_each_phase+1), num_matched_memories_retrieved_mean[i], label="{} matched".format(i), marker="o")
    plt.fill_between(np.arange(1, timestep_each_phase+1), num_matched_memories_retrieved_mean[i]-num_matched_memories_retrieved_std[i], num_matched_memories_retrieved_mean[i]+num_matched_memories_retrieved_std[i], alpha=0.2)
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.xlabel("time in recall phase")
plt.ylabel("number of matched\nmemories retrieved")
plt.legend(fontsize=10, bbox_to_anchor=(1.05, 1.0), loc='upper left')
plt.tight_layout()
savefig(fig_path, "num_matched_memories_retrieved")

plt.figure(figsize=(4.7, 3.2), dpi=180)
for i in range(1, timestep_each_phase+1):
    plt.plot(np.arange(1, timestep_each_phase+1), accuracy_by_num_retrieved_memories_mean[i], label="{} matched".format(i), marker="o")
    plt.fill_between(np.arange(1, timestep_each_phase+1), accuracy_by_num_retrieved_memories_mean[i]-accuracy_by_num_retrieved_memories_std[i], accuracy_by_num_retrieved_memories_mean[i]+accuracy_by_num_retrieved_memories_std[i], alpha=0.2)
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.xlabel("number of retrieved\nmemories")
plt.ylabel("accuracy")
plt.legend(fontsize=10, bbox_to_anchor=(1.05, 1.0), loc='upper left')
plt.tight_layout()
savefig(fig_path, "accuracy_num_retrieved_memories")
