Idea: memorize the $Q$ function qua [Model-Free Episodic Control](https://arxiv.org/abs/1606.04460).

# CMT API

From the paper we have:

1. $(u, z) \leftarrow \text{Query}(x)$ where $z = \{ (x_n, \omega_n) \}$ is an ordered set of retrieved key-value pairs.
1. $\text{Update}(x, (x_n, \omega_n), r, u)$ provides feedback reward $r$ for retrieval of $(x_n, \omega_n)$ for query $x$.
   1. Must be compatible with self-consistency or supervised and unsupervised updates conflict.
1. $\text{Insert}(x, \omega)$ creates a new memory.

In [321]:
class CMT:
    class Node:
        def __init__(self, parent, left=None, right=None, g=None):
            self.parent = parent
            self.isLeaf = left is None
            self.n = 0        
            self.memories = {}
            self.left = left
            self.right = right
            self.g = g
            
        def makeInternal(self, g):       
            assert self.isLeaf
            
            self.isLeaf = False
            self.left = CMT.Node(parent=self)
            self.right = CMT.Node(parent=self)
            self.n = 0
            self.g = g
            
            mem = self.memories
            self.memories = {}
            
            return mem
        
        def replaceNode(self, replacement):
            if self is not replacement:
                self.isLeaf = replacement.isLeaf
                self.n = replacement.n
                self.memories = replacement.memories
                self.left = replacement.left
                if self.left:
                    self.left.parent = self
                self.right = replacement.right
                if self.right:
                    self.right.parent = self
                self.g = replacement.g
  
        def topk(self, x, k, f):
            assert self.isLeaf
            return [ z for _, z in zip(range(k), 
                                       sorted(self.memories.items(),
                                              key=lambda z: f.predict(x, z),
                                              reverse=True
                                             )
                                      ) 
                   ]
        
        def randk(self, k, randomState):
            assert self.isLeaf
            return randomState.sample(self.memories.items(), k)
    
    class Path:
        def __init__(self, nodes, leaf):
            self.nodes = nodes
            self.leaf = leaf
    
    def __init__(self, routerFactory, scorer, alpha, c, d, randomState):
        self.routerFactory = routerFactory
        self.f = scorer
        self.alpha = alpha
        self.c = c
        self.d = d
        self.leafbykey = {}
        self.root = CMT.Node(None)
        self.randomState = randomState        
        self.allkeys = []
        self.allkeysindex = {}

    def displaynode(self, node, indent):
        if node is not None:
            from pprint import pformat
            print(indent, pformat((node, node.__dict__)))
            self.displaynode(node.left, indent + "*")
            self.displaynode(node.right, indent + "*")
        
    def displaytree(self):
        self.displaynode(self.root, indent="")

    def __path(self, x, v):          
        nodes = []
        while not v.isLeaf:
            a = v.right if v.g.predict(x) > 0 else v.left
            nodes.append((v, a, 1))
            v = a
            
        return CMT.Path(nodes, v)
        
    def query(self, x, k, epsilon):
        path = self.__path(x, self.root)
        q = self.randomState.uniform(0, 1)
        if q >= epsilon:
            return (None, path.leaf.topk(x, k, self.f))
        else:
            i = self.randomState.randint(0, len(path.nodes))
            if i < len(path.nodes):
                a = self.randomState.choice( (path.nodes[i].left, path.nodes[i].right) )
                l = self.path(x, a)
                return ((path.nodes[i], a, 1/2), l.topk(x, k, self.f))
            else:
                return ((path.leaf, None, None), path.leaf.randk(k, self.randomState))
            
    def update(u, x, z, r):
        if u is None:
            pass
        else:
            (v, a, p) = u
            if u.isLeaf:
                self.f.update(x, z, r)
            else:
                from math import log

                rhat = (r/p) * (1 if a == v.right else -1)
                y = (1 - self.alpha) * rhat + self.alpha * (log(1e-2 + v.left.n) - log(1e-2 + v.right.n)) 
                signy = 1 if y > 0 else -1
                absy = signy * y
                v.g.update(x, signy, absy)
                
            for _ in range(self.d):
                self.__reroute()
                
    def delete(self, x):
        if x not in self.allkeysindex:
            # deleting something not in the memory ...
            assert False
                    
        ind = self.allkeysindex.pop(x)
        lastx = self.allkeys.pop()
        if ind < len(self.allkeys):
            self.allkeys[ind] = lastx
            self.allkeysindex[lastx] = ind
                
        v = self.leafbykey.pop(x)
        
        while v is not None:
            try:
                assert v.left is None or v.isLeaf is False
                assert v.left is None or v.left.parent is v
                assert v.right is None or v.right.parent is v
            except:
                from pprint import pformat
                print(pformat((v, v.__dict__)))
                raise
            
            v.n -= 1
            if v.isLeaf:
                try:
                    omega = v.memories.pop(x)
                except:
                    from pprint import pformat
                    print('x is {}'.format(x))
                    print('v.memories is {}'.format(v.memories))
                    raise
            else:
                if v.n == 0:
                    try:
                        assert v.parent is not None

                        assert v is v.parent.left or v is v.parent.right
                        other = v.parent.left if v is v.parent.right else v.parent.right
                        assert v is not other
                        assert v.parent is not other
                        assert v.parent is other.parent
                        assert 1 + other.n == v.parent.n, '{} + {} ?= {}'.format(v.n, other.n, v.parent.n)
                        if other.isLeaf:
                            for xprime in other.memories.keys():
                                self.leafbykey[xprime] = v.parent
                                

                        # yo
                        assert v.parent.left is None or v.parent.isLeaf is False
                        assert v.parent.left is None or v.parent.left.parent is v.parent
                        assert v.parent.right is None or v.parent.right.parent is v.parent
                                
                        v.parent.replaceNode(other)
                        
                        # sup
                        assert v.parent.left is None or v.parent.isLeaf is False
                        assert v.parent.left is None or v.parent.left.parent is v.parent
                        assert v.parent.right is None or v.parent.right.parent is v.parent
                        
                        v = v.parent
                    except:
                        from pprint import pformat
                        print(pformat((v, v.__dict__)))
                        raise
                    
            assert v.n >= 0
            v = v.parent
            
    def __insertLeaf(self, x, omega, v):
        from math import log
        
        assert v.isLeaf

        if x not in self.allkeysindex:            
            self.allkeysindex[x] = len(self.allkeys)
            self.allkeys.append(x)
            
        if v.n < self.c:
            assert x not in self.leafbykey
            self.leafbykey[x] = v
            assert x not in v.memories
            v.memories[x] = omega
            v.n += 1
            assert v.n == len(v.memories)
        else:
            assert not self.splitting
            self.splitting = True
            mem = v.makeInternal(g=self.routerFactory())
            
            while mem:
                xprime, omegaprime = mem.popitem()
                del self.leafbykey[xprime]
                self.insert(xprime, omegaprime, v, 0)
                
            self.insert(x, omega, v, 0)
                
        assert v.parent is None or (v is v.parent.left or v is v.parent.right)
     
    def insert(self, x, omega, v=None, d=None):
        from math import log
        
        if x in self.leafbykey:
            # duplicate memory ... need to merge values ...
            assert False
            
        if v is None:
            v = self.root
            self.splitting = False
        
        while not v.isLeaf:
            assert v.parent is None or (v is v.parent.left or v is v.parent.right)
            assert v.left is None or v.left.parent is v
            assert v.right is None or v.right.parent is v

            B = log(1e-2 + v.left.n) - log(1e-2 + v.right.n) 
            y = (1 - self.alpha) * v.g.predict(x) + self.alpha * B
            signy = 1 if y > 0 else -1
            v.g.update(x, signy, 1)
            v.n += 1
            v = v.right if v.g.predict(x) > 0 else v.left
            
        self.__insertLeaf(x, omega, v)
        
        for _ in range(self.d if d is None else d):
            self.__reroute()
                            
    def __reroute(self):
        x = self.randomState.choice(self.allkeys)
        omega = self.leafbykey[x].memories[x]
        assert self.root.n == len(self.leafbykey), '{} ?= {}'.format(self.root.n, len(self.leafbykey))
        assert self.root.left is None or self.root.left.parent is self.root
        assert self.root.right is None or self.root.right.parent is self.root
        self.delete(x)
        assert self.root.left is None or self.root.left.parent is self.root
        assert self.root.right is None or self.root.right.parent is self.root
        assert self.root.n == len(self.leafbykey), '{} ?= {}'.format(self.root.n, len(self.leafbykey))
        self.insert(x, omega, d=0)
        assert self.root.n == len(self.leafbykey), '{} ?= {}'.format(self.root.n, len(self.leafbykey))
        assert self.root.left is None or self.root.left.parent is self.root
        assert self.root.right is None or self.root.right.parent is self.root
        
        for k in self.leafbykey.keys():
            assert k in self.leafbykey[k].memories

In [320]:
class CMTTests:
    class LinearModel:
        def __init__(self, *args, **kwargs):
            from sklearn import linear_model
            
            self.model = linear_model.SGDRegressor(*args, **kwargs)
            
        def predict(self, x):
            from sklearn.exceptions import NotFittedError 
            try:
                return self.model.predict(X=[x])[0]
            except NotFittedError:
                return 0
        
        def update(self, x, y, w):
            self.model.partial_fit(X=[x], y=[y], sample_weight=[w])
            
    class NormalizedLinearProduct:
        def __init__(self):
            pass
        
        def predict(self, x, z):
            import numpy as np
            from math import sqrt
            
            (xprime, omegaprime) = z
            
            xa = np.array(x)
            xprimea = np.array(xprime)
                        
            return np.inner(xa, xprimea) / sqrt(np.inner(xa, xa) * np.inner(xprimea, xprimea))
        
        def update(self, x, y, w):
            pass
 
    @staticmethod
    def selfconsistent():
        def tracefunc(frame, event, arg, indent=[0]):
            if event == "call":
                indent[0] += 2
                print("-" * indent[0] + "> call function", frame.f_code.co_name)
            elif event == "return":
                print("<" + "-" * indent[0], "exit function", frame.f_code.co_name)
                indent[0] -= 2
            return tracefunc

        import random
        
        routerFactory = lambda: CMTTests.LinearModel()
        scorer = CMTTests.NormalizedLinearProduct()
        randomState = random.Random()
        randomState.seed(45)
        cmt = CMT(routerFactory=routerFactory, scorer=scorer, alpha=0.5, c=10, d=0, randomState=randomState)
        
        stuff = {}
        for _ in range(200):
            try:
                x = tuple([ randomState.uniform(0, 1) for _ in range(3)])
                omega = randomState.uniform(0, 1)

                cmt.insert(x, omega)
                u, [ (xprime, omegaprime) ] = cmt.query(x, k=1, epsilon=0)
                assert omega == omegaprime, '({}, [({}, {})]) = cmt.query({}) != {}'.format(u, xprime, omegaprime, x, omega)
            except:
                print("--------------")
                cmt.displaytree()
                print("--------------")
                raise
                
        print('selfconsistent test pass')
            
CMTTests().selfconsistent()

selfconsistent test pass


# Memorized Q pseudocode -- Attempt 3

Basic idea:
* Estimate value of $a$ in context $x$ by stored value associated with first memory retrieved from CMT queried with $(x, a)$.
* Play $\epsilon$-greedy with greedy action being the maximum estimated value. 
* Play action $a$ in context $x$ and observe reward $r$.
* Reward memory system just like a parametric direct method, i.e., using regression loss such as squared loss.
   * Update the memory $((x', a'), r')$ retrieved by query $(x, a)$ using reward $-(r - r')^2$.
* Insert key $(x, a)$ with value $r$.
* Conjecture: compatible with self-consistency assuming no reward variance.
   * Update reward is maximized by retrieving a memory with $r = r'$.
   * Exact match response does this.
   * Censorship issue: only argmax key is updated, does this matter?

# Memorized Q pseudocode -- Attempt 2

Basic idea:
* Estimate value of $a$ in context $x$ by stored value associated with first memory retrieved from CMT queried with $(x, a)$.
* Play $\epsilon$-greedy with greedy action being the maximum estimated value. 
* Play action $a$ in context $x$ and observe reward $r$.
* For each action $a'$, update the memory retrieved with query $(x, a')$ using the observed reward as feedback reward.
* Insert key $(x, a)$ with value $r$.
* Conjecture: compatible with self-consistency assuming no reward variance.
   * Update reward is maximized by identifying the correct argmax.
   * Exact match responses to all queries identifies the correct argmax.
   * No censorship issues: all retrieved keys are updated.

