# Code demonstrating basic non-linear reduced rank models

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import copy

import matplotlib.pyplot as plt
import numpy as np
import numpy.linalg
import torch

from janelia_core.ml.reduced_rank_models import RRReluModel

## Setup ground truth model

In [3]:
d_in, d_out, d_latent, n_smps = 100, 100, 10, 10000

In [4]:
true_model = RRReluModel(d_in, d_out, d_latent)
true_model.generate_random_model(var_range=[.1, 1])

## Generate data

In [5]:
x = 3*torch.randn([n_smps, d_in])
y = true_model.generate(x)

## Create model we will fit to data

In [37]:
fitted_model = RRReluModel(d_in, d_out, d_latent)
fitted_model.init_weights(y, w_gain=.0001)
fitted_model.to('cuda')

RRReluModel()

## Fit model to data

In [38]:
adam_params = {'lr': .001, 'betas': [.9, .999]} # betas are for mean and second moment (.9, .999)
fitted_model.fit(x,y, batch_size=100, send_size=100,
                 max_its=5000, update_int=1000, adam_params=adam_params)

0: Elapsed time 0.01566171646118164, vl: 381.3066796875
1000: Elapsed time 6.203508138656616, vl: 150.381474609375
2000: Elapsed time 12.174859285354614, vl: 139.591357421875
3000: Elapsed time 18.187827110290527, vl: 132.56421875
4000: Elapsed time 24.21998167037964, vl: 126.295751953125


## See fitted model results

In [39]:
%matplotlib qt
true_model.standardize()
fitted_model.standardize()
fitted_model.to('cpu')
RRReluModel.compare_models(true_model, fitted_model, x[0:100,:], [0,1,2,3])
fitted_model.to('cuda')

RRReluModel()

In [33]:
w0_true = true_model.w0.cpu().detach().numpy()
w1_true = true_model.w1.cpu().detach().numpy()

w0_fitted = fitted_model.w0.cpu().detach().numpy()
w1_fitted = fitted_model.w1.cpu().detach().numpy()

w_true = np.matmul(w1_true, w0_true.T)
w_fitted = np.matmul(w1_fitted, w0_fitted.T)

In [34]:
from janelia_core.visualization.matrix_visualization import cmp_n_mats

In [21]:
cmp_n_mats([w_true[0:100, 0:100], w_fitted[0:100, 0:100], w_true[0:100, 0:100]-w_fitted[0:100, 0:100]])

[<matplotlib.axes._subplots.AxesSubplot at 0x1c4fe32b240>,
 <matplotlib.axes._subplots.AxesSubplot at 0x1c5086b4940>,
 <matplotlib.axes._subplots.AxesSubplot at 0x1c5086e6198>]