# online nets

Deep learning is powerful but computationally expensive, frequently requiring massive compute budgets. In persuit of cost-effective-yet-powerful AI, this work explores and evaluates a heuristic which should lend to more-efficient use of data through online learning.

Goal: evaluate a deep learning alternative capable of true online learning. Solution requirements:

1. catastrophic forgetting should be impossible;
2. all data is integrated into sufficient statistics of fixed dimension;
3. and our solution should have predictive power comparable to deep learning.

## modeling strategy

We will not attempt to derive sufficient statistics for an entire deep net, but instead leverage well-known sufficient statistics for least squares models, 
so will have sufficient statistics per deep net layer. If this can be empirically shown effective, we'll build-out the theory afterwards. 


# model code definitions

In [20]:
import torch

class OnlineDenseLayer:
    '''
    A single dense net, formulated as a least squares model. 
    '''
    def __init__(self, p, q, activation=lambda x:x, activation_inverse=lambda x:x, lam=1.): 
        '''
        inputs:
        - p: input dimension
        - q: output dimension
        - activation: non-linear function, from R^p to R^q. Default is identity
        - activation_inverse: inverse of the activation function. Default is identity. 
        - lam: regularization term 
        '''
        self.__validate_inputs(p=p, q=q, lam=lam) 
        self.p = p
        self.q = q
        self.activation = activation 
        self.activation_inverse = activation_inverse 
        self.lam = lam 
        self.xTy = torch.zeros(p,q) 
        self.yTx = torch.zeros(p,q) 
        self.xTx_inv = torch.diag(torch.tensor([lam]*p)) 
        self.yTy_inv = torch.diag(torch.tensor([lam]*q)) 
        self.betaT_forward = torch.matmul(self.xTx_inv, self.xTy)
        self.betaT_forward = torch.transpose(self.betaT_forward, 0, 1) 
        self.betaT_backward = torch.matmul(self.yTy_inv, self.yTx) 
        self.betaT_backward = torch.transpose(self.betaT_backward, 0, 1) 
        self.x_forward = None 
        self.y_forward = None 
        self.x_backward = None 
        self.y_backward = None 
        pass 
    def forward_predict(self, x): 
        'creates and stores x_forward and y_forward, then returns activation(y_forward)'
        self.__validate_inputs(x=x, p=self.p)
        ## TODO prepend 1 for intercept 
        self.x_forward = x 
        self.y_forward = torch.matmul(self.betaT_forward, x) 
        return self.activation(self.y_forward) 
    def backward_predict(self, y):
        'creates and stores x_backward and y_backward, then returns y_backward'
        y = self.activation_inverse(y) 
        self.__validate_inputs(y=y, q=self.q)
        self.y_backward = y
        self.x_backward = torch.matmul(self.betaT_backward, y) 
        return self.x_backward 
    def forward_fit(self): 
        'uses x_forward and y_backward to update forward model, then returns Sherman Morrison denominator'
        self.__validate_inputs(x=self.x_forward, y=self.y_backward) 
        self.xTx_inv, sm_denom = self.__sherman_morrison(self.xTx_inv, self.x_forward, self.x_forward) 
        self.xTy += torch.matmul(self.x_forward, torch.transpose(self.y_backward, 0, 1)) 
        self.betaT_forward = torch.matmul(self.xTx_inv, self.xTy) 
        self.betaT_forward = torch.transpose(self.betaT_forward, 0, 1) 
        return sm_denom 
    def backward_fit(self):
        'uses x_backward and y_forward to update backward model, then returns Sherman Morrison denominator'
        self.yTy_inv, sm_denom = self.__sherman_morrison(self.yTy_inv, self.y_forward, self.y_forward) 
        self.yTx += torch.matmul(self.y_forward, torch.transpose(self.x_backward, 0, 1)) 
        self.betaT_forward = torch.matmul(self.yTy_inv, self.yTx)
        self.betaT_forward = torch.transpose(self.betaT_forward, 0, 1) 
        return sm_denom 
    def __sherman_morrison(self, inv_mat, vec1, vec2):
        '''
        applies Sherman Morrison updates, (mat + vec1 vec2^T)^{-1}
        inputs:
        - inv_mat: an inverted matrix 
        - vec1: a column vector 
        - vec2: a column vector 
        returns:
        - updated matrix
        - the Sherman Morrison denominator, for tracking numerical stability 
        '''
        v2t = torch.transpose(vec2, 0, 1)
        denominator = 1. + torch.matmul(torch.matmul(v2t, inv_mat), vec1) 
        numerator = torch.matmul(torch.matmul(inv_mat, vec1), torch.matmul(v2t, inv_mat)) 
        updated_inv_mat = inv_mat - numerator / denominator 
        return updated_inv_mat, float(denominator) 
    def __validate_inputs(self, p=None, q=None, lam=None, x=None, y=None):
        'raises value exceptions if provided parameters are invalid'
        if q is not None:
            if not isinstance(q, int):
                raise ValueError('`q` must be int!')
            if q <= 0:
                raise ValueError('`q` must be greater than zero!')
        if p is not None:
            if not isinstance(p, int): 
                raise ValueError('`p` must be int!')
            if p <= 0: 
                raise ValueError('`p` must be greater than zero!')
        if lam is not None:
            if not (isinstance(lam, float) or isinstance(lam, int)):
                raise ValueError('`lam` must be float or int!')
            if lam < 0:
                raise ValueError('`lam` must be non-negative!')
        if x is not None and p is not None: 
            if type(x) != torch.tensor:
                raise ValueError('`x` must be of type `torch.tensor`!') 
            if list(x.shape) != [p,1]: 
                raise ValueError('`x.shape` must be `[p,1]`') 
            pass 
        if y is not None and q is not None: 
            if type(y) != torch.tensor:
                raise ValueError('`y` must be of type `torch.tensor`!') 
            if list(y.shape) != [q,1]: 
                raise ValueError('`y.shape` must be `[q,1]`') 
            pass 
        pass 
    pass