# Memorized Q pseudocode

Basic idea:
* Estimate value of $a$ in context $x$ by stored value associated with first memory retrieved from CMT queried with $(x, a)$.
* Play $\epsilon$-greedy with greedy action being the maximum estimated value. 
* After playing action $a$ in context $x$ and observing reward $r$:
   * Update the memory retrieved with query $(x, a)$ using feedback reward of $r$.
   * Store memory with key $(x, a)$ and value $r$.
* Conjecture: not compatible with self-consistency because of update frequency issues.
   * By returning a non-exact match with high stored reward, a key can capture more updates.
   * Counter argument: the memory system as a whole receives largest possible reward if argmax is correct, which exact match ensures.
   * Counter Counter argument: but reward is associated to particular keys with different frequency, does that matter?
   * **Confused**

In [5]:
def MemorizedQ():
    mem = CMT()
    env = Environment()  # distribution over (x, r) pairs.

    Actions = set(...)   # fixed set of actions for now
    epsilon = ...        # epsilon-greedy exploration

    while True:
        x = env.Observe()
        querySet = { a: (u, ((xprime, aprime), rprime))
                     for a in Actions
                     for (u, z) in [ CMT.Query(key=(x, a)) ]
                     if len(z) > 0
                     for ((xprime, aprime), rprime) in [ z[0] ]
                   }
        if len(querySet) > 0:
            greedy, _ = max(querySet.iteritems(), lambda kv: kv[1][1][1]) # action with largest first retrieved reward
        else:
            greedy = next(iter(Actions))                                  # if memory is completely empty, play action 0

        pa = (1 - epsilon) * IndicatorDistribution(greedy) + epsilon * UniformDistribution(Actions)
        a = pa.sample()
        r = env.ObserveReward(a)

        if a in querySet:
            # question: what's the feedback reward?
            # question: do we only do this when we take the greedy action?

            u, (xprime, aprime), rprime = querySet[a]
            CMT.Update(key=(x, a), retrieved=((xprime, aprime), rprime), feedbackreward=None, u=u)

        CMT.Insert(key=(x, a), value=r)

### Is this compatible with self-consistency?

I'm not sure.  Suppose there is no reward variance, so we just dealing with the partial feedback issue.
* Any memory retrieved when querying on $(x, a)$ will be updated with feedback reward $r$.
* Conditional on calling `Update()`, feedback reward is constant.
* Except that some retrieved memories will "win the argmax" and some will lose, changing frequency of `Update()`.
* Consider the memory retrieved by `Query(key=(x, a))`.
   * Possible inserted $((x, a), r)$ pair will lose the argmax after additional inserts.
   * This could be appropriate as another action $a'$ might be better in a neighborhood of $x$ but hadn't been observed yet.
   * However retrieving $((x'', a''), r'')$ with $r'' > r$ would win the argmax and receive reward $r$.
   
**Idea**: this could be self-consistent if we update all the actions.