In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import torch

from janelia_core.ml.extra_torch_modules import Bias
from janelia_core.ml.latent_regression import LatentRegModel
from janelia_core.ml.latent_regression import LinearMap

## Parameters for model go here

In [3]:
d_in = [5, 6, 7]
d_out = [5, 8, 9]

d_proj = [1, 1, 1]
d_trans = [1, 1, 2]

n_smps = 20000

In [4]:
## Create the model here
n_output_groups = len(d_in)

M = LinearMap(d_proj, d_trans)
S = [[Bias(1) for i in range(d_out[h])] for h in range(n_output_groups)]
mdl = LatentRegModel(d_in, d_out, d_proj, d_trans, M, S, direct_pairs=None)

## Generate some x data here

In [5]:
x = [torch.randn([n_smps, d]) for d in d_in]

## Run the model forward 

In [6]:
y_pred = mdl(x)
y = mdl.generate(x)

## Fit a new model

In [7]:
M_fitted = LinearMap(d_proj, d_trans)
S_fitted = [[Bias(1) for i in range(d_out[h])] for h in range(n_output_groups)]
fitted_mdl = LatentRegModel(d_in, d_out, d_proj, d_trans, M_fitted, S_fitted, direct_pairs=None)

In [8]:
log = fitted_mdl.fit(x, y, max_its=10000, batch_size=100, update_int=1000, min_var = .01)

0: Elapsed fitting time 0.0, vl: 6628.026, lr: 0.01
1000: Elapsed fitting time 3.4695756435394287, vl: 12.79518, lr: 0.01
2000: Elapsed fitting time 6.693968772888184, vl: -5.688615, lr: 0.01
3000: Elapsed fitting time 9.905351638793945, vl: -8.532728, lr: 0.01
4000: Elapsed fitting time 13.126751899719238, vl: -10.888018, lr: 0.01
5000: Elapsed fitting time 16.47635793685913, vl: -12.217566, lr: 0.01
6000: Elapsed fitting time 19.871877908706665, vl: -12.651084, lr: 0.01
7000: Elapsed fitting time 23.266780853271484, vl: -14.891417, lr: 0.01
8000: Elapsed fitting time 26.68431830406189, vl: -13.93134, lr: 0.01
9000: Elapsed fitting time 30.09081244468689, vl: -14.4672, lr: 0.01
9999: Elapsed fitting time 33.49433898925781, vl: -14.49562


In [9]:
%matplotlib qt
plt.figure()
plt.plot(log['elapsed_time'], log['obj'])

[<matplotlib.lines.Line2D at 0x1b35cad8160>]

## Compare predictions from fitted model to ground truth

In [10]:
plt_smps = np.arange(1000)

%matplotlib qt
true_y_hat = mdl(x)
fitted_y_hat = fitted_mdl(x)

for g in range(n_output_groups):
    plt.figure()
    for d in range(d_out[g]):
        plt.subplot(1, d_out[g], d+1)
        true_plt_values = true_y_hat[g][plt_smps, d].detach().numpy()
        fitted_plt_values = fitted_y_hat[g][plt_smps, d].detach().numpy()
        plt.plot(true_plt_values, 'ro')
        plt.plot(fitted_plt_values, 'b-')