class OnlineNet: 
    'online, sequential dense net' 
    def __init__(self, layer_list): 
        ## validate inputs 
        if type(layer_list) != list: 
            raise ValueError('`layer_list` must be of type list!') 
        for layer in layer_list: 
            if not issubclass(type(layer), OnlineDenseLayer):
                raise ValueError('each item in `layer_list` must be an instance of a subclass of `OnlineDenseLayer`!') 
        ## assign 
        self.layer_list = layer_list 
        pass 
    def forward(self, x): 
        'predict forward'
        for layer in self.layer_list:
            x = layer.forward(x) 
        return x 
    def backward(self, y):
        'predict backward'
        for layer in reversed(self.layer_list): 
            y = layer.backward(y) 
        return y 
    def fit(self): 
        'assumes layers x & y targets have already been set. Returns Sherman Morrison denominators per layer in (forward, backward) pairs in a list'
        sherman_morrison_denominator_list = [] 
        for layer in self.layer_list:
            forward_smd = layer.forward_fit() 
            backward_smd = layer.backward_fit() 
            sherman_morrison_denominator_list.append((forward_smd, backward_smd))
        return sherman_morrison_denominator_list 
    def __reduce_sherman_morrison_denominator_list(self, smd_pair_list):
        'returns the value closest to zero'
        if type(smd_pair_list) != list: 
            raise ValueError('`smd_pair_list` must be of type `list`!')
        if len(smd_pair_list) == 0:
            return None 
        smallest_smd = None 
        for smd_pair in smd_pair_list:
            if type(smd_pair) != tuple:
                raise ValueError('`smd_pair_list` must be list of tuples!')
            if smallest_smd is None: 
                smallest_smd = smd_pair[0] 
            if abs(smallest_smd) > abs(smd_pair[0]): 
                smallest_smd = smd_pair[0] 
            if abs(smallest_smd) > abs(smd_pair[1]):
                smallest_smd = smd_pair[1] 
        return smallest_smd 
    def __call__(self, x, y=None): 
        '''
        If only x is given, a prediction is made and returned.
        If x and y are given, then the model is updated, and returns
        - the prediction
        - the sherman morrison denominator closest to zero, for tracking numerical stability
        '''
        y_hat = self.forward(x) 
        if y is None: 
            return y_hat 
        self.backward(y) 
        self.x_forward = x 
        self.x_backward = x 
        self.y_forward = y 
        self.y_backward = y 
        smd_pair_list = self.fit() 
        smallest_smd = self.__reduce_sherman_morrison_denominator_list(smd_pair_list) 
        return y_hat, smallest_smd 

# first experiment: mnist classification

In [21]:
from tqdm import tqdm
from torchvision import datasets, transforms

transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])

dataset1 = datasets.MNIST('../../data', train=True, download=True, transform=transform)
dataset2 = datasets.MNIST('../../data', train=False, transform=transform)

train_loader = torch.utils.data.DataLoader(dataset1)
test_loader = torch.utils.data.DataLoader(dataset2)

