# online nets

Deep learning is powerful but computationally expensive, frequently requiring massive compute budgets. In persuit of cost-effective-yet-powerful AI, this work explores and evaluates a heuristic which should lend to more-efficient use of data through online learning.

Goal: evaluate a deep learning alternative capable of true online learning. Solution requirements:

1. catastrophic forgetting should be impossible;
2. all data is integrated into sufficient statistics of fixed dimension;
3. and our solution should have predictive power comparable to deep learning.

## alternative model

Least squares regression (LSR) meets solution requirements 1 and 2. To achieve requirement 3, we'll structure our model similarly to deep learning by having LSR models depend on other LSR models, effectively producing a Gaussian Bayes net with nonlinear activations.
Unfortunately, Bayes nets' latent variables result in computationally intractible integrals, so we'll use a heuristic to "observe" all variables.

Our LSR model can be written as $X_n\beta_n = Y_n$, where
- $X_n \in \mathbb{R}^{n \times p}$ is a matrix of regressor columns and $x_i \in \mathbb{R}^{1 \times p}$ observation rows.
- $Y_n \in \mathbb{R}^{n \times q}$ is a matrix of dependent variable columns and $y_i \in \mathbb{R}^{1 \times q}$ observation rows.
- $\beta_n \in \mathbb{R}^{p \times q}$ is our matrix of regression weights.

First, we'll rephrase our problems as into time series problems, so that we may assume our input and output dimensions are equivalent or $p = q$. While this category of problems clearly covers reinforcement learning, it can also cover more general problems.
For example, given the right network topology and interpretting $X_{n+k}$ as the $n^{th}$ sample's output, we cover feed-forward classification problems as well. Our network will be used as a recurrent net, so we'll use this additional constraint:
- $X_{n+1} := \sigma(Y_n)$ where $\sigma: \mathbb{R}^p \to \mathbb{R}^p$ is an invertible activation function.

To accomodate both observed and latent variables, let
- $O_n \in \mathbb{R}^{p - l}$ be our observed varaibles
- $L_n \in \mathbb{R}^l$ be our latent variables
- $X_n = [O_n \; | \; L_n]$ be a column-concatentation of both observed and latent variables.

Given $(O_n, L_n, O_{n+1})$ observed but $L_{n+1}$ not observed, we now construct a heuristic to guess $L_{n+1}$. Rather simply, we'll run the LSR backwards: $\hat{X}_n := Y_n \beta_n^{-1}$. We'll the use the $\hat{L}_n$ portion of $\hat{X}_n$ as $L_{n+1}$.
Then, we'll fit our model to $Y_n =  \left[ \sigma^{-1}(O_{n+1}) \; \big| \; \hat{L}_n \right]$.

## motivating the heuristic

The advantage of this heuristic over backpropagation is that it supports mathematically guaranteed online learning. However, we must motivate using $\hat{L}^n$ as $L_{n+1}$ in $Y_n$.
The essential idea is that we produce an approximate backpropagation equivalent by allowing $O_{n+1}$ to inform what $L_n$ should have been to produce $O_{n+1}$.
Of course, having $L_{n+1}$ updated requires another iteration to inform $O_{n+2}$, so solutions may require several iterations to converge.

This heuristic is indeed approximate and lacks the mathematical rigour guaranteeing deep learning's success, but it is worth trying because deep learning's computational cost is approaching infeasibility.

## numerical considarations

We'll use the Sherman-Morrison formula (SMF) to derive $\hat{\beta}_n^{-1}$ in a computationally tractible way.

$$(A + uv^T)^{-1} = A^{-1} - \frac{A^{-1} av^T A^{-1}}{1 + v^tA^{-1}u}$$

We will later define our problems in terms of recurrent nets and time series. Input and output dimensions thus equate, so let $p = q$.

With regularization term $\lambda >0$, the L2-regularized estimate of $\beta$ is $\left(\sum_{i=1}^nx_i^tx_i + \lambda \right)^{-1}\sum_{i=1}^n x_i^Ty_i$. However, we'll need an additional inverse, so we must add further regularization.
Take our $\beta$ estimate to be $\hat{\beta}_n = \left(\sum_{i=1}^nx_i^tx_i + \lambda \right)^{-1}\left(\sum_{i=1}^n x_i^Ty_i + \lambda \right)$. We'll derive our SMF-inverse updates with these definitions:
- $A_n := \sum_{i=1}^n x_i^Tx_i + \lambda$
- $A_0 := \lambda I_{p \times p}$
- $B_n := \sum_{i=1}^n x_i^T y_i + \lambda$
- $B_0 := \lambda I_{p \times p}$

