In [20]:
import os
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns


from utils import savefig

plt.rcParams['font.size'] = 14

colors = sns.color_palette()

## compare encoding and recall phase

In [18]:
enc_exp_path = Path("./experiments/CondQA/MemTransModel/figures/DeepValueMemoryTransformGRU/setup_encq_thresdecay")
rec_exp_path = Path("./experiments/CondQA/MemTransModel/figures/DeepValueMemoryTransformGRU/setup_recq_thresdecay")
timestep_each_phase = 4
fig_path = Path("./figures_qa/enc_rec")

# accuracy
accs_encq = []
accs_recq = []
for i in range(3):
    with open(enc_exp_path / "{}/accuracy.csv".format(i), "r") as f:
        accs_encq.append(np.array(f.read().split(",")[0], dtype=float))
    with open(rec_exp_path / "{}/accuracy.csv".format(i), "r") as f:
        accs_recq.append(np.array(f.read().split(",")[0], dtype=float))
accs_encq = np.array(accs_encq)
accs_recq = np.array(accs_recq)

mean_acc_encq = np.mean(accs_encq)
std_acc_encq = np.std(accs_encq)
mean_acc_recq = np.mean(accs_recq)
std_acc_recq = np.std(accs_recq)


def bar_plot(mean, std, baseline, ylabel, filename, figsize=(2.8, 3), baseline_text_pos="bottom"):
    plt.figure(figsize=figsize, dpi=200)
    plt.bar(["encoding", "decision"], mean, yerr=std, color=colors, capsize=5)
    plt.axhline(baseline, color='black', linestyle='--')
    if baseline_text_pos == "top":
        plt.text(1.5, baseline+0.005, "chance level", fontsize=12, ha='right', va='bottom')
    else:
        plt.text(1.5, baseline-0.015, "chance level", fontsize=12, ha='right', va='top')
    plt.ylabel(ylabel)
    plt.ylim(0, 1)
    ax = plt.gca()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    plt.xticks(fontsize=12)
    plt.tight_layout()
    savefig(fig_path, filename)

bar_plot([mean_acc_encq, mean_acc_recq], [std_acc_encq, std_acc_recq], 0.5, "task performance", "performance")



# decoding results
identity_accs_encq = []
identity_accs_recq = []
feaure_accs_encq = []
feaure_accs_recq = []
random_accs_encq = []
random_accs_recq = []
mean_feature_accs_encq = []
mean_feature_accs_recq = []

for i in range(10):
    if os.path.exists(enc_exp_path / "{}/enc_ridge.csv".format(i)):
        with open(enc_exp_path / "{}/enc_ridge.csv".format(i), "r") as f:
            data = f.read().split(",")
            identity_accs_encq.append(data[0])
            feaure_accs_encq.append(data[1])
            random_accs_encq.append(data[2])
    if os.path.exists(enc_exp_path / "{}/enc_ridge_feature.csv".format(i)):
        with open(enc_exp_path / "{}/enc_ridge_feature.csv".format(i), "r") as f:
            data = f.read().split(",")
            data = np.array(data, dtype=float)
            mean_feature_accs_encq.append(np.mean(data))
    if os.path.exists(rec_exp_path / "{}/enc_ridge.csv".format(i)):
        with open(rec_exp_path / "{}/enc_ridge.csv".format(i), "r") as f:
            data = f.read().split(",")
            identity_accs_recq.append(data[0])
            feaure_accs_recq.append(data[1])
            random_accs_recq.append(data[2])
    if os.path.exists(rec_exp_path / "{}/enc_ridge_feature.csv".format(i)):
        with open(rec_exp_path / "{}/enc_ridge_feature.csv".format(i), "r") as f:
            data = f.read().split(",")
            data = np.array(data, dtype=float)
            mean_feature_accs_recq.append(np.mean(data))

