# Code demonstrating basic non-linear reduced rank regression

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import numpy.linalg
import torch

from janelia_core.ml.non_linear_rr_regression import NonLinearRRRegresion

## Setup ground truth model

In [3]:
d_in, d_out, d_latent, n_smps = 5, 2, 1, 1000

In [4]:
true_model = NonLinearRRRegresion(d_in, d_out, d_latent)
true_model.init_weights()

## Generate data

In [5]:
x = 3*torch.rand([n_smps, d_in])
y = true_model.generate(x)

## Create model we will fit to data

In [6]:
fitted_model = NonLinearRRRegresion(d_in, d_out, d_latent)
fitted_model.init_weights()

## Fit model to data

### Setup optimizer

In [7]:
optimizer = torch.optim.Adam(fitted_model.parameters(), lr=.0001)

### Fit the model

In [23]:
for i in range(100000):
    optimizer.zero_grad()
    mns = fitted_model(x.data)
    nll = fitted_model.neg_log_likelihood(y.data, mns)
    nll.backward()
    optimizer.step()
    if i % 10000 == 0: 
        print(str(i) + ': ' + str(nll))

0: tensor(-596.9268, grad_fn=<ThAddBackward>)
10000: tensor(-596.9265, grad_fn=<ThAddBackward>)
20000: tensor(-596.9265, grad_fn=<ThAddBackward>)
30000: tensor(-596.9264, grad_fn=<ThAddBackward>)
40000: tensor(-596.9267, grad_fn=<ThAddBackward>)
50000: tensor(-596.9263, grad_fn=<ThAddBackward>)
60000: tensor(-596.9265, grad_fn=<ThAddBackward>)
70000: tensor(-596.9265, grad_fn=<ThAddBackward>)
80000: tensor(-596.9266, grad_fn=<ThAddBackward>)
90000: tensor(-596.9266, grad_fn=<ThAddBackward>)


## Standardize fitted model

In [26]:
true_model.standardize()
fitted_model.standardize()

In [33]:
true_model.o2

Parameter containing:
tensor([[ 0.1590],
        [-0.6123]], requires_grad=True)

In [51]:
fitted_y_pred = fitted_model(x)
true_y_pred = true_model(x)

In [52]:
%matplotlib qt
for i in range(d_out):
    plt.plot(fitted_y_pred.detach().numpy()[:, i])
    plt.plot(true_y_pred.detach().numpy()[:,i], 'o')

In [46]:
y_pred.detach().numpy()[:,0]

array([2.5213258, 2.5004637, 2.4576912, 2.4506242, 2.39124  , 2.4571948,
       2.524058 , 2.5083914, 2.5068321, 2.5083055, 2.523325 , 2.5241075,
       2.4992495, 2.5077837, 2.4875846, 2.433589 , 2.5199747, 2.5229208,
       2.5186596, 2.4962378, 2.522494 , 2.5205393, 2.5138154, 2.5179913,
       2.5214345, 2.521766 , 2.5154588, 2.5217237, 2.500745 , 2.5228767,
       2.444665 , 2.5174341, 2.5151272, 2.523124 , 2.5184429, 2.41749  ,
       2.518446 , 2.517745 , 2.5129542, 2.5240316, 2.5220778, 2.51514  ,
       2.5020812, 2.5236766, 2.5240703, 2.5214515, 2.5225945, 2.3544936,
       2.5129848, 2.5239315, 2.5193515, 2.5182672, 2.485341 , 2.517187 ,
       2.5227256, 2.509751 , 2.5214734, 2.287788 , 2.5162997, 2.5240479,
       2.5216334, 2.37201  , 2.4891155, 2.2969468, 2.488935 , 2.5217392,
       2.4590187, 2.448059 , 2.4874482, 2.5232246, 2.5178142, 2.5207422,
       2.5236647, 2.5197964, 2.5233173, 2.5236292, 2.4151235, 2.5233717,
       2.5073595, 2.52051  , 2.3309822, 2.4420936, 

In [None]:
torch.matmul(fitted_model.w1, torch.t(fitted_model.w0))

In [None]:
true_model.g

In [41]:
y_pred.detach().numpy()

array([[ 2.5213258 , -0.599764  ],
       [ 2.5004637 , -0.5994371 ],
       [ 2.4576912 , -0.5944052 ],
       ...,
       [ 2.438923  , -0.5884206 ],
       [ 2.5121608 , -0.59970623],
       [ 2.400144  , -0.5603526 ]], dtype=float32)

In [42]:
y_pred.shape

torch.Size([1000, 2])