In [1]:
import sys
import os
os.chdir(os.path.join(sys.path[0], '..'))
os.environ['LOGLEVEL'] = "ERROR"

In [11]:
from memory.environments import OQAGenerator
from memory import EpisodicMemory
import time

max_history = 1024
accs = []

for policy in [
    {"forget": "oldest", "answer": "latest"},
    {"forget": "random", "answer": "latest"},
]:

    for capacity in [2,4,8,16,32,64,128,256,512,1024]:

        oqag = OQAGenerator(max_history=max_history, commonsense_prob=0.5)
        M_e = EpisodicMemory(capacity=capacity)
        rewards = 0

        for _ in range(max_history):
            ob, question_answer = oqag.generate(generate_qa=True, recent_more_likely=True)
            mem_epi = M_e.ob2epi(ob)
            M_e.add(mem_epi)
            if M_e.is_kinda_full:
                if policy["forget"].lower() == "oldest":
                    M_e.forget_oldest()
                elif policy["forget"].lower() == "random":
                    M_e.forget_random()
                else:
                    raise NotImplementedError

            if policy["answer"].lower() == "latest":
                reward, _, _ = M_e.answer_latest(question_answer)
            elif policy["answer"].lower() == "random":
                reward, _, _ = M_e.answer_random(question_answer)
            else:
                raise NotImplementedError

            rewards += reward
        print(f"capacity: {capacity}, policy: {policy}, rewards: {rewards}, acc: {rewards / max_history}")

capacity: 2, policy: {'forget': 'oldest', 'answer': 'latest'}, rewards: 18, acc: 0.017578125
capacity: 4, policy: {'forget': 'oldest', 'answer': 'latest'}, rewards: 47, acc: 0.0458984375
capacity: 8, policy: {'forget': 'oldest', 'answer': 'latest'}, rewards: 71, acc: 0.0693359375
capacity: 16, policy: {'forget': 'oldest', 'answer': 'latest'}, rewards: 146, acc: 0.142578125
capacity: 32, policy: {'forget': 'oldest', 'answer': 'latest'}, rewards: 262, acc: 0.255859375
capacity: 64, policy: {'forget': 'oldest', 'answer': 'latest'}, rewards: 387, acc: 0.3779296875
capacity: 128, policy: {'forget': 'oldest', 'answer': 'latest'}, rewards: 602, acc: 0.587890625
capacity: 256, policy: {'forget': 'oldest', 'answer': 'latest'}, rewards: 824, acc: 0.8046875
capacity: 512, policy: {'forget': 'oldest', 'answer': 'latest'}, rewards: 988, acc: 0.96484375
capacity: 1024, policy: {'forget': 'oldest', 'answer': 'latest'}, rewards: 1024, acc: 1.0
capacity: 2, policy: {'forget': 'random', 'answer': 'lates

In [12]:
from memory import SemanticMemory

max_history = 1024

for policy in [
    {"forget": "weakest", "answer": "strongest"},
    {"forget": "random", "answer": "strongest"},
]:

    for capacity in [2,4,8,16,32,64,128,256,512,1024]:

        oqag = OQAGenerator(max_history=max_history, commonsense_prob=0.5)
        M_s = SemanticMemory(capacity=capacity)
        rewards = 0

        for _ in range(max_history):
            ob, question_answer = oqag.generate(generate_qa=True, recent_more_likely=True)
            mem_sem = M_s.ob2sem(ob)
            if not M_s.is_frozen:
                M_s.add(mem_sem)
                if M_s.is_kinda_full and (not M_s.is_frozen):
                    if policy["forget"].lower() == "weakest":
                        M_s.forget_weakest()
                    elif policy["forget"].lower() == "random":
                        M_s.forget_random()
                    else:
                        raise NotImplementedError

            question = M_s.eq2sq(question_answer)

            if policy["answer"].lower() == "strongest":
                reward, _, _ = M_s.answer_strongest(question)
            elif policy["answer"].lower() == "random":
                reward, _, _ = M_s.answer_random(question)
            else:
                raise NotImplementedError

            rewards += reward



        print(f"capacity: {capacity}, policy: {policy}, rewards: {rewards}, acc: {rewards / max_history}")

capacity: 2, policy: {'forget': 'weakest', 'answer': 'strongest'}, rewards: 25, acc: 0.0244140625
capacity: 4, policy: {'forget': 'weakest', 'answer': 'strongest'}, rewards: 58, acc: 0.056640625
capacity: 8, policy: {'forget': 'weakest', 'answer': 'strongest'}, rewards: 57, acc: 0.0556640625
capacity: 16, policy: {'forget': 'weakest', 'answer': 'strongest'}, rewards: 146, acc: 0.142578125
capacity: 32, policy: {'forget': 'weakest', 'answer': 'strongest'}, rewards: 268, acc: 0.26171875
capacity: 64, policy: {'forget': 'weakest', 'answer': 'strongest'}, rewards: 403, acc: 0.3935546875
capacity: 128, policy: {'forget': 'weakest', 'answer': 'strongest'}, rewards: 497, acc: 0.4853515625
capacity: 256, policy: {'forget': 'weakest', 'answer': 'strongest'}, rewards: 532, acc: 0.51953125
capacity: 512, policy: {'forget': 'weakest', 'answer': 'strongest'}, rewards: 574, acc: 0.560546875
capacity: 1024, policy: {'forget': 'weakest', 'answer': 'strongest'}, rewards: 558, acc: 0.544921875
capacity:

In [13]:
from train_RL import DQN

In [14]:
net = DQN(129, 6, 129)

In [15]:
import torch
foo = torch.randn(4,129,6)

In [28]:
bar = net.LinearRow2(net.LinearRow1(net.LinearRow1(foo)))
bar.shape

torch.Size([4, 129, 1])

In [35]:
net.LinearCol2(bar.view(-1, 129)).shape

torch.Size([4, 129])

In [17]:
net(foo).shape

torch.Size([4, 129])