In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import torch

from janelia_core.ml.extra_torch_modules import Bias
from janelia_core.ml.latent_regression import LatentRegModel
from janelia_core.ml.latent_regression import LinearMap
from janelia_core.ml.latent_regression import IdentityMap

## Parameters for model go here

In [3]:
d_in = [100000, 100000]
d_out = [100000, 100000]

d_proj = [2, 2]
d_trans = [2, 2]

n_smps = 10000

In [4]:
## Create the model here
n_output_groups = len(d_in)

#M = LinearMap(d_proj, d_trans)
M = IdentityMap()
S = [Bias(d_o) for d_o in d_out]
mdl = LatentRegModel(d_in, d_out, d_proj, d_trans, M, S, direct_pairs=None, w_gain=50, 
                    noise_range=[5.0, 10.0])

## Generate some x data here

In [None]:
x = [torch.randn([n_smps, d]) for d in d_in]

## Run the model forward 

In [None]:
y_pred = mdl(x)
y = mdl.generate(x)

## Fit a new model

In [None]:
M_fitted = IdentityMap()
#M_fitted = LinearMap(d_proj, d_trans)
S_fitted = [Bias(d_o) for d_o in d_out]
fitted_mdl = LatentRegModel(d_in, d_out, d_proj, d_trans, M_fitted, S_fitted, direct_pairs=None)

## Move model to cuda

In [None]:
fitted_mdl=fitted_mdl.cuda()
x = [x_i.pin_memory() for x_i in x]
y = [y_i.pin_memory() for y_i in y]

In [None]:
log = fitted_mdl.fit(x, y, max_its=50000, batch_size=10, send_size=10, update_int=1000, min_var = .01,
                     learning_rates=.001)

0: Elapsed fitting time 0.0, vl: 270558820.0, lr: 0.001
1000: Elapsed fitting time 73.83841872215271, vl: 62343080.0, lr: 0.001
2000: Elapsed fitting time 144.23129653930664, vl: 33904956.0, lr: 0.001
3000: Elapsed fitting time 214.32604789733887, vl: 11670952.0, lr: 0.001
4000: Elapsed fitting time 283.36871576309204, vl: 4858177.5, lr: 0.001
5000: Elapsed fitting time 354.6586682796478, vl: 4859257.5, lr: 0.001
6000: Elapsed fitting time 424.3720681667328, vl: 3494480.8, lr: 0.001
7000: Elapsed fitting time 493.31621646881104, vl: 3569555.2, lr: 0.001
8000: Elapsed fitting time 561.1847276687622, vl: 3108593.2, lr: 0.001
9000: Elapsed fitting time 629.228794336319, vl: 2456560.8, lr: 0.001
10000: Elapsed fitting time 699.0594162940979, vl: 1770930.9, lr: 0.001
11000: Elapsed fitting time 768.0774261951447, vl: 1655595.2, lr: 0.001
12000: Elapsed fitting time 837.0835249423981, vl: 1211581.0, lr: 0.001
13000: Elapsed fitting time 907.0052354335785, vl: 1069917.5, lr: 0.001


In [None]:
%matplotlib qt
plt.figure()
plt.plot(log['elapsed_time'], log['obj'])

## Compare predictions from fitted model to ground truth

In [None]:
fitted_mdl = fitted_mdl.to('cpu')
x = [x_g.to('cpu') for x_g in x]
y = [y_h.to('cpu') for y_h in y]

plt_smps = np.arange(1000)

%matplotlib qt
true_y_hat = mdl(x)
fitted_y_hat = fitted_mdl(x)

for g in range(n_output_groups):
    plt.figure()
    for d in range(5):
        plt.subplot(1, 5, d+1)
        true_obs_values = y[g][plt_smps, d].detach().numpy()
        true_plt_values = true_y_hat[g][plt_smps, d].detach().numpy()
        fitted_plt_values = fitted_y_hat[g][plt_smps, d].detach().numpy()
        plt.plot(true_plt_values, 'go')
        plt.plot(true_plt_values, 'ro')
        plt.plot(fitted_plt_values, 'b-')