# Code demonstrating exponential reduced rank models

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import numpy.linalg
import torch

from janelia_core.ml.reduced_rank_models import RRExpModel

## Setup ground truth model

In [3]:
d_in, d_out, d_latent, n_smps = 10, 5, 2, 12000

In [4]:
true_model = RRExpModel(d_in, d_out, d_latent)
true_model.generate_random_model()

## Generate data

In [5]:
x = .1*torch.randn([n_smps, d_in])
y = true_model.generate(x)

## Create model we will fit to data

In [6]:
fitted_model = RRExpModel(d_in, d_out, d_latent)
fitted_model.init_weights(y)
#fitted_model.to('cuda')

## Fit model to data

In [7]:
adam_params = {'lr': .001, 'betas': [.9, .999]} # betas are for mean and second moment (.9, .999)
fitted_model.fit(x,y, batch_size=100, send_size=100,
                 max_its=20000, update_int=1000, adam_params=adam_params)

0: Elapsed time 0.002688884735107422, vl: 24.297294921875
1000: Elapsed time 1.022035837173462, vl: 4.206455688476563
2000: Elapsed time 2.0445728302001953, vl: 3.502809753417969
3000: Elapsed time 3.1456828117370605, vl: 2.7750323486328123
4000: Elapsed time 4.1845996379852295, vl: 2.0274948120117187
5000: Elapsed time 5.173861742019653, vl: 1.6162147521972656
6000: Elapsed time 6.176631689071655, vl: 1.2902523803710937
7000: Elapsed time 7.209240674972534, vl: 1.2164424133300782
8000: Elapsed time 8.26312780380249, vl: 1.4680133056640625
9000: Elapsed time 9.253959894180298, vl: 1.2944786071777343
10000: Elapsed time 10.311578750610352, vl: 1.4995962524414062
11000: Elapsed time 11.311465978622437, vl: 1.68470458984375
12000: Elapsed time 12.340572834014893, vl: 1.4539814758300782
13000: Elapsed time 13.428228855133057, vl: 1.5150009155273438
14000: Elapsed time 14.375159740447998, vl: 1.649640350341797
15000: Elapsed time 15.455029964447021, vl: 1.6051492309570312
16000: Elapsed tim

## See fitted model results

In [8]:
%matplotlib qt
true_model.standardize()
fitted_model.standardize()
fitted_model.to('cpu')
true_model.compare_models(true_model, fitted_model, x[0:100,:], [0, 1])

In [9]:
plt.figure()
plt.plot(true_model.w1[:,0].detach().numpy())
plt.plot(fitted_model.w1[:,0].detach().numpy())

[<matplotlib.lines.Line2D at 0x1319d85f8>]