identity_accs_encq_mean = np.mean(np.array(identity_accs_encq, dtype=float))
identity_accs_encq_std = np.std(np.array(identity_accs_encq, dtype=float))
identity_accs_recq_mean = np.mean(np.array(identity_accs_recq, dtype=float))
identity_accs_recq_std = np.std(np.array(identity_accs_recq, dtype=float))
feature_accs_encq_mean = np.mean(np.array(feaure_accs_encq, dtype=float))
feature_accs_encq_std = np.std(np.array(feaure_accs_encq, dtype=float))
feature_accs_recq_mean = np.mean(np.array(feaure_accs_recq, dtype=float))
feature_accs_recq_std = np.std(np.array(feaure_accs_recq, dtype=float))
random_accs_encq_mean = np.mean(np.array(random_accs_encq, dtype=float))
random_accs_encq_std = np.std(np.array(random_accs_encq, dtype=float))
random_accs_recq_mean = np.mean(np.array(random_accs_recq, dtype=float))
random_accs_recq_std = np.std(np.array(random_accs_recq, dtype=float))
mean_feature_accs_encq_mean = np.mean(np.array(mean_feature_accs_encq, dtype=float))
mean_feature_accs_encq_std = np.std(np.array(mean_feature_accs_encq, dtype=float))
mean_feature_accs_recq_mean = np.mean(np.array(mean_feature_accs_recq, dtype=float))
mean_feature_accs_recq_std = np.std(np.array(mean_feature_accs_recq, dtype=float))

bar_plot([identity_accs_encq_mean, identity_accs_recq_mean], [identity_accs_encq_std, identity_accs_recq_std], 1.0/16, 
"item decoding accuracy", "identity_decoding", baseline_text_pos="top")

bar_plot([feature_accs_encq_mean, feature_accs_recq_mean], [feature_accs_encq_std, feature_accs_recq_std], 0.5,
"task-related feature\ndecoding accuracy", "feature_decoding")

bar_plot([random_accs_encq_mean, random_accs_recq_mean], [random_accs_encq_std, random_accs_recq_std], 0.5,
"random feature\ndecoding accuracy", "random_feature_decoding")

bar_plot([mean_feature_accs_encq_mean, mean_feature_accs_recq_mean], [mean_feature_accs_encq_std, mean_feature_accs_recq_std], 0.5,
"individual feature\ndecoding accuracy", "mean_feature_decoding")


## question in recall phase
### load data

In [21]:
exp_path = Path("./experiments/CondQA/MemTransModel/figures/sep_matched/DeepValueMemoryTransformGRU/setup_recq_thresdecay")
timestep_each_phase = 4
fig_path = Path("./figures_qa/recalling")


accuracy_by_timestep = []
accuracy_by_num_retrieved_memories = []
accuracy_by_num_matched_memories = []
num_retrieved_memories = []
num_matched_memories_retrieved = []
answer_timesteps_num_by_matched_memories = []
accuracy_sep_by_recall_all_matched = []

for i in range(10):
    print(i)
    file_path = exp_path / str(i) / "data"
    accuracy_by_timestep.append(np.load(file_path / "accuracy_by_timestep_all.npy"))
    accuracy_by_num_retrieved_memories.append(np.load(file_path / "accuracy_by_num_retrieved_memories.npy"))
    accuracy_by_num_matched_memories.append(np.load(file_path / "accuracy_by_num_matched_memories.npy"))
    num_retrieved_memories.append(np.load(file_path / "num_retrieved_memories_mean_by_timestep.npy"))
    num_matched_memories_retrieved.append(np.load(file_path / "num_matched_memories_retrieved_mean_by_timestep.npy"))
    answer_timesteps_num_by_matched_memories.append(np.load(file_path / "answer_timesteps_num_by_num_matched_memories.npy"))
    accuracy_sep_by_recall_all_matched.append(np.load(file_path / "performance_sep_by_recall_all_matched.npy"))


accuracy_by_timestep_mean = np.mean(accuracy_by_timestep, axis=0)
accuracy_by_timestep_std = np.std(accuracy_by_timestep, axis=0)

accuracy_by_num_retrieved_memories_mean = np.mean(accuracy_by_num_retrieved_memories, axis=0)

