In [63]:
import sys
import os

os.chdir(os.path.join(sys.path[0], ".."))
os.environ["LOGLEVEL"] = "ERROR"

In [64]:
from memory.environment.generator import OQAGenerator
from memory import EpisodicMemory
import time

max_history = 1024
accs = []

for policy in [
    {"forget": "oldest", "answer": "latest"},
    {"forget": "random", "answer": "latest"},
]:

    # for capacity in [2,4,8,16,32,64,128,256,512,1024]:
    for capacity in [256]:

        oqag = OQAGenerator(max_history=max_history, commonsense_prob=0.5)
        M_e = EpisodicMemory(capacity=capacity)
        rewards = 0

        for _ in range(max_history):
            time.sleep(0.01)
            ob, question_answer = oqag.generate(
                generate_qa=True, recent_more_likely=True
            )
            mem_epi = M_e.ob2epi(ob)
            M_e.add(mem_epi)
            if M_e.is_kinda_full:
                if policy["forget"].lower() == "oldest":
                    M_e.forget_oldest()
                elif policy["forget"].lower() == "random":
                    M_e.forget_random()
                else:
                    raise NotImplementedError

            if policy["answer"].lower() == "latest":
                reward, _, _ = M_e.answer_latest(question_answer)
            elif policy["answer"].lower() == "random":
                reward, _, _ = M_e.answer_random(question_answer)
            else:
                raise NotImplementedError

            rewards += reward
        print(
            f"capacity: {capacity}, policy: {policy}, rewards: {rewards}, acc: {rewards / max_history}"
        )

capacity: 256, policy: {'forget': 'oldest', 'answer': 'latest'}, rewards: 821, acc: 0.8017578125
capacity: 256, policy: {'forget': 'random', 'answer': 'latest'}, rewards: 723, acc: 0.7060546875


In [65]:
from memory.environment import MemorySpace

mem_space = MemorySpace(
    {"episodic": 256, "semantic": 0}, space_type="episodic_memory_manage"
)

In [66]:
mem_space.episodic_memory_system_to_numbers(M_e, M_e.capacity + 1).tolist()