With these definitions, we have that $\hat{\beta}_n = A_n^{-1} B_n$ and also that $\hat{\beta}_{n+1} = \left(A_n + x_i^Tx_i \right)^{-1} \left( B_n + x_i^Ty_i \right)$. Applying SMF, we get these identities
$$A_{n+1}^{-1} = A_n^{-1} - \frac{A_n^{-1} x_i^Tx_i A_n^{-1}}{1+x_iA_n^{-1}x_i^T}, \; \; \; B_{n+1}^{-1} = B_n^{-1} - \frac{B_n^{-1} x_i^Ty_i B_n^{-1}}{1+y_iB_n^{-1}x_i^T}$$

So, we have derived our computationally-tractible inverse updates for $A_n$ and $B_n$. Now, since $\hat{\beta}_n = A_n^{-1}B_n$ it is trivial to calculate $\hat{\beta}_n^{-1} = B_n^{-1} A_n$.

## the heuristic algorithm for supervised learning

1. Choose $\lambda > 0$, $p \in \mathbb{N}$, $l \in \{0, 1, \ldots, p-1\}$, sampling distribution $(O_{i+1} \; | \; O_i) \sim F(o_{i+1} \; | \; o_i)$ & $X_0 \sim F(x_0)$, and activation function $\sigma: \mathbb{R}^p \to \mathbb{R}^p$.
2. Set $n \gets 0$, $A_0 \gets \lambda I_{p \times p}$, $A_0^{-1} \gets \lambda^{-1}I_{p \times p}$, $B_0 \gets \lambda I_{p \times p}$, and $B_0^{-1} \gets \lambda^{-1} I_{p \times p}$.
3. Sample $O_n, L_n, O_{n+1}$.
4. Calculate $\hat Y_n = \left[ \sigma^{-1}(O_{n+1}) \; \big|\; \sigma^{-1}( \hat L_{n+1} ) \right] = X_n \hat \beta_n = X_n A_n^{-1} B_n$. 
5. Set $Y_n \gets \left[ \sigma^{-1}(O_{n+1}) \; \big| \; \sigma^{-1}( \hat L_{n+1} ) \right]$
6. Calculate $\hat X_n = \left[ \hat O_{n+1} \; \big|\; \hat L_n \right] \gets Y_n \hat \beta_n^{-1} = Y_n B_n^{-1} A_n$.
7. Set $\hat Y_n \gets \left[ \sigma^{-1}(O_{n+1}) \; \big|\; \hat L_n \right]$.  
8. Set $A_{n+1} \gets A_n + x_i^T x_i$, where $x_i$ is the latest row addition to $X_n$.
9. Set $B_{n+1} \gets B_n + x_i^T y_i$, where $y_i$ is the latest row addition to $Y_n$.
10. Set $A_{n+1}^{-1} \gets A_n^{-1} - \left(A_n^{-1} x_i^T x_i A_n^{-1}\right)/\left(1 + x_i A_n^{-1} x_i^T\right)$. 
11. Set $B_{n+1}^{-1} \gets B_n^{-1} - \left(B_n^{-1} x_i^T y_i B_n^{-1}\right)/\left(1 + y_i B_n^{-1} x_i^T\right)$.
12. Increment $n \gets n + 1$.
13. Return to step 3.


# model code definitions

In [1]:
import torch

