Idea: memorize the $Q$ function qua [Model-Free Episodic Control](https://arxiv.org/abs/1606.04460).

# CMT API

From the paper we have:

1. $(u, z) \leftarrow \text{Query}(x)$ where $z = \{ (x_n, \omega_n) \}$ is an ordered set of retrieved key-value pairs.
1. $\text{Update}(x, (x_n, \omega_n), r, u)$ provides feedback reward $r$ for retrieval of $(x_n, \omega_n)$ for query $x$.
   1. Must be compatible with self-consistency or supervised and unsupervised updates conflict.
1. $\text{Insert}(x, \omega)$ creates a new memory.

In [372]:
class CMT:
    class Node:
        def __init__(self, parent, left=None, right=None, g=None):
            self.parent = parent
            self.isLeaf = left is None
            self.n = 0        
            self.memories = {}
            self.left = left
            self.right = right
            self.g = g
            
        def makeInternal(self, g):       
            assert self.isLeaf
            
            self.isLeaf = False
            self.left = CMT.Node(parent=self)
            self.right = CMT.Node(parent=self)
            self.n = 0
            self.g = g
            
            mem = self.memories
            self.memories = {}
            
            return mem
        
        def replaceNode(self, replacement):
            if self is not replacement:
                self.isLeaf = replacement.isLeaf
                self.n = replacement.n
                self.memories = replacement.memories
                self.left = replacement.left
                if self.left:
                    self.left.parent = self
                self.right = replacement.right
                if self.right:
                    self.right.parent = self
                self.g = replacement.g
  
        def topk(self, x, k, f):
            assert self.isLeaf
            return [ z for _, z in zip(range(k), 
                                       sorted(self.memories.items(),
                                              key=lambda z: f.predict(x, z),
                                              reverse=True
                                             )
                                      ) 
                   ]
        
        def randk(self, k, randomState):
            assert self.isLeaf
            return randomState.sample(self.memories.items(), k)
    
    class Path:
        def __init__(self, nodes, leaf):
            self.nodes = nodes
            self.leaf = leaf
    
    def __init__(self, routerFactory, scorer, alpha, c, d, randomState):
        self.routerFactory = routerFactory
        self.f = scorer
        self.alpha = alpha
        self.c = c
        self.d = d
        self.leafbykey = {}
        self.root = CMT.Node(None)
        self.randomState = randomState        
        self.allkeys = []
        self.allkeysindex = {}

    def nodeForeach(self, f, node=None):
        if node is None:
            node = self.root
            
        f(node)
        if node.left:
            self.nodeForeach(f, node.left)
        if node.right:
            self.nodeForeach(f, node.right)
        
    def __path(self, x, v):          
        nodes = []
        while not v.isLeaf:
            a = v.right if v.g.predict(x) > 0 else v.left
            nodes.append(v)
            v = a
            
        return CMT.Path(nodes, v)
        
    def query(self, x, k, epsilon):
        path = self.__path(x, self.root)
        q = self.randomState.uniform(0, 1)
        if q >= epsilon:
            return (None, path.leaf.topk(x, k, self.f))
        else:
            i = self.randomState.randint(0, len(path.nodes))
            if i < len(path.nodes):
                a = self.randomState.choice( (path.nodes[i].left, path.nodes[i].right) )
                l = self.__path(x, a).leaf
                return ((path.nodes[i], a, 1/2), l.topk(x, k, self.f))
            else:
                return ((path.leaf, None, None), path.leaf.randk(k, self.randomState))
            
    def update(self, u, x, z, r):
        if u is None:
            pass
        else:
            (v, a, p) = u
            if v.isLeaf:
                self.f.update(x, z, r)
            else:
                from math import log

                rhat = (r/p) * (1 if a == v.right else -1)
                y = (1 - self.alpha) * rhat + self.alpha * (log(1e-2 + v.left.n) - log(1e-2 + v.right.n)) 
                signy = 1 if y > 0 else -1
                absy = signy * y
                v.g.update(x, signy, absy)
                
            for _ in range(self.d):
                self.__reroute()
                
    def delete(self, x):
        if x not in self.allkeysindex:
            # deleting something not in the memory ...
            assert False
                    
        ind = self.allkeysindex.pop(x)
        lastx = self.allkeys.pop()
        if ind < len(self.allkeys):
            self.allkeys[ind] = lastx
            self.allkeysindex[lastx] = ind
                
        v = self.leafbykey.pop(x)
        
        while v is not None:
            v.n -= 1
            if v.isLeaf:
                omega = v.memories.pop(x)
            else:
                if v.n == 0:
                    other = v.parent.left if v is v.parent.right else v.parent.right
                    if other.isLeaf:
                        for xprime in other.memories.keys():
                            self.leafbykey[xprime] = v.parent

                    v.parent.replaceNode(other)
                    v = v.parent
                    
            assert v.n >= 0
            v = v.parent
            
    def __insertLeaf(self, x, omega, v):
        from math import log
        
        assert v.isLeaf

        if x not in self.allkeysindex:            
            self.allkeysindex[x] = len(self.allkeys)
            self.allkeys.append(x)
            
        if v.n < self.c:
            assert x not in self.leafbykey
            self.leafbykey[x] = v
            assert x not in v.memories
            v.memories[x] = omega
            v.n += 1
            assert v.n == len(v.memories)
        else:
            assert not self.splitting
            self.splitting = True
            mem = v.makeInternal(g=self.routerFactory())
            
            while mem:
                xprime, omegaprime = mem.popitem()
                del self.leafbykey[xprime]
                self.insert(xprime, omegaprime, v, 0)
                
            self.insert(x, omega, v, 0)
                     
    def insert(self, x, omega, v=None, d=None):
        from math import log
        
        if x in self.leafbykey:
            # duplicate memory ... need to merge values ...
            assert False
            
        if v is None:
            v = self.root
            self.splitting = False
        
        while not v.isLeaf:
            B = log(1e-2 + v.left.n) - log(1e-2 + v.right.n) 
            y = (1 - self.alpha) * v.g.predict(x) + self.alpha * B
            signy = 1 if y > 0 else -1
            v.g.update(x, signy, 1)
            v.n += 1
            v = v.right if v.g.predict(x) > 0 else v.left
            
        self.__insertLeaf(x, omega, v)
        
        for _ in range(self.d if d is None else d):
            self.__reroute()
                            
    def __reroute(self):
        x = self.randomState.choice(self.allkeys)
        omega = self.leafbykey[x].memories[x]
        self.delete(x)
        self.insert(x, omega, d=0)
        
        for k in self.leafbykey.keys():
            assert k in self.leafbykey[k].memories

In [373]:
class CMTTests:
    class LinearModel:
        def __init__(self, *args, **kwargs):
            from sklearn import linear_model
            
            self.model = linear_model.SGDRegressor(*args, **kwargs)
            
        def predict(self, x):
            from sklearn.exceptions import NotFittedError 
            try:
                return self.model.predict(X=[x])[0]
            except NotFittedError:
                return 0
        
        def update(self, x, y, w):
            self.model.partial_fit(X=[x], y=[y], sample_weight=[w])
            
    class NormalizedLinearProduct:
        def __init__(self):
            pass
        
        def predict(self, x, z):
            import numpy as np
            from math import sqrt
            
            (xprime, omegaprime) = z
            
            xa = np.array(x)
            xprimea = np.array(xprime)
                        
            return np.inner(xa, xprimea) / sqrt(np.inner(xa, xa) * np.inner(xprimea, xprimea))
        
        def update(self, x, y, w):
            pass
 
    @staticmethod
    def displaynode(node, indent):
        if node is not None:
            from pprint import pformat
            print(indent, pformat((node, node.__dict__)))
            CMTTests.displaynode(node.left, indent + "*")
            CMTTests.displaynode(node.right, indent + "*")

    @staticmethod
    def displaytree(cmt):
        CMTTests.displaynode(cmt.root, indent="")

    @staticmethod
    def structureValid():
        import random
        
        routerFactory = lambda: CMTTests.LinearModel()
        scorer = CMTTests.NormalizedLinearProduct()
        randomState = random.Random()
        randomState.seed(2112)
        cmt = CMT(routerFactory=routerFactory, scorer=scorer, alpha=0.5, c=10, d=0, randomState=randomState)

        def checkNodeInvariants(node):
            assert node.parent is None or node.parent.left is node or node.parent.right is node
            assert node.left is None or node.n == node.left.n + node.right.n
            assert node.left is None or node.left.parent is node
            assert node.right is None or node.right.parent is node
            assert node.left is not None or node.n == len(node.memories)
    
        stuff = {}
        
        for _ in range(200):
            try:
                if stuff and randomState.uniform(0, 1) < 0.1:
                    # delete
                    x, omega = stuff.popitem()
                    cmt.delete(x)
                elif stuff and randomState.uniform(0, 1) < 0.1:
                    # query/update
                    somex = randomState.choice(list(stuff.keys()))
                    u, z = cmt.query(x, 1, 0.1)
                    cmt.update(u, x, z, randomState.uniform(0, 1))
                else:
                    # insert
                    x = tuple([ randomState.uniform(0, 1) for _ in range(3)])
                    omega = randomState.uniform(0, 1)
                    cmt.insert(x, omega)
                    stuff[x] = omega

                assert cmt.root.n == len(stuff)
                assert cmt.root.n == len(cmt.leafbykey)
                assert cmt.root.n == len(cmt.allkeys)
                assert cmt.root.n == len(cmt.allkeysindex)
                for x in stuff.keys():
                    assert x in cmt.leafbykey[x].memories
                    assert x in cmt.allkeysindex
                    assert cmt.allkeys[cmt.allkeysindex[x]] is x
                cmt.nodeForeach(checkNodeInvariants)
            except:
                print("--------------")
                CMTTests.displaytree(cmt)
                print("--------------")
                raise
                
        print('structureValid test pass')           
                       
    @staticmethod
    def selfconsistent():
        import random
        
        routerFactory = lambda: CMTTests.LinearModel()
        scorer = CMTTests.NormalizedLinearProduct()
        randomState = random.Random()
        randomState.seed(45)
        cmt = CMT(routerFactory=routerFactory, scorer=scorer, alpha=0.5, c=10, d=0, randomState=randomState)
        
        for _ in range(200):
            try:
                x = tuple([ randomState.uniform(0, 1) for _ in range(3)])
                omega = randomState.uniform(0, 1)

                cmt.insert(x, omega)
                u, [ (xprime, omegaprime) ] = cmt.query(x, k=1, epsilon=0)
                assert omega == omegaprime, '({}, [({}, {})]) = cmt.query({}) != {}'.format(u, xprime, omegaprime, x, omega)
            except:
                print("--------------")
                CMTTests.displaytree(cmt)
                print("--------------")
                raise
                
        print('selfconsistent test pass')
       
    @staticmethod
    def all():
        CMTTests.structureValid()
        CMTTests.selfconsistent()

CMTTests().all()

structureValid test pass
selfconsistent test pass


# Fully Observed Covertype

## Quadratic Classifier

In [386]:
class FOC:
    class EasyAcc:
        def __init__(self):
            self.n = 0
            self.sum = 0
            
        def __iadd__(self, other):
            self.n += 1
            self.sum += other
            return self
            
        def mean(self):
            return self.sum / max(self.n, 1)
 
    def doit():
        from collections import Counter
        from sklearn.datasets import fetch_covtype
        from sklearn.decomposition import PCA
        from sklearn.linear_model import SGDClassifier
        from sklearn.metrics import accuracy_score
        from math import ceil
        import numpy as np

        cov = fetch_covtype()
        cov.data = PCA(whiten=True).fit_transform(cov.data)
        classes = np.unique(cov.target - 1)
        print(Counter(cov.target - 1))
        ndata = len(cov.target)
        order = np.random.RandomState(seed=42).permutation(ndata)
        ntrain = ceil(0.9 * ndata)
        Object = lambda **kwargs: type("Object", (), kwargs)()
        train = Object(data = cov.data[order[:ntrain]], target = cov.target[order[:ntrain]] - 1)
        test = Object(data = cov.data[order[ntrain:]], target = cov.target[order[ntrain:]] - 1)
  
        blocksize = 32
        for lr in [0.1]:
            print("*** lr = {} ***".format(lr), flush=True)
            print('{:8.8s}\t{:8.8s}\t{:10.10s}\t{:10.10s}'.format(
                'n', 'emp loss', 'since last', 'pred')
            )

            cls = SGDClassifier(loss='log', shuffle=False, learning_rate='invscaling', eta0=lr)
            loss = FOC.EasyAcc()
            sincelast = FOC.EasyAcc()

            for pno in range(1):
                order = np.random.RandomState(seed=42+pno).permutation(len(train.data))
                for n, ind in enumerate(zip(*(iter(order),)*blocksize)):
                    v = np.array([ np.outer(t, np.append(t, [1])).ravel() for z in ind for t in ( train.data[z], ) ])
                    actual = [ train.target[z] for z in ind ]
                    if n == 0:
                        for a in actual:
                            loss += 0 if a == 0 else 1
                            sincelast += 0 if a == 0 else 1
                    if n > 0:
                        pred = cls.predict(v)
                        for p, a in zip(pred, actual):
                            loss += 0 if p == a else 1
                            sincelast += 0 if p == a else 1  

                        if (n & (n - 1) == 0): # and n & 0xAAAAAAAA == 0):
                            print('{:<8d}\t{:<8.3f}\t{:<10.3f}\t{:<10d}'.format(
                                        loss.n, loss.mean(), sincelast.mean(), pred[-1]),
                                  flush=True)

                            sincelast = FOC.EasyAcc()

                    cls.partial_fit(v, actual, classes=classes)

                print('{:<8d}\t{:<8.3f}\t{:<10.3f}\t{:<10d}'.format(
                             loss.n, loss.mean(), sincelast.mean(), pred[-1]),
                       flush=True)                
                sincelast = FOC.EasyAcc()

                preds = cls.predict(np.array([np.outer(d, np.append(d, [1])).ravel() for d in test.data]))
                print(Counter(preds))
                ascores = []
                for b in range(16):
                    bootie = np.random.RandomState(90210+b).choice(len(test.target), replace=True, size=len(test.target))
                    ascores.append(accuracy_score(y_true=test.target[bootie], y_pred=preds[bootie]))

                print("test accuracy: {}".format(np.quantile(ascores, [0.05, 0.5, 0.95])))
                        
FOC.doit()

Counter({1: 283301, 0: 211840, 2: 35754, 6: 20510, 5: 17367, 4: 9493, 3: 2747})
*** lr = 0.1 ***
n       	emp loss	since last	pred      
64      	0.531   	0.531     	6         
96      	0.542   	0.562     	1         
160     	0.494   	0.422     	1         
288     	0.472   	0.445     	1         
544     	0.428   	0.379     	0         
1056    	0.407   	0.385     	1         
2080    	0.371   	0.333     	0         
4128    	0.352   	0.333     	2         
8224    	0.325   	0.298     	0         
16416   	0.307   	0.289     	0         
32800   	0.286   	0.265     	0         
65568   	0.273   	0.260     	1         
131104  	0.263   	0.253     	0         
262176  	0.257   	0.251     	0         
522880  	0.252   	0.248     	0         
Counter({1: 29667, 0: 20845, 2: 3529, 5: 1750, 6: 1703, 4: 320, 3: 287})
test accuracy: [0.74952669 0.75278395 0.75523227]


## Contextual Memory Tree

In [None]:
class FOC:
    class EasyAcc:
        def __init__(self):
            self.n = 0
            self.sum = 0
            
        def __iadd__(self, other):
            self.n += 1
            self.sum += other
            return self
            
        def mean(self):
            return self.sum / max(self.n, 1)
 
    class LogisticModel:
        def __init__(self, *args, **kwargs):
            from sklearn import linear_model
            
            kwargs['loss'] = 'log'
            self.model = linear_model.SGDClassifier(*args, **kwargs)
            
        def predict(self, x):
            from sklearn.exceptions import NotFittedError 
            try:
                return self.model.predict(X=[x])[0]
            except NotFittedError:
                return 0
        
        def update(self, x, y, w):
            self.model.partial_fit(X=[x], y=[y], sample_weight=[w], classes=(-1, 1))
            
    class LinearModel:
        def __init__(self, *args, **kwargs):
            from sklearn import linear_model
            
            self.model = linear_model.SGDRegressor(*args, **kwargs)
            
        def predict(self, x):
            from sklearn.exceptions import NotFittedError 
            try:
                return self.model.predict(X=[x])[0]
            except NotFittedError:
                return 0
        
        def update(self, x, y, w):
            self.model.partial_fit(X=[x], y=[y], sample_weight=[w])

    class EuclideanDistance:
        def __init__(self):
            pass
        
        def predict(self, x, z):
            import numpy as np
            from math import sqrt
            
            (xprime, omegaprime) = z
            
            return -np.linalg.norm(np.array(x) - np.array(xprime))
        
        def update(self, x, y, w):
            pass
            
    def doit():
        from collections import Counter
        from sklearn.datasets import fetch_covtype
        from sklearn.decomposition import PCA
        from sklearn.linear_model import SGDClassifier
        from sklearn.metrics import accuracy_score
        from math import ceil
        import numpy as np

        cov = fetch_covtype()
        cov.data = PCA(whiten=True).fit_transform(cov.data)
        classes = np.unique(cov.target - 1)
        print(Counter(cov.target - 1))
        ndata = len(cov.target)
        order = np.random.RandomState(seed=42).permutation(ndata)
        ntrain = ceil(0.9 * ndata)
        Object = lambda **kwargs: type("Object", (), kwargs)()
        train = Object(data = cov.data[order[:ntrain]], target = cov.target[order[:ntrain]] - 1)
        test = Object(data = cov.data[order[ntrain:]], target = cov.target[order[ntrain:]] - 1)
        
        import random
        
        routerFactory = lambda: FOC.LogisticModel(eta0=0.1, learning_rate='invscaling')
        scorer = FOC.EuclideanDistance()
        randomState = random.Random()
        randomState.seed(45)
        cmt = CMT(routerFactory=routerFactory, scorer=scorer, alpha=0.25, c=10, d=1, randomState=randomState)
  
        print('{:8.8s}\t{:8.8s}\t{:10.10s}\t{:10.10s}'.format(
            'n', 'emp loss', 'since last', 'last pred')
        )

        loss = FOC.EasyAcc()
        sincelast = FOC.EasyAcc()

        for pno in range(1):
            order = np.random.RandomState(seed=42+pno).permutation(len(train.data))
            for n, ind in enumerate(order):
                t = train.data[ind]
                x = tuple(np.outer(t, np.append(t, [1])).ravel())
                actual = train.target[ind]
                
                if n == 0:
                    pred = 0
                else:
                    u, z = cmt.query(x, k=1, epsilon=0.0)
                    pred = z[0][1] if len(z) else 0
                    
                loss += 0 if pred == actual else 1
                sincelast += 0 if pred == actual else 1
                
                if (n & (n - 1) == 0): # and n & 0xAAAAAAAA == 0):
                    print('{:<8d}\t{:<8.3f}\t{:<10.3f}\t{:<10d}'.format(
                                loss.n, loss.mean(), sincelast.mean(), pred),
                          flush=True)

                    sincelast = FOC.EasyAcc()
                    
                if n > 0:
#                     cmt.update(u, x, z, 1 if pred == actual else -1)
                    cmt.insert(x, actual)

            print('{:<8d}\t{:<8.3f}\t{:<10.3f}\t{:<10d}'.format(
                         loss.n, loss.mean(), sincelast.mean(), pred),
                   flush=True)                
            sincelast = FOC.EasyAcc()
            
            assert False

            preds = cls.predict(np.array([np.outer(d, np.append(d, [1])).ravel() for d in test.data]))
            print(Counter(preds))
            ascores = []
            for b in range(16):
                bootie = np.random.RandomState(90210+b).choice(len(test.target), replace=True, size=len(test.target))
                ascores.append(accuracy_score(y_true=test.target[bootie], y_pred=preds[bootie]))

            print("test accuracy: {}".format(np.quantile(ascores, [0.05, 0.5, 0.95])))
                        

FOC.doit()

Counter({1: 283301, 0: 211840, 2: 35754, 6: 20510, 5: 17367, 4: 9493, 3: 2747})
n       	emp loss	since last	last pred 
1       	1.000   	1.000     	0         
2       	0.500   	0.000     	0         
3       	0.333   	0.000     	0         
5       	0.400   	0.500     	0         
9       	0.333   	0.250     	1         
17      	0.529   	0.750     	5         
33      	0.424   	0.312     	1         
65      	0.462   	0.500     	1         
129     	0.488   	0.516     	1         
257     	0.475   	0.461     	1         
513     	0.431   	0.387     	1         
1025    	0.395   	0.359     	1         
2049    	0.375   	0.355     	1         
4097    	0.361   	0.346     	1         


# Memorized Q pseudocode -- Attempt 3

Basic idea:
* Estimate value of $a$ in context $x$ by stored value associated with first memory retrieved from CMT queried with $(x, a)$.
* Play $\epsilon$-greedy with greedy action being the maximum estimated value. 
* Play action $a$ in context $x$ and observe reward $r$.
* Reward memory system just like a parametric direct method, i.e., using regression loss such as squared loss.
   * Update the memory $((x', a'), r')$ retrieved by query $(x, a)$ using reward $-(r - r')^2$.
* Insert key $(x, a)$ with value $r$.
* Conjecture: compatible with self-consistency assuming no reward variance.
   * Update reward is maximized by retrieving a memory with $r = r'$.
   * Exact match response does this.
   * Censorship issue: only argmax key is updated, does this matter?

# Memorized Q pseudocode -- Attempt 2

Basic idea:
* Estimate value of $a$ in context $x$ by stored value associated with first memory retrieved from CMT queried with $(x, a)$.
* Play $\epsilon$-greedy with greedy action being the maximum estimated value. 
* Play action $a$ in context $x$ and observe reward $r$.
* For each action $a'$, update the memory retrieved with query $(x, a')$ using the observed reward as feedback reward.
* Insert key $(x, a)$ with value $r$.
* Conjecture: compatible with self-consistency assuming no reward variance.
   * Update reward is maximized by identifying the correct argmax.
   * Exact match responses to all queries identifies the correct argmax.
   * No censorship issues: all retrieved keys are updated.