[[114.0, 1032.0, 10.0, 114.0, 10074.0, 592758.0],
 [105.0, 1037.0, 10.0, 105.0, 10073.0, 592759.5],
 [108.0, 10099.0, 10.0, 108.0, 10128.0, 592759.5],
 [101.0, 1008.0, 10.0, 101.0, 10137.0, 592759.6875],
 [112.0, 1009.0, 10.0, 112.0, 10034.0, 592760.375],
 [113.0, 1014.0, 10.0, 113.0, 10009.0, 592760.5],
 [105.0, 1035.0, 10.0, 105.0, 10053.0, 592760.5625],
 [116.0, 1003.0, 10.0, 116.0, 10037.0, 592761.125],
 [113.0, 1026.0, 10.0, 113.0, 10071.0, 592761.1875],
 [118.0, 1009.0, 10.0, 118.0, 10104.0, 592761.1875],
 [113.0, 1035.0, 10.0, 113.0, 10065.0, 592761.25],
 [116.0, 1014.0, 10.0, 116.0, 10003.0, 592761.375],
 [102.0, 1043.0, 10.0, 102.0, 10092.0, 592761.4375],
 [115.0, 10089.0, 10.0, 115.0, 10128.0, 592761.75],
 [103.0, 10002.0, 10.0, 103.0, 10113.0, 592761.8125],
 [100.0, 1045.0, 10.0, 100.0, 10125.0, 592761.9375],
 [118.0, 1041.0, 10.0, 118.0, 10001.0, 592762.0],
 [108.0, 1012.0, 10.0, 108.0, 10100.0, 592762.0625],
 [118.0, 1045.0, 10.0, 118.0, 10125.0, 592762.25],
 [110.0, 1023.

In [40]:
oqag.M_e.entries

[["Charles's bird", 'AtLocation', "Charles's sky", 590450.9613595009],
 ["Mary's airplane", 'AtLocation', "Mary's restaurant", 590450.9614241123],
 ["Sarah's bus", 'AtLocation', "Sarah's city", 590450.9614770412],
 ["John's broccoli", 'AtLocation', "John's bakery", 590450.9615254402],
 ["Karen's pizza", 'AtLocation', "Karen's oven", 590450.9615733624],
 ["Patricia's toothbrush",
  'AtLocation',
  "Patricia's suitcase",
  590450.961622715],
 ["Sarah's mouse", 'AtLocation', "Sarah's laboratory", 590450.9616737366],
 ["Jennifer's scissors",
  'AtLocation',
  "Jennifer's backyard",
  590450.9617218971],
 ["Robert's bed", 'AtLocation', "Robert's house", 590450.9617700577],
 ["Susan's spoon", 'AtLocation', "Susan's stadium", 590450.9618165493],
 ["Sarah's car", 'AtLocation', "Sarah's church", 590450.9618647099],
 ["William's keyboard", 'AtLocation', "William's desk", 590450.9619131088],
 ["Patricia's elephant", 'AtLocation', "Patricia's land", 590450.9619715214],
 ["Sarah's toothbrush", 'AtL

In [12]:
from memory import SemanticMemory

max_history = 1024

for policy in [
    {"forget": "weakest", "answer": "strongest"},
    {"forget": "random", "answer": "strongest"},
]:

    for capacity in [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]:

        oqag = OQAGenerator(max_history=max_history, commonsense_prob=0.5)
        M_s = SemanticMemory(capacity=capacity)
        rewards = 0

        for _ in range(max_history):
            ob, question_answer = oqag.generate(
                generate_qa=True, recent_more_likely=True
            )
            mem_sem = M_s.ob2sem(ob)
            if not M_s.is_frozen:
                M_s.add(mem_sem)
                if M_s.is_kinda_full and (not M_s.is_frozen):
                    if policy["forget"].lower() == "weakest":
                        M_s.forget_weakest()
                    elif policy["forget"].lower() == "random":
                        M_s.forget_random()
                    else:
                        raise NotImplementedError

            question = M_s.eq2sq(question_answer)

            if policy["answer"].lower() == "strongest":
                reward, _, _ = M_s.answer_strongest(question)
            elif policy["answer"].lower() == "random":
                reward, _, _ = M_s.answer_random(question)
            else:
                raise NotImplementedError

            rewards += reward

        print(
            f"capacity: {capacity}, policy: {policy}, rewards: {rewards}, acc: {rewards / max_history}"
        )

capacity: 2, policy: {'forget': 'weakest', 'answer': 'strongest'}, rewards: 25, acc: 0.0244140625
capacity: 4, policy: {'forget': 'weakest', 'answer': 'strongest'}, rewards: 58, acc: 0.056640625
capacity: 8, policy: {'forget': 'weakest', 'answer': 'strongest'}, rewards: 57, acc: 0.0556640625
capacity: 16, policy: {'forget': 'weakest', 'answer': 'strongest'}, rewards: 146, acc: 0.142578125
capacity: 32, policy: {'forget': 'weakest', 'answer': 'strongest'}, rewards: 268, acc: 0.26171875
capacity: 64, policy: {'forget': 'weakest', 'answer': 'strongest'}, rewards: 403, acc: 0.3935546875
capacity: 128, policy: {'forget': 'weakest', 'answer': 'strongest'}, rewards: 497, acc: 0.4853515625
capacity: 256, policy: {'forget': 'weakest', 'answer': 'strongest'}, rewards: 532, acc: 0.51953125
capacity: 512, policy: {'forget': 'weakest', 'answer': 'strongest'}, rewards: 574, acc: 0.560546875
capacity: 1024, policy: {'forget': 'weakest', 'answer': 'strongest'}, rewards: 558, acc: 0.544921875
capacity:

In [13]:
from train_RL import DQN

In [14]:
net = DQN(129, 6, 129)

In [15]:
import torch

foo = torch.randn(4, 129, 6)

In [28]:
bar = net.LinearRow2(net.LinearRow1(net.LinearRow1(foo)))
bar.shape

torch.Size([4, 129, 1])

In [35]:
net.LinearCol2(bar.view(-1, 129)).shape

torch.Size([4, 129])

In [17]:
net(foo).shape

torch.Size([4, 129])