d = 1*1*28*28 + 1 + 10 ## data dim, +1 for intercept, +10 for one-hot encoding 
l = 100 
p = d+l 
iters = 10 ## number of recurrent net iterations to run per classification 
n_labels = 10 
regularizer = 1000

model = OnlineNet(
    [OnlineDenseLayer(p=1*1*28*28, q=1000), ## TODO add activation functions 
    OnlineDenseLayer(p=1000, q=5000), 
    OnlineDenseLayer(p=5000, q=100), 
    OnlineDenseLayer(p=100, q=n_labels)] 
)

def build_data(image, label, latent_vec): 
    'format data from iterator for model' 
    y = torch.tensor([1. if int(label[0]) == idx else 0. for idx in range(n_labels)]) ## one-hot representation 
    x = image.reshape([-1]) ## flatten 
    ## shrink so sigmoid inverse is well-defined 
    y = y*.90 + .05 
    return x, y 

errors = [] 
lat_sums = []
pbar = tqdm(train_loader)
for [image, label] in pbar:
    x, y = build_data(image, label, latent_vec) 
    x0 = x 
    ## fit 
    for _ in range(iters): 
        y_pred = model.predict(x) 
        y_pred[d:] = torch.sigmoid(y_pred[d:])
        #print(f'DEBUG 1 y_pred.sum(): {y_pred.sum()}') 
        y_target = model.build_y(x, y_pred, y[:d]) 
        ## update labels before fitting 
        y_target[:n_labels] = y[:n_labels] 
        model.fit(x, y_target) 
        ## recurse 
        x = y_pred 
    ## train error 
    for _ in range(iters): 
        x0 = model.predict(x0) 
    error = (x0[:n_labels] - y[:n_labels]).abs().sum() 
    errors.append(error) 
    lat_sum = latent_vec.abs().sum()
    lat_sums.append(lat_sum)
    pbar.set_description(f'error: {error}, {lat_sum}')
    ## keep latent vec for next iteration 
    latent_vec = torch.sigmoid(x0[(p-l):])

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1000x1000 and 784x1000)

In [19]:
x = torch.tensor([[1.,2.], [3., 5.], [6., 7.]])
list(x.shape) == [3,2]

True

## scratch space

$x_1, X_2, X_3, \ldots, X_{p-1}, x_p = y$

$\beta_1^F, \beta_1^B, \beta_2^F, \beta_2^B, \ldots, \beta_{p-1}^F, \beta_{p-1}^B$

$\hat x_{j+1} = \sigma \left( \beta_j^{FT} x_j \right)$

forward series: $x_1, \hat x_2, \hat x_3, \ldots, \hat x_{p-1}, x_p = y$

$ \hat \beta_{p-1}^{FT} = \text{argmin}_\beta \| x_p - \beta^{T} \hat x_{p-1} \|^2 $

$ \hat \beta_{p-2}^{FT} = \text{argmin}_\beta \| \hat x_{p-1} - \beta^{T} \hat x_{p-2} \|^2 $ We won't do this. 

$ \tilde x_{j-1} = \sigma^{-1}\left( \beta_j^{BT} x_j \right)$

backward series: $x_1, \tilde x_2, \tilde x_3, \ldots, \tilde x_{p-1}, x_p = y$

$ \hat \beta_{p-2}^{F} = \text{argmin}_\beta \| \tilde x_{p-1} - \beta^{T} \hat x_{p-2} \|^2 $

$ \hat \beta_{p-2}^{B} = \text{argmin}_\beta \| \hat x_{p-2} - \beta^{T} \tilde x_{p-1} \|^2 $

$ \hat x_3 = \sigma \left( \beta_2^{FT} \hat x_2 \right) $

$ \hat \beta_2^F = \text{argmin}_\beta \| \sigma^{-1}\left( \hat x_3 \right) - \beta^T \hat x_2 \|^2 $ Useless without $\tilde x$

$ \tilde x_2 = \sigma^{-1}\left( \beta_3^{BT} \tilde x_3 \right) $

$ \hat \beta_2^B = \text{argmin}_\beta \| \sigma\left( \hat x_2 \right) - \beta^T \hat x_3 \|^2 $ Useless without $\hat x$

So, use these estimates instead.

$ \hat \beta_2^F = \text{argmin}_\beta \| \sigma^{-1}\left( \tilde x_3 \right) - \beta^T \hat x_2 \|^2 $

$ \hat \beta_2^B = \text{argmin}_\beta \| \sigma\left( \hat x_2 \right) - \beta^T \tilde x_3 \|^2 $

$\pi$