# Memorized Q pseudocode

Basic idea:
* Estimate value of $a$ in context $x$ by stored value associated with first memory retrieved from CMT queried with $(x, a)$.
* Play $\epsilon$-greedy with greedy action being the maximum estimated value. 
* After playing action $a$ in context $x$ and observing reward $r$:
   * Update the memory retrieved with query $(x, a)$ using feedback reward of $r$.
   * Store memory with key $(x, a)$ and value $r$.
* Conjecture: not compatible with self-consistency because of update frequency issues.
   * By returning a non-exact match with high stored reward, a key can capture more updates.
   * Counter argument: the memory system as a whole receives largest possible reward if argmax is correct, which exact match ensures.
   * Counter Counter argument: but reward is associated to particular keys with different frequency, does that matter?
   * **Confused**

In [358]:
def MemorizedQ():
    mem = CMT()
    env = Environment()  # distribution over (x, r) pairs.

    Actions = set(...)   # fixed set of actions for now
    epsilon = ...        # epsilon-greedy exploration

    while True:
        x = env.Observe()
        querySet = { a: (u, ((xprime, aprime), rprime))
                     for a in Actions
                     for (u, z) in [ CMT.Query(key=(x, a)) ]
                     if len(z) > 0
                     for ((xprime, aprime), rprime) in [ z[0] ]
                   }
        if len(querySet) > 0:
            greedy, _ = max(querySet.iteritems(), lambda kv: kv[1][1][1]) # action with largest first retrieved reward
        else:
            greedy = next(iter(Actions))                                  # if memory is completely empty, play action 0

        pa = (1 - epsilon) * IndicatorDistribution(greedy) + epsilon * UniformDistribution(Actions)
        a = pa.sample()
        r = env.ObserveReward(a)

        if a in querySet:
            # question: what's the feedback reward?
            # question: do we only do this when we take the greedy action?

            u, (xprime, aprime), rprime = querySet[a]
            CMT.Update(key=(x, a), retrieved=((xprime, aprime), rprime), feedbackreward=None, u=u)

        CMT.Insert(key=(x, a), value=r)

### Is this compatible with self-consistency?

I'm not sure.  Suppose there is no reward variance, so we just dealing with the partial feedback issue.
* Any memory retrieved when querying on $(x, a)$ will be updated with feedback reward $r$.
* Conditional on calling `Update()`, feedback reward is constant.
* Except that some retrieved memories will "win the argmax" and some will lose, changing frequency of `Update()`.
* Consider the memory retrieved by `Query(key=(x, a))`.
   * Possible inserted $((x, a), r)$ pair will lose the argmax after additional inserts.
   * This could be appropriate as another action $a'$ might be better in a neighborhood of $x$ but hadn't been observed yet.
   * However retrieving $((x'', a''), r'')$ with $r'' > r$ would win the argmax and receive reward $r$.
   
**Idea**: this could be self-consistent if we update all the actions.