class OnlineNetBase:
    '''
    A low-level class defining a recursive neural net employing
    a hueristic to enable true online learning. No activation 
    functions are used. Latent variables are not abstracted-away. 
    Only the essential math is defined as `y = x a^{-1} b`. The 
    first `p-l` dimensions of `x` are assumed observed, and latent
    thereafter. 
    '''
    def __init__(self, p, l, regularizer=0.001):
        '''
        Initialize an Online Net
        inputs:
         - p: input and output vector dimension
         - l: number of latent variables, `<p`
         - regularizer: L2 regualarization value, `>0`
        '''
        ## verify inputs 
        if type(p) != int:
            raise ValueError('`p` must be an int!')
        if type(l) != int:
            raise ValueError('`l` must be an int!')
        if type(regularizer) not in [int, float]:
            raise ValueError('`regularizer` must be an int or float!') 
        if type(regularizer) == int:
            regularizer = float(regularizer)
        if p <= 0:
            raise ValueError('`p` must be `>0`!') 
        if l < 0 or l > p:
            raise ValueError('`l` must satisfy `0 <= l <= p`!')
        if regularizer <= .0:
            raise ValueError('`regularizer` must be `>0`!')
        ## store init variables 
        self.p = p 
        self.l = l 
        self.regularizer = regularizer
        ## construct initial estimates 
        regularizer_vec = torch.tensor([regularizer]*p) 
        inv_regularizer_vec = torch.tensor([1./regularizer]*p) 
        self.a = torch.diag(regularizer_vec)
        self.b = torch.diag(regularizer_vec)
        self.inv_a = torch.diag(inv_regularizer_vec)
        self.inv_b = torch.diag(inv_regularizer_vec)
        pass 
    
    def predict(self, x):
        '''
        Predict y vector from an x vector
        inputs:
         - x: torch.Tensor of shape `[p]`
        outputs:
         - y: torch.Tensor of shape `[p]`
        '''
        ## verify inputs 
        if type(x) != torch.Tensor:
            raise ValueError('`x` must be of type `torch.Tensor`!') 
        if len(x.shape) != 1:
            raise ValueError('`x` must satisfy `len(x.shape) == 1`!') 
        if x.shape[0] != self.p:
            raise ValueError('`x` must satisfy `x.shape[0] == p`!') 
        ## calculate output
        y = x.reshape([1, self.p])
        y = torch.matmul(y, self.inv_a) 
        y = torch.matmul(y, self.b) 
        return y.reshape([self.p]) 
    
    def fit(self, x, y):
        '''
        Update the model to a new observation `y`.
        inputs:
         - x: torch.Tensor of shape `[p]` 
         - y: torch.Tensor of shape `[p]`
        '''
        ## verify inputs 
        if type(x) != torch.Tensor:
            raise ValueError('`x` must be of type `torch.Tensor`!') 
        if len(x.shape) != 1:
            raise ValueError('`x` must satisfy `len(x.shape) == 1`!') 
        if x.shape[0] != self.p:
            raise ValueError('`x` must satisfy `x.shape[0] == p`!') 
        if type(y) != torch.Tensor:
            raise ValueError('`y` must satisfy `type(y) == torch.Tensor`!') 
        if len(y.shape) != 1:
            raise ValueError('`y` must satisfy `len(y.shape) == 1`!') 
        if y.shape[0] != self.p:
            raise ValueError('`y` must satisfy `y.shape[0] == self.p`!') 
        ## update model 
        x_row = x.reshape([1, self.p]) 
        y_row = y.reshape([1, self.p]) 
        xTx = x_row.transpose(0,1).matmul(x_row) 
        xTy = x_row.transpose(0,1).matmul(y_row) 
        self.a = self.a + xTx 
        self.b = self.b + xTy 
        self.inv_a = self.inv_a - (self.inv_a.matmul(xTx).matmul(self.inv_a))/(1. + x_row.matmul(self.inv_a).matmul(x_row.transpose(0,1)))
        self.inv_b = self.inv_b - (self.inv_b.matmul(xTy).matmul(self.inv_b))/(1. + y_row.matmul(self.inv_b).matmul(x_row.transpose(0,1)))
        pass 
    
    def build_y(self, x, y_pred):
        '''
        Applies the heuristic, updating latent values in `y`. 
        inputs:
         - x: torch.Tensor of shape `[p]` 
         - y_pred: torch.Tensor of shape `[p]`, output of `predict`
        outputs:
         - y_updated: torch.Tensor of shape `[p]`
        '''
        ## verify inputs 
        if type(x) != torch.Tensor:
            raise ValueError('`x` must be of type `torch.Tensor`!') 
        if len(x.shape) != 1:
            raise ValueError('`x` must satisfy `len(x.shape) == 1`!') 
        if x.shape[0] != self.p:
            raise ValueError('`x` must satisfy `x.shape[0] == p`!') 
        if type(y_pred) != torch.Tensor:
            raise ValueError('`y_pred` must satisfy `type(y_pred) == torch.Tensor`!') 
        if len(y_pred.shape) != 1:
            raise ValueError('`y_pred` must satisfy `len(y_pred.shape) == 1`!') 
        if y_pred.shape[0] != self.p:
            raise ValueError('`y_pred` must satisfy `y_pred.shape[0] == self.p`!') 
        ## apply heuristic 
        x_pred = y_pred.matmul(self.inv_b).matmul(self.a).reshape([self.p]) 
        y_updated_observed = y_pred[:(self.p - self.l)] 
        y_updated_latent = x_pred[(self.p - self.l):] 
        y_updated = torch.cat([y_updated_observed, y_updated_latent]) 
        return y_updated 
    pass 