accuracy_by_num_retrieved_memories_std = np.std(accuracy_by_num_retrieved_memories, axis=0)

# accuracy_by_num_matched_memories = np.array(accuracy_by_num_matched_memories)
# accuracy_by_num_matched_memories = np.nan_to_num(accuracy_by_num_matched_memories, nan=0)
# accuracy_by_num_matched_memories_mean = np.mean(accuracy_by_num_matched_memories, axis=0)
# accuracy_by_num_matched_memories_std = np.std(accuracy_by_num_matched_memories, axis=0)
accuracy_by_num_matched_memories_mean = np.nanmean(accuracy_by_num_matched_memories, axis=0)
accuracy_by_num_matched_memories_std = np.nanstd(accuracy_by_num_matched_memories, axis=0)

num_retrieved_memories_mean = np.mean(num_retrieved_memories, axis=0)
num_retrieved_memories_std = np.std(num_retrieved_memories, axis=0)

num_matched_memories_retrieved = np.array(num_matched_memories_retrieved)
num_matched_memories_retrieved = np.nan_to_num(num_matched_memories_retrieved, nan=0)
num_matched_memories_retrieved_mean = np.mean(num_matched_memories_retrieved, axis=0)
num_matched_memories_retrieved_std = np.std(num_matched_memories_retrieved, axis=0)

answer_timesteps_num_by_matched_memories_mean = np.mean(answer_timesteps_num_by_matched_memories, axis=0)
answer_timesteps_num_by_matched_memories_std = np.std(answer_timesteps_num_by_matched_memories, axis=0)

accuracy_sep_by_recall_all_matched_mean = np.mean(accuracy_sep_by_recall_all_matched, axis=0)
accuracy_sep_by_recall_all_matched_std = np.std(accuracy_sep_by_recall_all_matched, axis=0)

0
1
2
3
4
5
6
7
8
9


  accuracy_by_num_matched_memories_mean = np.nanmean(accuracy_by_num_matched_memories, axis=0)
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,


### accuracy by the number of time steps taken

In [12]:
plt.figure(figsize=(4.7, 3), dpi=180)
for i in range(0, timestep_each_phase):
    plt.plot(np.arange(1, timestep_each_phase+1), accuracy_by_timestep_mean[i], label="{} matched".format(i+1), marker="o")
    plt.fill_between(np.arange(1, timestep_each_phase+1), accuracy_by_timestep_mean[i]-accuracy_by_timestep_std[i], accuracy_by_timestep_mean[i]+accuracy_by_timestep_std[i], alpha=0.2)
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.xlabel("time in recall phase")
plt.ylabel("task performance")
plt.legend(fontsize=10, bbox_to_anchor=(1.05, 1.0), loc='upper left')
plt.tight_layout()
savefig(fig_path, "accuracy_timesteps")

### number of retrieved memory and accuracy, grouped by number of matched item

In [22]:
plt.figure(figsize=(4.7, 3), dpi=180)
for i in range(1, timestep_each_phase+1):
    plt.plot(np.arange(1, timestep_each_phase+1), num_retrieved_memories_mean[i], label="{} matched".format(i), marker="o")
    plt.fill_between(np.arange(1, timestep_each_phase+1), num_retrieved_memories_mean[i]-num_retrieved_memories_std[i], num_retrieved_memories_mean[i]+num_retrieved_memories_std[i], alpha=0.2)
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.xlabel("time in recall phase")
plt.ylabel("number of unique\nmemories retrieved")
plt.legend(fontsize=10, bbox_to_anchor=(1.05, 1.0), loc='upper left')
plt.tight_layout()
savefig(fig_path, "num_retrieved_memories")

