Idea: memorize the $Q$ function qua [Model-Free Episodic Control](https://arxiv.org/abs/1606.04460).

# CMT API

From the paper we have:

1. $(u, z) \leftarrow \text{Query}(x)$ where $z = \{ (x_n, \omega_n) \}$ is an ordered set of retrieved key-value pairs.
1. $\text{Update}(x, (x_n, \omega_n), r, u)$ provides feedback reward $r$ for retrieval of $(x_n, \omega_n)$ for query $x$.
   1. Must be compatible with self-consistency or supervised and unsupervised updates conflict.
1. $\text{Insert}(x, \omega)$ creates a new memory.

In [1]:
class CMT:
    class Node:
        def __init__(self, parent, left=None, right=None, g=None):
            self.parent = parent
            self.isLeaf = left is None
            self.n = 0        
            self.memories = {}
            self.left = left
            self.right = right
            self.g = g
            
        def makeInternal(self, g):       
            assert self.isLeaf
            
            self.isLeaf = False
            self.left = CMT.Node(parent=self)
            self.right = CMT.Node(parent=self)
            self.n = 0
            self.g = g
            
            mem = self.memories
            self.memories = {}
            
            return mem
        
        def replaceNode(self, replacement):
            if self is not replacement:
                self.isLeaf = replacement.isLeaf
                self.n = replacement.n
                self.memories = replacement.memories
                self.left = replacement.left
                if self.left:
                    self.left.parent = self
                self.right = replacement.right
                if self.right:
                    self.right.parent = self
                self.g = replacement.g

        def topk(self, x, k, f):
            assert self.isLeaf
            return [ z for _, z in zip(range(k), 
                                       sorted(self.memories.items(),
                                              key=lambda z: f.predict(x, z),
                                              reverse=True
                                             )
                                      ) 
                   ]
        
        def randk(self, k, randomState):
            assert self.isLeaf
            return [ z[1] for _, z in zip(range(k),
                                          sorted( (randomState.uniform(0, 1), m) for m in self.memories.items() 
                                                )
                                      )
                   ]
    
    class Path:
        def __init__(self, nodes, leaf):
            self.nodes = nodes
            self.leaf = leaf
    
    class LRU:
        def __init__(self):
            self.entries = []
            self.entry_finder = set()
            self.n = 0
        
        def add(self, x):
            from heapq import heappush
            
            assert x not in self.entry_finder
            
            entry = (self.n, x)
            self.entry_finder.add(x)
            heappush(self.entries, entry)
            self.n += 1
            
        def __len__(self):
            return len(self.entry_finder)
        
        def __contains__(self, x):
            return x in self.entry_finder
        
        def peek(self):
            from heapq import heappop
            
            while self.entries[0][1] not in self.entry_finder:
                heappop(self.entries)
                
            return self.entries[0][1]
        
        def remove(self, x):
            self.entry_finder.remove(x)
    
    def __init__(self, routerFactory, scorer, alpha, c, d, randomState, maxMemories=None):
        self.routerFactory = routerFactory
        self.f = scorer
        self.alpha = alpha
        self.c = c
        self.d = d
        self.leafbykey = {}
        self.root = CMT.Node(None)
        self.randomState = randomState        
        self.allkeys = []
        self.allkeysindex = {}
        self.maxMemories = maxMemories 
        self.keyslru = CMT.LRU()
        self.rerouting = False
        self.splitting = False

    def nodeForeach(self, f, node=None):
        if node is None:
            node = self.root
            
        f(node)
        if node.left:
            self.nodeForeach(f, node.left)
        if node.right:
            self.nodeForeach(f, node.right)
        
    def __path(self, x, v):          
        nodes = []
        while not v.isLeaf:
            a = v.right if v.g.predict(x) > 0 else v.left
            nodes.append(v)
            v = a
            
        return CMT.Path(nodes, v)
        
    def query(self, x, k, epsilon):
        path = self.__path(x, self.root)
        q = self.randomState.uniform(0, 1)
        if q >= epsilon:
            return (None, path.leaf.topk(x, k, self.f))
        else:
            i = self.randomState.randint(0, len(path.nodes))
            if i < len(path.nodes):
                a = self.randomState.choice( (path.nodes[i].left, path.nodes[i].right) )
                l = self.__path(x, a).leaf
                return ((path.nodes[i], a, 1/2), l.topk(x, k, self.f))
            else:
                return ((path.leaf, None, None), path.leaf.randk(k, self.randomState))
            
    def update(self, u, x, z, r):
        if u is None:
            pass
        else:
            (v, a, p) = u
            if v.isLeaf:
                self.f.update(x, z, r)
            else:
                from math import log

                rhat = (r/p) * (1 if a == v.right else -1)
                y = (1 - self.alpha) * rhat + self.alpha * (log(1e-2 + v.left.n) - log(1e-2 + v.right.n)) 
                signy = 1 if y > 0 else -1
                absy = signy * y
                v.g.update(x, signy, absy)
                
            for _ in range(self.d):
                self.__reroute()
                
    def delete(self, x):
        if x not in self.allkeysindex:
            # deleting something not in the memory ...
            assert False
                    
        ind = self.allkeysindex.pop(x)
        lastx = self.allkeys.pop()
        if ind < len(self.allkeys):
            self.allkeys[ind] = lastx
            self.allkeysindex[lastx] = ind
                
        if not self.rerouting:
            self.keyslru.remove(x)
                
        v = self.leafbykey.pop(x)
        
        while v is not None:
            v.n -= 1
            if v.isLeaf:
                omega = v.memories.pop(x)
            else:
                if v.n == 0:
                    other = v.parent.left if v is v.parent.right else v.parent.right
                    if other.isLeaf:
                        for xprime in other.memories.keys():
                            self.leafbykey[xprime] = v.parent

                    v.parent.replaceNode(other)
                    v = v.parent
                    
            assert v.n >= 0
            v = v.parent
            
    def __insertLeaf(self, x, omega, v):
        from math import log
        
        assert v.isLeaf

        if x not in self.allkeysindex:          
            self.allkeysindex[x] = len(self.allkeys)
            self.allkeys.append(x)
        
        if not self.rerouting and not self.splitting:
            self.keyslru.add(x)
                        
        if self.splitting or v.n < self.c:
            assert x not in self.leafbykey
            self.leafbykey[x] = v
            assert x not in v.memories
            v.memories[x] = omega
            v.n += 1
            assert v.n == len(v.memories)
        else:
            self.splitting = True
            mem = v.makeInternal(g=self.routerFactory())
            
            while mem:
                xprime, omegaprime = mem.popitem()
                del self.leafbykey[xprime]
                self.insert(xprime, omegaprime, v)
                
            self.insert(x, omega, v)
            self.splitting = False
            
        if not self.rerouting and not self.splitting:
            daleaf = self.leafbykey[x]
            dabest = daleaf.topk(x, 2, self.f)
            if len(dabest) > 1:
                other = dabest[1] if dabest[0][0] == x else dabest[0] 
                z = [(x, omega), other]
                self.f.update(x, z, 1)
                     
    def insert(self, x, omega, v=None):
        from math import log
        
        if x in self.leafbykey:
            # duplicate memory ... need to merge values ...
            assert False
            
        if v is None:
            v = self.root
        
        while not v.isLeaf:
            B = log(1e-2 + v.left.n) - log(1e-2 + v.right.n)
            y = (1 - self.alpha) * v.g.predict(x) + self.alpha * B
            signy = 1 if y > 0 else -1
            v.g.update(x, signy, 1)
            v.n += 1
            v = v.right if v.g.predict(x) > 0 else v.left
            
        self.__insertLeaf(x, omega, v)
        
        if not self.rerouting and not self.splitting:
            if self.maxMemories is not None and len(self.keyslru) > self.maxMemories:
                oldest = self.keyslru.peek()
                self.delete(oldest)

            for _ in range(self.d):
                self.__reroute()
                            
    def __reroute(self):
        x = self.randomState.choice(self.allkeys)
        omega = self.leafbykey[x].memories[x]
        self.rerouting = True
        self.delete(x)
        self.insert(x, omega)
        self.rerouting = False
        
        for k in self.leafbykey.keys():
            assert k in self.leafbykey[k].memories

In [2]:
class CMTTests:
    class LinearModel:
        def __init__(self, *args, **kwargs):
            from sklearn import linear_model
            
            self.model = linear_model.SGDRegressor(*args, **kwargs)
            
        def predict(self, x):
            from sklearn.exceptions import NotFittedError 
            try:
                return self.model.predict(X=[x])[0]
            except NotFittedError:
                return 0
        
        def update(self, x, y, w):
            self.model.partial_fit(X=[x], y=[y], sample_weight=[w])
            
    class NormalizedLinearProduct:
        def __init__(self):
            pass
        
        def predict(self, x, z):
            import numpy as np
            from math import sqrt
            
            (xprime, omegaprime) = z
            
            xa = np.array(x)
            xprimea = np.array(xprime)
                        
            return np.inner(xa, xprimea) / sqrt(np.inner(xa, xa) * np.inner(xprimea, xprimea))
        
        def update(self, x, y, w):
            pass
 
    @staticmethod
    def displaynode(node, indent):
        if node is not None:
            from pprint import pformat
            print(indent, pformat((node, node.__dict__)))
            CMTTests.displaynode(node.left, indent + "*")
            CMTTests.displaynode(node.right, indent + "*")

    @staticmethod
    def displaytree(cmt):
        CMTTests.displaynode(cmt.root, indent="")

    @staticmethod
    def structureValid():
        import random
        
        routerFactory = lambda: CMTTests.LinearModel()
        scorer = CMTTests.NormalizedLinearProduct()
        randomState = random.Random()
        randomState.seed(2112)
        cmt = CMT(routerFactory=routerFactory, scorer=scorer, alpha=0.5, c=10, d=0, randomState=randomState)

        def checkNodeInvariants(node):
            assert node.parent is None or node.parent.left is node or node.parent.right is node
            assert node.left is None or node.n == node.left.n + node.right.n
            assert node.left is None or node.left.parent is node
            assert node.right is None or node.right.parent is node
            assert node.left is not None or node.n == len(node.memories)
    
        stuff = {}
        
        for _ in range(200):
            try:
                if stuff and randomState.uniform(0, 1) < 0.1:
                    # delete
                    x, omega = stuff.popitem()
                    cmt.delete(x)
                elif stuff and randomState.uniform(0, 1) < 0.1:
                    # query/update
                    somex = randomState.choice(list(stuff.keys()))
                    u, z = cmt.query(x, 1, 0.1)
                    cmt.update(u, x, z, randomState.uniform(0, 1))
                else:
                    # insert
                    x = tuple([ randomState.uniform(0, 1) for _ in range(3)])
                    omega = randomState.uniform(0, 1)
                    cmt.insert(x, omega)
                    stuff[x] = omega

                assert cmt.root.n == len(stuff)
                assert cmt.root.n == len(cmt.leafbykey)
                assert cmt.root.n == len(cmt.allkeys)
                assert cmt.root.n == len(cmt.allkeysindex)
                for x in stuff.keys():
                    assert x in cmt.leafbykey[x].memories
                    assert x in cmt.allkeysindex
                    assert cmt.allkeys[cmt.allkeysindex[x]] is x
                cmt.nodeForeach(checkNodeInvariants)
            except:
                print("--------------")
                CMTTests.displaytree(cmt)
                print("--------------")
                raise
                
        print('structureValid test pass')           
                       
    @staticmethod
    def selfconsistent():
        import random
        
        routerFactory = lambda: CMTTests.LinearModel()
        scorer = CMTTests.NormalizedLinearProduct()
        randomState = random.Random()
        randomState.seed(45)
        cmt = CMT(routerFactory=routerFactory, scorer=scorer, alpha=0.5, c=10, d=0, randomState=randomState)
        
        for _ in range(200):
            try:
                x = tuple([ randomState.uniform(0, 1) for _ in range(3)])
                omega = randomState.uniform(0, 1)

                cmt.insert(x, omega)
                u, [ (xprime, omegaprime) ] = cmt.query(x, k=1, epsilon=0)
                assert omega == omegaprime, '({}, [({}, {})]) = cmt.query({}) != {}'.format(u, xprime, omegaprime, x, omega)
            except:
                print("--------------")
                CMTTests.displaytree(cmt)
                print("--------------")
                raise
                
        print('selfconsistent test pass')
        
    @staticmethod
    def maxmemories():
        import random
        
        routerFactory = lambda: CMTTests.LinearModel()
        scorer = CMTTests.NormalizedLinearProduct()
        randomState = random.Random()
        randomState.seed(45)
        maxM = 100
        cmt = CMT(routerFactory=routerFactory, scorer=scorer, alpha=0.5, c=10, d=0, randomState=randomState, maxMemories=maxM)
        
        for _ in range(200):
            try:
                x = tuple([ randomState.uniform(0, 1) for _ in range(3)])
                omega = randomState.uniform(0, 1)

                cmt.insert(x, omega)
                assert len(cmt.leafbykey) <= maxM
            except:
                print("--------------")
                CMTTests.displaytree(cmt)
                print("--------------")
                raise
                
        print('maxmemories test pass')
       
    @staticmethod
    def all():
        CMTTests.structureValid()
        CMTTests.selfconsistent()
        CMTTests.maxmemories()

CMTTests().all()

structureValid test pass
selfconsistent test pass
maxmemories test pass


# Value difference from deletion

## Fully Supervised Case

Suppose queries are drawn IID from $D$, every query is followed by an update, and we have a fixed CMT we are operating in greedy mode.   Then the expected reward for the CMT is 
$$
V(\text{CMT}) = \mathbb{E}_{\substack{x \sim D \\ z \sim \text{CMT}(x) \\ r \sim \text{Update}(x, z)}}\left[ 1^\top r \right]
$$
Note the reward on $z$ is the sum over the rewards on each returned item (NB: this structure is forced upon us by the update call).
Suppose now we delete meme $\alpha$ from the system, denote the resulting tree $\text{CMT}_{\setminus \alpha}$.  The reward is conditionally independent of the CMT given $z$, therefore $$
\begin{aligned}
V\left( \text{CMT}_{\setminus \alpha} \right) &= \mathbb{E}_{\substack{x \sim D \\ z_{\setminus \alpha} \sim \text{CMT}_{\setminus \alpha}(x) \\ r_{\setminus \alpha} \sim \text{Update}(x, z_{\setminus \alpha})}}\left[ 1^\top r_{\setminus \alpha} \right] \\
\Delta V(\alpha) \doteq V(\text{CMT}) - V\left( \text{CMT}_{\setminus \alpha} \right) &= \mathbb{E}_{\substack{x \sim D \\ z \sim \text{CMT}(x) \\ z_{\setminus \alpha} \sim \text{CMT}_{\setminus \alpha}(x) \\ r \sim \text{Update}(x, z) \\ r_{\setminus \alpha} \sim \text{Update}(x, z_{\setminus \alpha})}}\left[ 1^\top (r - r_{\setminus \alpha}) 1_{z \neq z_{\setminus \alpha}} \right] \\
\end{aligned}
$$ For a particular query $x$, if $\alpha \not \in z$ than $z = z_{\setminus \alpha}$.  So to estimate the value difference from deletion we only need to consider queries whose retrieval contains $\alpha$.  Therefore, delta in reward when $\alpha \in z$ is ($\mathbb{E}\left[r | x, \alpha \right]$ - $\mathbb{E}\left[r | x, \text{k+1}^\text{th}\right]$) where $\mathbb{E}\left[r | x, \text{k+1}^\text{th}\right]$ is the estimated value of the "first meme not returned in greedy mode".  How to estimate (a lower bound on) this?

* **Idea**: Approximate $\mathbb{E}\left[r | x, \alpha\right] \approx \mathbb{E}\left[r | \text{leaf}(x), \alpha \right]$.
   * For each meme, maintain a scalar (CI) reward conditional on retrieval.
   * A valuable memory: 
       1. Is frequently retrieved, and
       1. Has a high reward conditioned upon being retrieved, and
       1. The $\text{k+1}^\text{th}$ meme in the same leaf has low reward conditioned on being retrieved.
       1. $\Delta V(\alpha) = \mathrm{Pr}(\text{$\alpha$ retrieved}) \left( \mathbb{E}\left[r | \text{leaf}(x), \alpha \right] - \mathbb{E}\left[r | \text{leaf}(x), \text{k+1}^\text{th}\right]\right)$.

## Partially Supervised or Unsupervised Case

Denote an &ldquo;no-update&rdquo; reward as $\emptyset$ so that rewards are always specified. Suppose updates are missing at random such that $\mathbb{E}\left[r 1_{r \neq \emptyset} | x, \alpha\right] = \mathbb{E}\left[r | x, \alpha, r \neq \emptyset\right] \mathbb{E}\left[1_{r \neq \emptyset}\right]$.  Then it makes sense to condition on $r \neq \emptyset$ and apply the fully supervised strategy.

## Fully Unsupervised Case

In this case we only know two things:
1. Exact match retrieval is presumed to have maximum $r = 1$. 
1. Each insert comes with an implicit exact match update.

Therefore, if contexts never repeat all memories have the same estimated value (everybody has exactly 1 update with $r = 1$) and all memes have equal value, suggesting at-random deletion is a good strategy.

If contexts repeat than the CI for a frequently exactly matched meme  will be tighter below $r = 1$ than for an infrequently exactly matched meme, in which case the least frequently exactly matched meme is best to delete.

# Fully Observed Covertype

## Linear Classifier

In [3]:
class FOC:
    class EasyAcc:
        def __init__(self):
            self.n = 0
            self.sum = 0
            
        def __iadd__(self, other):
            self.n += 1
            self.sum += other
            return self
            
        def mean(self):
            return self.sum / max(self.n, 1)

    import torch
    class LogisticRegressor(torch.nn.Module):        
        def __init__(self, input_dim, output_dim, eta0=0.1):
            import torch
            
            super(FOC.LogisticRegressor, self).__init__()
            self.linear = torch.nn.Linear(input_dim, output_dim)
            self.loss = torch.nn.CrossEntropyLoss()
            self.optimizer = torch.optim.Adam(self.linear.parameters(), lr=eta0)
            self.eta0 = eta0
            self.n = 0
            
        def forward(self, x):
            import numpy as np
            import torch

            return self.linear(torch.autograd.Variable(torch.from_numpy(x)))
        
        def predict(self, X):
            import torch
            
            return torch.argmax(self.forward(X), dim=1).numpy()
        
        def set_lr(self):
            from math import sqrt
            lr = self.eta0 / sqrt(self.n)
            for g in self.optimizer.param_groups:
                g['lr'] = lr
            
        def partial_fit(self, X, y, sample_weight=None, **kwargs):
            import torch
            
            self.optimizer.zero_grad()
            yhat = self.forward(X)
            if sample_weight is None:
                loss = self.loss(yhat, torch.from_numpy(y))
            else:
                loss = sample_weight * self.loss(yhat, torch.from_numpy(y))
            loss.backward()
            self.n += X.shape[0]
            self.set_lr()
            self.optimizer.step() 
        
    def doit():
        from collections import Counter
        from sklearn.datasets import fetch_covtype
        from sklearn.decomposition import PCA
        from sklearn.linear_model import SGDClassifier
        from sklearn.metrics import accuracy_score
        from math import ceil
        import numpy as np

        cov = fetch_covtype()
        cov.data = PCA(whiten=True).fit_transform(cov.data).astype(np.float32)
        classes = np.unique(cov.target - 1)
        print(Counter(cov.target - 1))
        ndata = len(cov.target)
        order = np.random.RandomState(seed=42).permutation(ndata)
        ntrain = ceil(0.9 * ndata)
        Object = lambda **kwargs: type("Object", (), kwargs)()
        train = Object(data = cov.data[order[:ntrain]], target = cov.target[order[:ntrain]] - 1)
        test = Object(data = cov.data[order[ntrain:]], target = cov.target[order[ntrain:]] - 1)

        blocksize = 32
        for lr in [0.1]:
            print("*** lr = {} ***".format(lr), flush=True)
            print('{:8.8s}\t{:8.8s}\t{:10.10s}\t{:10.10s}'.format(
                'n', 'emp loss', 'since last', 'pred')
            )
            
            input_dim = train.data[0].shape[0]
            cls = FOC.LogisticRegressor(input_dim, output_dim=len(classes), eta0=lr)
#             cls = SGDClassifier(loss='log', shuffle=False, learning_rate='invscaling', eta0=lr)
            loss = FOC.EasyAcc()
            sincelast = FOC.EasyAcc()

            for pno in range(1):
                order = np.random.RandomState(seed=42+pno).permutation(len(train.data))
                for n, ind in enumerate(zip(*(iter(order),)*blocksize)):
                    v = np.array([ t
                                   for z in ind for t in ( train.data[z], ) ],
                                 dtype='float32')
                    actual = np.array([ train.target[z] for z in ind ], dtype='int')
                    if n == 0:
                        for a in actual:
                            loss += 0 if a == 0 else 1
                            sincelast += 0 if a == 0 else 1
                    if n > 0:
                        pred = cls.predict(v)
                        for p, a in zip(pred, actual):
                            loss += 0 if p == a else 1
                            sincelast += 0 if p == a else 1  

                        if (n & (n - 1) == 0): # and n & 0xAAAAAAAA == 0):
                            print('{:<8d}\t{:<8.3f}\t{:<10.3f}\t{:<10d}'.format(
                                        loss.n, loss.mean(), sincelast.mean(), pred[-1]),
                                  flush=True)

                            sincelast = FOC.EasyAcc()

                    cls.partial_fit(v, actual, classes=classes)

                print('{:<8d}\t{:<8.3f}\t{:<10.3f}\t{:<10d}'.format(
                             loss.n, loss.mean(), sincelast.mean(), pred[-1]),
                       flush=True)                
                sincelast = FOC.EasyAcc()

                preds = cls.predict(test.data.astype('float32'))
                print(Counter(preds))
                ascores = []
                for b in range(16):
                    bootie = np.random.RandomState(90210+b).choice(len(test.target), replace=True, size=len(test.target))
                    ascores.append(accuracy_score(y_true=test.target[bootie], y_pred=preds[bootie]))

                print("test accuracy: {}".format(np.quantile(ascores, [0.05, 0.5, 0.95])))

def flass():
    import timeit
    print(timeit.timeit(FOC.doit, number=1))
    
flass()

Counter({1: 283301, 0: 211840, 2: 35754, 6: 20510, 5: 17367, 4: 9493, 3: 2747})
*** lr = 0.1 ***
n       	emp loss	since last	pred      
64      	0.641   	0.641     	3         
96      	0.625   	0.594     	4         
160     	0.656   	0.703     	6         
288     	0.615   	0.562     	1         
544     	0.570   	0.520     	1         
1056    	0.518   	0.463     	1         
2080    	0.455   	0.391     	6         
4128    	0.405   	0.354     	2         
8224    	0.365   	0.324     	0         
16416   	0.332   	0.300     	0         
32800   	0.312   	0.291     	0         
65568   	0.301   	0.290     	1         
131104  	0.293   	0.286     	0         
262176  	0.289   	0.285     	1         
522880  	0.287   	0.284     	0         
Counter({1: 30335, 0: 20781, 2: 4847, 6: 1503, 5: 422, 3: 202, 4: 11})
test accuracy: [0.71633449 0.71908401 0.7219368 ]
19.190191100002266


## Contextual Memory Tree (Linear Routers)

In [5]:
class FOC:
    class EasyAcc:
        def __init__(self):
            self.n = 0
            self.sum = 0
            
        def __iadd__(self, other):
            self.n += 1
            self.sum += other
            return self
            
        def mean(self):
            return self.sum / max(self.n, 1)

    import torch
    class LogisticRegressor(torch.nn.Module):        
        def __init__(self, input_dim, output_dim, eta0):
            import torch
            
            super(FOC.LogisticRegressor, self).__init__()
            self.linear = torch.nn.Linear(input_dim, output_dim)
            self.loss = torch.nn.CrossEntropyLoss()
            self.optimizer = torch.optim.Adam(self.linear.parameters(), lr=eta0)
            self.eta0 = eta0
            self.n = 0
            
        def forward(self, X):
            import numpy as np
            import torch

            return self.linear(torch.autograd.Variable(torch.from_numpy(X)))
        
        def predict(self, X):
            import torch
            
            return torch.argmax(self.forward(X), dim=1).numpy()
        
        def set_lr(self):
            from math import sqrt
            lr = self.eta0 / sqrt(self.n)
            for g in self.optimizer.param_groups:
                g['lr'] = lr

        def partial_fit(self, X, y, sample_weight=None, **kwargs):
            import torch
            
            self.optimizer.zero_grad()
            yhat = self.forward(X)
            if sample_weight is None:
                loss = self.loss(yhat, torch.from_numpy(y))
            else:
                loss = torch.from_numpy(sample_weight) * self.loss(yhat, torch.from_numpy(y))
            loss.backward()
            self.n += X.shape[0]
            self.set_lr()
            self.optimizer.step() 

    class LogisticModel:
        def __init__(self, *args, **kwargs):
            kwargs['output_dim'] = 2
            self.model = FOC.LogisticRegressor(*args, **kwargs)
            
        def predict(self, x):
            import numpy as np
            
            F = self.model.forward(X=np.array([x], dtype='float32')).detach().numpy()
            dF = F[:,1] - F[:,0]
            return -1 + 2 * dF          
        
        def update(self, x, y, w):
            import numpy as np
            
            assert y == 1 or y == -1
            
            self.model.partial_fit(X=np.array([x], dtype='float32'), 
                                   y=(1 + np.array([y], dtype='int')) // 2, 
                                   sample_weight=np.array([w], dtype='float32'),
                                   classes=(0, 1))

    class LearnedEuclideanDistance:
        def __init__(self, *args, **kwargs):
            kwargs['output_dim'] = 2
            self.model = FOC.LogisticRegressor(*args, **kwargs)
            self.model.linear.weight.data[0,:].fill_(0.01 / kwargs['input_dim'])
            self.model.linear.weight.data[1,:].fill_(-0.01 / kwargs['input_dim'])
            self.model.linear.bias.data.fill_(0.0)
            self.model.linear.bias.requires_grad = False
        
        def predict(self, x, z):
            import numpy as np
            
            (xprime, omegaprime) = z
            
            dx = np.array([x], dtype='float32')
            dx -= [xprime]
            dx *= dx
            
            F = self.model.forward(dx).detach().numpy()
            dist = F[0,1] - F[0,0]
            return dist
        
        def update(self, x, z, r):
            import numpy as np
            
            if r == 1 and len(z) > 1 and z[0][1] != z[1][1]:
                dx = np.array([ z[0][0], z[1][0] ], dtype='float32')
                dx -= [x]
                dx *= dx
                y = np.array([1, 0], dtype='int')    
                self.model.partial_fit(X=dx,
                                       y=y,
                                       sample_weight=None, # (?)
                                       classes=(0, 1))
            
    def doit():
        from collections import Counter
        from sklearn.datasets import fetch_covtype
        from sklearn.decomposition import PCA
        from sklearn.linear_model import SGDClassifier
        from sklearn.metrics import accuracy_score
        from math import ceil
        import numpy as np
        import random
        import torch

        cov = fetch_covtype()
        cov.data = PCA(whiten=True).fit_transform(cov.data)
        classes = np.unique(cov.target - 1)
        print(Counter(cov.target - 1))
        ndata = len(cov.target)
        order = np.random.RandomState(seed=42).permutation(ndata)
        ntrain = ceil(0.9 * ndata)
        Object = lambda **kwargs: type("Object", (), kwargs)()
        train = Object(data = cov.data[order[:ntrain]], target = cov.target[order[:ntrain]] - 1)
        test = Object(data = cov.data[order[ntrain:]], target = cov.target[order[ntrain:]] - 1)
        
        input_dim = train.data[0].shape[0]
        routerFactory = lambda: FOC.LogisticModel(eta0=0.1, input_dim=input_dim)
        scorer = FOC.LearnedEuclideanDistance(eta0=1e-4, input_dim=input_dim)
        randomState = random.Random()
        randomState.seed(45)
        torch.manual_seed(2112)
        cmt = CMT(routerFactory=routerFactory, scorer=scorer, alpha=0.25, c=10, d=1, randomState=randomState, 
                  maxMemories=1000)

        print('{:8.8s}\t{:8.8s}\t{:10.10s}\t{:10.10s}'.format(
            'n', 'emp loss', 'since last', 'last pred')
        )

        loss = FOC.EasyAcc()
        sincelast = FOC.EasyAcc()

        for pno in range(1):
            order = np.random.RandomState(seed=42+pno).permutation(len(train.data))
            for n, ind in enumerate(order):
                t = train.data[ind]
                x = tuple(t)
                actual = train.target[ind]
                
                if n == 0:
                    pred = 0
                else:
                    u, z = cmt.query(x, k=1, epsilon=0.0)
                    pred = z[0][1] if len(z) else 0
                    
                loss += 0 if pred == actual else 1
                sincelast += 0 if pred == actual else 1
                
                if (n & (n - 1) == 0): # and n & 0xAAAAAAAA == 0):
                    print('{:<8d}\t{:<8.3f}\t{:<10.3f}\t{:<10d}'.format(
                                loss.n, loss.mean(), sincelast.mean(), pred),
                          flush=True)

                    sincelast = FOC.EasyAcc()
                    
                if n > 0:
                    u, z = cmt.query(x, k=2, epsilon=1.0)
                    if len(z):
                        r = 1 if z[0][1] == actual else -1
                        cmt.update(u, x, z, r)

                cmt.insert(x, actual)

            print('{:<8d}\t{:<8.3f}\t{:<10.3f}\t{:<10d}'.format(
                         loss.n, loss.mean(), sincelast.mean(), pred),
                   flush=True)                
            sincelast = FOC.EasyAcc()
            
#             preds = cls.predict(np.array([np.outer(d, np.append(d, [1])).ravel() for d in test.data]))
#             print(Counter(preds))
#             ascores = []
#             for b in range(16):
#                 bootie = np.random.RandomState(90210+b).choice(len(test.target), replace=True, size=len(test.target))
#                 ascores.append(accuracy_score(y_true=test.target[bootie], y_pred=preds[bootie]))

#             print("test accuracy: {}".format(np.quantile(ascores, [0.05, 0.5, 0.95])))

            pass
                        
def flass():
    import timeit
    print(timeit.timeit(FOC.doit, number=1))
    
flass()

Counter({1: 283301, 0: 211840, 2: 35754, 6: 20510, 5: 17367, 4: 9493, 3: 2747})
n       	emp loss	since last	last pred 
1       	1.000   	1.000     	0         
2       	1.000   	1.000     	6         
3       	1.000   	1.000     	6         
5       	0.800   	0.500     	0         
9       	0.556   	0.250     	1         
17      	0.588   	0.625     	1         
33      	0.515   	0.438     	6         
65      	0.523   	0.531     	1         
129     	0.426   	0.328     	1         
257     	0.440   	0.453     	1         
513     	0.427   	0.414     	1         
1025    	0.390   	0.354     	1         
2049    	0.370   	0.351     	0         
4097    	0.373   	0.376     	1         
8193    	0.363   	0.352     	2         
16385   	0.359   	0.355     	6         
32769   	0.362   	0.365     	0         


KeyboardInterrupt: 