In [None]:
%cd ..

In [None]:
from gym.spaces import Space
from gym import spaces


In [None]:
from memory.utils import read_json

stats = read_json('./data/dataset-stats.json')

In [None]:
len(set(list(stats['episodic_location_counts'].keys())))

In [96]:
import random
from memory.utils import read_json
import time
from pprint import pformat

MAX_INT_32 = 2147483647


class MemorySpace(Space):
    def __init__(self, capacity: int, memory_type: str) -> None:
        """

        Args
        ----
        memory_type: either episodic or semantic.
        capacity: memory capacity

        """
        assert memory_type in ["episodic", "semantic"]
        self.memory_type = memory_type
        self.capacity = capacity

        dataset_stats = read_json("./data/dataset-stats.json")

        self.possible_objects = list(
            set(list(dataset_stats[f"{self.memory_type}_object_counts"].keys()))
        )
        assert (
            len(self.possible_objects)
            == dataset_stats[f"num_unique_{self.memory_type}_objects"]
        )
        self.relation = "AtLocation"  # hard coded
        self.possible_locations = list(
            set(list(dataset_stats[f"{self.memory_type}_location_counts"].keys()))
        )
        assert (
            len(self.possible_locations)
            == dataset_stats[f"num_unique_{self.memory_type}_locations"]
        )

        super().__init__()

    def sample_memory(self):
        """Individual memory"""
        head = random.choice(self.possible_objects)
        tail = random.choice(self.possible_locations)

        time_or_num_generalized_memories = random.randint(1, MAX_INT_32)

        return [head, self.relation, tail, time_or_num_generalized_memories]

    def sample(self):
        """Memory system"""
        
        return [self.sample_memory() for _ in range(self.capacity + 1)]

    def contains(self, x):
        if not isinstance(x, list):
            return False

        if len(x) != (self.capacity + 1):
            return False        
        for x_ in x:
            if len(x_) != 4:
                return False
            if x_[0] not in self.possible_objects:
                return False
            if x_[1] != self.relation:
                return False
            if x_[2] not in self.possible_locations:
                return False
            if not (0 < x_[3] <= MAX_INT_32):
                return False

        return True

    def __repr__(self):
        return pformat(vars(self), indent=4, width=1)

    def __eq__(self, other):
        if self.memory_type != other.memory_type:
            return False
        if self.capacity != other.capacity:
            return False
        if self.possible_objects != other.possible_objects:
            return False
        if self.relation != other.relation:
            return False
        if self.possible_locations != other.possible_locations:
            return False

        return True


In [117]:
import gym
from gym import error, spaces, utils
from gym.utils import seeding
from memory.utils import read_data, select_a_question, load_questions
from memory import EpisodicMemory, SemanticMemory


class MemoryEnv(gym.Env):
    metadata = {"render.modes": ["console"]}
    CORRECT = 1
    WRONG = 0

    def __init__(
        self,
        memory_type: str,
        capacity: int,
        split: str,
        data_path: str = "./data/data.json",
        question_path: str = "./data/questions.json",
    ) -> None:
        """

        Args
        ----
        memory_type: either episodic or semantic.
        capacity: memory capacity
        split: train, val, or test
        data_path: path to data.
        question_path: path to the questions file.

        """
        assert memory_type in ["episodic", "semantic"]
        self.memory_type = memory_type
        self.capacity = capacity

        if self.memory_type == "episodic":
            self.M_dummy = EpisodicMemory(self.capacity)

        assert split in ["train", "val", "test"]
        self.split = split
        self.data = read_data(data_path)
        self.questions = load_questions(question_path)
        n_actions = self.capacity + 1
        self.action_space = spaces.Discrete(n_actions)
        self.observation_space = MemorySpace(capacity, memory_type)

    def reset(self):
        self.idx = 0
        for idx, ob in enumerate(self.data[self.split]):
            self.idx = idx
            mem_epi = self.M_dummy.ob2epi(ob)
            self.M_dummy.add(mem_epi)

            if self.M_dummy.is_kinda_full:
                break

        assert idx == self.capacity

        return self.M_dummy.entries

    def step(self, action):
        mem_to_forget = self.M_dummy.entries[action]

        self.M_dummy.forget(mem_to_forget)

        question = select_a_question(self.idx, self.data, self.questions, self.split)

        reward = self.M_dummy.answer_latest(question)

        self.idx += 1
        if self.idx == len(self.data[self.split]):
            done = True
        else:
            done = False

            ob = self.data[self.split][self.idx]
            mem_epi = self.M_dummy.ob2epi(ob)
            self.M_dummy.add(mem_epi)

            assert self.M_dummy.is_kinda_full

        info = {}

        return self.M_dummy.entries, reward, done, info

    def render(self, mode="console"):
        if mode != "console":
            raise NotImplementedError()
        else:
            print(self.M_dummy.entries)

    def close(self):
        pass


In [122]:
from stable_baselines.common.env_checker import check_env

In [None]:
env = MemoryEnv('episodic', 4, 'train')
# If the environment don't follow the interface, an error will be thrown
check_env(env, warn=True)

In [None]:
env = MemoryEnv('episodic', 4, 'val')

obs = env.reset()
env.render()

print(env.observation_space)
print(env.action_space)
print(env.action_space.sample())

n_steps = 1200

for step in range(n_steps):
    print("Step {}".format(step + 1))
    obs, reward, done, info = env.step(0)
    print('obs=', obs, 'reward=', reward, 'done=', done)
    env.render()
    if done:
        print("Goal reached!", "reward=", reward)
        break

In [139]:
info

{}

In [132]:
done

False

In [None]:
from transformers import RobertaTokenizer, RobertaForSequenceClassification
import torch

tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaForSequenceClassification.from_pretrained('roberta-base')

inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
labels = torch.tensor([1]).unsqueeze(0)  # Batch size 1
outputs = model(**inputs, labels=labels)
loss = outputs.loss
logits = outputs.logits

In [None]:
from memory import EpisodicMemory

capacities = [
2,
4, 0),
    (8, 0),
    (16, 0),
    (32, 0),
    (64, 0),
    (128, 0),
    (256, 0),
    (512, 0),
    (1024, 0),
]

results = {"train": {}, "val": {}, "test": {}}

for split in ["train", "val", "test"]:

    M_e = EpisodicMemory(capacity["episodic"])
    rewards = 0

    for step, ob in enumerate(data[split]):
        mem_epi = M_e.ob2epi(ob)
        M_e.add(mem_epi)
        if M_e.is_kinda_full:
            if policy["forget"].lower() == "oldest":
                M_e.forget_oldest()
            elif policy["forget"].lower() == "random":
                M_e.forget_random()
            else:
                raise NotImplementedError

        question = select_a_question(step, data, questions, split)

        if policy["answer"].lower() == "latest":
            reward = M_e.answer_latest(question)
        elif policy["answer"].lower() == "random":
            reward = M_e.answer_random(question)
        else:
            raise NotImplementedError

        rewards += reward

    results[split]["rewards"] = rewards
    results[split]["num_samples"] = len(data[split])
    results[split]["episodic_memories"] = M_e.entries

    logging.info(f"results so far: {results}")

logging.info("episodic only training done!")

In [None]:
episodic only

capacity