plt.figure(figsize=(4.7, 3), dpi=180)
for i in range(1, timestep_each_phase+1):
    plt.plot(np.arange(1, timestep_each_phase+1), num_matched_memories_retrieved_mean[i], label="{} matched".format(i), marker="o")
    plt.fill_between(np.arange(1, timestep_each_phase+1), num_matched_memories_retrieved_mean[i]-num_matched_memories_retrieved_std[i], num_matched_memories_retrieved_mean[i]+num_matched_memories_retrieved_std[i], alpha=0.2)
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.xlabel("time in recall phase")
plt.ylabel("number of matched\nmemories retrieved")
plt.legend(fontsize=10, bbox_to_anchor=(1.05, 1.0), loc='upper left')
plt.tight_layout()
savefig(fig_path, "num_matched_memories_retrieved")

plt.figure(figsize=(4.7, 3.2), dpi=180)
for i in range(1, timestep_each_phase+1):
    plt.plot(np.arange(1, timestep_each_phase+1), accuracy_by_num_retrieved_memories_mean[i], label="{} matched".format(i), marker="o")
    plt.fill_between(np.arange(1, timestep_each_phase+1), accuracy_by_num_retrieved_memories_mean[i]-accuracy_by_num_retrieved_memories_std[i], accuracy_by_num_retrieved_memories_mean[i]+accuracy_by_num_retrieved_memories_std[i], alpha=0.2)
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.xlabel("number of unique\nmemories retrieved")
plt.ylabel("task performance")
plt.legend(fontsize=10, bbox_to_anchor=(1.05, 1.0), loc='upper left')
plt.tight_layout()
savefig(fig_path, "accuracy_num_retrieved_memories")

plt.figure(figsize=(4.7, 3.2), dpi=180)
for i in range(1, timestep_each_phase+1):
    plt.plot(np.arange(1, i+1), accuracy_by_num_matched_memories_mean[i][:i], label="{} matched".format(i), marker="o")
    plt.fill_between(np.arange(1, i+1), accuracy_by_num_matched_memories_mean[i][:i]-accuracy_by_num_matched_memories_std[i][:i], accuracy_by_num_matched_memories_mean[i][:i]+accuracy_by_num_matched_memories_std[i][:i], alpha=0.2)
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.xlabel("number of matched\nmemories retrieved")
plt.ylabel("task performance")
plt.legend(fontsize=10, bbox_to_anchor=(1.05, 1.0), loc='upper left')
plt.tight_layout()
savefig(fig_path, "accuracy_num_matched_memories")


### performance before and after retrieving all matched memories

In [23]:
plt.figure(figsize=(2.8, 3), dpi=180)
plt.bar(["partial", "full"], accuracy_sep_by_recall_all_matched_mean, yerr=accuracy_sep_by_recall_all_matched_std, color=colors)
plt.axhline(0.5, color='black', linestyle='--')
plt.text(1.5, 0.5-0.015, "chance level", fontsize=12, ha='right', va='top')
plt.ylabel("task performance")
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.tight_layout()
savefig(fig_path, "accuracy_sep_by_recall_matched")


### distribution of number of time steps taken, grouped by number of matched items

In [24]:
plt.figure(figsize=(4.7, 3), dpi=180)
for i in range(0, timestep_each_phase):
    plt.plot(np.arange(1, timestep_each_phase+1), answer_timesteps_num_by_matched_memories_mean[i], label="{} matched".format(i), marker="o")
    plt.fill_between(np.arange(1, timestep_each_phase+1), answer_timesteps_num_by_matched_memories_mean[i]-answer_timesteps_num_by_matched_memories_std[i], 
                     answer_timesteps_num_by_matched_memories_mean[i]+answer_timesteps_num_by_matched_memories_std[i], alpha=0.2)
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.xlabel("time in recall phase")
plt.ylabel("proportion of trials")
plt.legend(fontsize=10, bbox_to_anchor=(1.05, 1.0), loc='upper left')
plt.tight_layout()
savefig(fig_path, "answer_timesteps_num_by_matched_memories")


## theoretical plots for each hypothesis

1. retrieve all memory without bias on matched memories
    - do not recall the same memory
    - randomly recall

2. recall matched items first

- plots
    - number of memories retrieved by number of time steps
    - number of matched memories retrieved by number of time step