In [1]:
%load_ext autoreload
%autoreload 2

In [32]:
import matplotlib.pyplot as plt
import numpy as np
import torch

from janelia_core.ml.extra_torch_modules import Bias
from janelia_core.ml.latent_regression import LatentRegModel
from janelia_core.ml.latent_regression import LinearMap
from janelia_core.ml.latent_regression import IdentityMap
from janelia_core.visualization.matrix_visualization import cmp_n_mats

## Parameters for model go here

In [3]:
d_in = [100, 100]
d_out = [1000, 1000]

d_proj = [2, 2]
d_trans = [2, 2]

n_smps = 10000

In [4]:
## Create the model here
n_output_groups = len(d_in)

#M = LinearMap(d_proj, d_trans)
M = IdentityMap()
S = [Bias(d_o) for d_o in d_out]
mdl = LatentRegModel(d_in, d_out, d_proj, d_trans, M, S, direct_pairs=None, w_gain=50, 
                    noise_range=[50.0, 100.0])

In [28]:
## Modify u[0] to make it sparse
u_orig = mdl.u[0].detach().numpy()
u_orig[np.abs(u_orig) < 4] = 0
plt.imshow(u_orig, aspect='auto')
plt.colorbar()

<matplotlib.colorbar.Colorbar at 0x1290c734860>

## Generate some x data here

In [7]:
x = [torch.randn([n_smps, d]) for d in d_in]

## Run the model forward 

In [8]:
y_pred = mdl(x)
y = mdl.generate(x)

## Fit a new model

In [18]:
M_fitted = IdentityMap()
#M_fitted = LinearMap(d_proj, d_trans)
S_fitted = [Bias(d_o) for d_o in d_out]
fitted_mdl = LatentRegModel(d_in, d_out, d_proj, d_trans, M_fitted, S_fitted, direct_pairs=None)

## Move model to cuda

In [19]:
fitted_mdl=fitted_mdl.cuda()
x = [x_i.pin_memory() for x_i in x]
y = [y_i.pin_memory() for y_i in y]

In [20]:
log = fitted_mdl.fit(x, y, max_its=10000, batch_size=10, send_size=10, update_int=1000, min_var = .01,
                     learning_rates=.1, l1_u_lambda = [1000., 0])

0: Elapsed fitting time 0.0, vl: 217626060.0, lr: 0.1
1000: Elapsed fitting time 9.821837902069092, vl: 36857.582, lr: 0.1
2000: Elapsed fitting time 19.287567377090454, vl: 44719.543, lr: 0.1
3000: Elapsed fitting time 28.655341625213623, vl: 36784.48, lr: 0.1
4000: Elapsed fitting time 37.593058347702026, vl: 39045.504, lr: 0.1
5000: Elapsed fitting time 46.574596643447876, vl: 32046.197, lr: 0.1
6000: Elapsed fitting time 55.66513514518738, vl: 32770.348, lr: 0.1
7000: Elapsed fitting time 64.70155382156372, vl: 24451.375, lr: 0.1
8000: Elapsed fitting time 73.79538345336914, vl: 21560.908, lr: 0.1
9000: Elapsed fitting time 82.80181837081909, vl: 21265.014, lr: 0.1
9999: Elapsed fitting time 92.07241249084473, vl: 21131.854


In [21]:
%matplotlib qt
plt.figure()
plt.plot(log['elapsed_time'], log['obj'])

[<matplotlib.lines.Line2D at 0x1290ea2e198>]

## Compare predictions from fitted model to ground truth

In [22]:
fitted_mdl = fitted_mdl.to('cpu')
x = [x_g.to('cpu') for x_g in x]
y = [y_h.to('cpu') for y_h in y]

plt_smps = np.arange(1000)

%matplotlib qt
true_y_hat = mdl(x)
fitted_y_hat = fitted_mdl(x)

for g in range(n_output_groups):
    plt.figure()
    for d in range(5):
        plt.subplot(1, 5, d+1)
        true_obs_values = y[g][plt_smps, d].detach().numpy()
        true_plt_values = true_y_hat[g][plt_smps, d].detach().numpy()
        fitted_plt_values = fitted_y_hat[g][plt_smps, d].detach().numpy()
        plt.plot(true_obs_values, 'go')
        plt.plot(true_plt_values, 'ro')
        plt.plot(fitted_plt_values, 'b-')

## Visualize parameters

In [37]:
u_est = u_h = np.abs(fitted_mdl.u[0].detach().cpu().numpy())
u_true = u_h = np.abs(mdl.u[0].detach().cpu().numpy())

In [38]:
plt.figure()
cmp_n_mats([u_est, u_true, u_est - u_true], titles=['Estimated', 'True', 'Diff'])

[<matplotlib.axes._subplots.AxesSubplot at 0x1290d31da90>,
 <matplotlib.axes._subplots.AxesSubplot at 0x1290d3c8c18>,
 <matplotlib.axes._subplots.AxesSubplot at 0x1291fac6048>]

In [35]:
print(cmp_n_mats.__doc__)

 Produces a figuring comparing matrices.

    Each matrix will be plotted in a separate axes, and all axes will use the same color scaling.

    Args:
        mats: A list of matrices to compare.

        clim: An list of length 2 of color limits to apply to all images.  If None is provided, will be set to
        [min_v, max_v] where min_v and max_v are the min and max values in all of the matrices

        show_colorbars: True if colorbars should be shown next to each plot

        titles: A list of tiles for each matrix in mats.  If None, no titles will be generated

        grid_info: A dictionary with information about how to layout the matrices in a grid.  It should have the entries:
            grid_spec: The matplotlib.gridspec.GridSpec to use for the grid
            cell_info: A list the same length as mats.  cell_info[i] contains:
                loc: The location for the subplot for the i^th matrix
                rowspan: The row span for the i^th matrix
                co