# first experiment: mnist classification

In [3]:
from torchvision import datasets, transforms

transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])

dataset1 = datasets.MNIST('../data', train=True, download=True, transform=transform)
dataset2 = datasets.MNIST('../data', train=False, transform=transform)

train_loader = torch.utils.data.DataLoader(dataset1)
test_loader = torch.utils.data.DataLoader(dataset2)

d = 1*1*28*28 + 1 + 10 ## data dim, +1 for intercept, +10 for one-hot encoding 
l = 100 
p = d+l 
iters = 10 ## number of recurrent net iterations to run per classification 
n_labels = 10 

model = OnlineNetBase(p=p, l=l)
latent_vec = torch.tensor([0.]*l)
intercept = torch.tensor([1.]) 
zeros = torch.tensor([0.]*n_labels) 

def build_data(image, label, latent_vec): 
    'format data from iterator for model' 
    label_one_hot = torch.tensor([1. if int(label[0]) == idx else 0. for idx in range(n_labels)]) ## one-hot representation 
    image = image.reshape([-1]) ## flatten 
    x = torch.cat([zeros, intercept, image, latent_vec]) 
    y = torch.cat([label_one_hot, intercept, image, latent_vec]) 
    return x, y 

errors = [] 
for [image, label] in train_loader: 
    x, y = build_data(image, label, latent_vec) 
    x0 = x 
    ## fit 
    for _ in range(iters): 
        y_pred = model.predict(x) 
        print(f'DEBUG 1 y_pred.sum(): {y_pred.sum()}') 
        y_target = model.build_y(x, y_pred) 
        ## update labels before fitting 
        y_target[:n_labels] = y[:n_labels] 
        model.fit(x, y_target) 
        ## recurse 
        x = y_pred 
    ## train error 
    for _ in range(iters): 
        x0 = model.predict(x0) 
    error = (x0[:n_labels] - y[:n_labels]).abs().sum() 
    errors.append(error) 
    print(f'error: {error}') 
    ## keep latent vec for next iteration 
    latent_vec = x0[(p-l):] 

DEBUG 1 y_pred.sum(): 18.761709213256836
DEBUG 1 y_pred.sum(): 25.63625144958496
DEBUG 1 y_pred.sum(): -75.59857177734375
DEBUG 1 y_pred.sum(): -321.5128479003906
DEBUG 1 y_pred.sum(): -8966.6083984375
DEBUG 1 y_pred.sum(): -1024271.875
DEBUG 1 y_pred.sum(): -73408003964928.0
DEBUG 1 y_pred.sum(): -4.9896184247613905e+28
DEBUG 1 y_pred.sum(): nan
DEBUG 1 y_pred.sum(): nan
error: nan
DEBUG 1 y_pred.sum(): nan
DEBUG 1 y_pred.sum(): nan
DEBUG 1 y_pred.sum(): nan
DEBUG 1 y_pred.sum(): nan
DEBUG 1 y_pred.sum(): nan
DEBUG 1 y_pred.sum(): nan
DEBUG 1 y_pred.sum(): nan
DEBUG 1 y_pred.sum(): nan
DEBUG 1 y_pred.sum(): nan
DEBUG 1 y_pred.sum(): nan
error: nan
DEBUG 1 y_pred.sum(): nan
DEBUG 1 y_pred.sum(): nan
DEBUG 1 y_pred.sum(): nan
DEBUG 1 y_pred.sum(): nan
DEBUG 1 y_pred.sum(): nan
DEBUG 1 y_pred.sum(): nan
DEBUG 1 y_pred.sum(): nan
DEBUG 1 y_pred.sum(): nan
DEBUG 1 y_pred.sum(): nan
DEBUG 1 y_pred.sum(): nan
error: nan
DEBUG 1 y_pred.sum(): nan
DEBUG 1 y_pred.sum(): nan
DEBUG 1 y_pred.sum()

KeyboardInterrupt: 

In [92]:
for _ in range(iters): 
    x0 = model.predict(x0)
x0[:n_labels]

tensor([  -256., 136832.,  36864.,  66304.,  27264.,  72064.,  69248.,   -768.,
         26368.,  71168.])