In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import torch

from janelia_core.ml.extra_torch_modules import Bias
from janelia_core.ml.latent_regression import LatentRegModel
from janelia_core.ml.latent_regression import LinearMap
from janelia_core.ml.latent_regression import IdentityMap

## Parameters for model go here

In [3]:
d_in = [100000, 100000]
d_out = [100000, 100000]

d_proj = [2, 2]
d_trans = [2, 2]

n_smps = 10000

In [4]:
## Create the model here
n_output_groups = len(d_in)

#M = LinearMap(d_proj, d_trans)
M = IdentityMap()
S = [Bias(d_o) for d_o in d_out]
mdl = LatentRegModel(d_in, d_out, d_proj, d_trans, M, S, direct_pairs=None, w_gain=50, 
                    noise_range=[50.0, 100.0])

## Generate some x data here

In [5]:
x = [torch.randn([n_smps, d]) for d in d_in]

## Run the model forward 

In [6]:
y_pred = mdl(x)
y = mdl.generate(x)

## Fit a new model

In [7]:
M_fitted = IdentityMap()
#M_fitted = LinearMap(d_proj, d_trans)
S_fitted = [Bias(d_o) for d_o in d_out]
fitted_mdl = LatentRegModel(d_in, d_out, d_proj, d_trans, M_fitted, S_fitted, direct_pairs=None)

## Move model to cuda

In [8]:
fitted_mdl=fitted_mdl.cuda()
x = [x_i.pin_memory() for x_i in x]
y = [y_i.pin_memory() for y_i in y]

In [9]:
log = fitted_mdl.fit(x, y, max_its=50000, batch_size=10, send_size=10, update_int=1000, min_var = .01,
                     learning_rates=.001)

0: Elapsed fitting time 0.0, vl: 485553570.0, lr: 0.001
1000: Elapsed fitting time 68.71637725830078, vl: 69868080.0, lr: 0.001
2000: Elapsed fitting time 135.69833779335022, vl: 35667212.0, lr: 0.001
3000: Elapsed fitting time 201.03882336616516, vl: 17077674.0, lr: 0.001
4000: Elapsed fitting time 267.7347774505615, vl: 14029090.0, lr: 0.001
5000: Elapsed fitting time 332.0230073928833, vl: 10224408.0, lr: 0.001
6000: Elapsed fitting time 397.14851212501526, vl: 8705989.0, lr: 0.001
7000: Elapsed fitting time 464.0400605201721, vl: 6943261.0, lr: 0.001
8000: Elapsed fitting time 528.8197736740112, vl: 8925332.0, lr: 0.001
9000: Elapsed fitting time 594.2801003456116, vl: 5345249.5, lr: 0.001
10000: Elapsed fitting time 659.87824010849, vl: 4924152.0, lr: 0.001
11000: Elapsed fitting time 724.3649067878723, vl: 4196821.0, lr: 0.001
12000: Elapsed fitting time 790.3232901096344, vl: 3306787.8, lr: 0.001
13000: Elapsed fitting time 856.0745935440063, vl: 3177473.8, lr: 0.001
14000: Elap

In [10]:
%matplotlib qt
plt.figure()
plt.plot(log['elapsed_time'], log['obj'])

[<matplotlib.lines.Line2D at 0x17637d5af28>]

## Compare predictions from fitted model to ground truth

In [11]:
fitted_mdl = fitted_mdl.to('cpu')
x = [x_g.to('cpu') for x_g in x]
y = [y_h.to('cpu') for y_h in y]

plt_smps = np.arange(1000)

%matplotlib qt
true_y_hat = mdl(x)
fitted_y_hat = fitted_mdl(x)

for g in range(n_output_groups):
    plt.figure()
    for d in range(5):
        plt.subplot(1, 5, d+1)
        true_obs_values = y[g][plt_smps, d].detach().numpy()
        true_plt_values = true_y_hat[g][plt_smps, d].detach().numpy()
        fitted_plt_values = fitted_y_hat[g][plt_smps, d].detach().numpy()
        plt.plot(true_obs_values, 'go')
        plt.plot(true_plt_values, 'ro')
        plt.plot(fitted_plt_values, 'b-')