# Code demonstrating basic non-linear reduced rank models

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import numpy.linalg
import torch

from janelia_core.ml.reduced_rank_models import RRLinearModel

## Setup ground truth model

In [3]:
d_in, d_out, d_latent, n_smps = 10, 5, 1, 1000

In [4]:
true_model = RRLinearModel(d_in, d_out, d_latent)
true_model.generate_random_model()

## Generate data

In [5]:
x = 10*torch.rand([n_smps, d_in])
y = true_model.generate(x)

## Create model we will fit to data

In [6]:
fitted_model = RRLinearModel(d_in, d_out, d_latent)
fitted_model.init_weights(y)
fitted_model.to('cpu')

RRLinearModel()

## Fit model to data

In [7]:
fitted_model.fit(x,y, batch_size=1000, max_its=10000, update_int=1000)

0: Elapsed time 0.23435187339782715, vl: 8.144268
1000: Elapsed time 1.107351541519165, vl: 1.8671443
2000: Elapsed time 1.9531753063201904, vl: 1.8671438
3000: Elapsed time 2.782068967819214, vl: 1.8673681
4000: Elapsed time 3.62626314163208, vl: 1.8671432
5000: Elapsed time 4.486172199249268, vl: 1.8671435
6000: Elapsed time 5.3130927085876465, vl: 1.8673743
7000: Elapsed time 6.141708135604858, vl: 1.8672189
8000: Elapsed time 7.001779556274414, vl: 1.8671439
9000: Elapsed time 7.830469846725464, vl: 1.8671443


## See fitted model results

In [13]:
%matplotlib qt
true_model.standardize()
fitted_model.standardize()
RRLinearModel.compare_models(true_model, fitted_model, x, [0, 1])