Notebook for development and testing of code for the second version of fitting latent regression models across multiple subjects with variational inference.  The main advance in version 2.0 of the code is the ability to support distributions across additional model parameters (not just the modes). 

In particular we generate models of how one neural population drives another as follows:

1) The user specified a number of subjects and how many neurons are in each population for each of those subjects. Neuron locations for each subject are than randomly drawn from a uniform distribution on the unit square. 

2) Our models include only neural dynamics (no stimulus input or behavioral output) and we use an identity mapping in 
the low d space

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import copy
import matplotlib.pyplot as plt
import numpy as np
import torch

from janelia_core.ml.datasets import TimeSeriesDataset
from janelia_core.ml.latent_regression.group_maps import GroupLinearTransform, IdentityMap
from janelia_core.ml.latent_regression.subject_models import LatentRegModel, SharedMLatentRegModel
from janelia_core.ml.latent_regression.vi import MultiSubjectVIFitter
from janelia_core.ml.latent_regression.vi import PriorCollection
from janelia_core.ml.latent_regression.vi import SubjectVICollection
from janelia_core.ml.latent_regression.vi import predict_with_truth
from janelia_core.ml.torch_distributions import CondGaussianDistribution
from janelia_core.ml.torch_distributions import CondMatrixHypercubePrior
from janelia_core.ml.torch_distributions import CondMatrixProductDistribution
from janelia_core.ml.torch_distributions import MatrixGaussianProductDistribution
from janelia_core.ml.torch_parameter_penalizers import ScalarPenalizer
from janelia_core.ml.utils import torch_mod_to_fcn
from janelia_core.ml.utils import list_torch_devices
from janelia_core.visualization.image_generation import generate_dot_image_3d
from janelia_core.visualization.image_visualization import visualize_2d_function


In [3]:
%matplotlib notebook

## Parameters and model specification goes here

In [4]:
# Here we specify the number of subjects (by the length of the list) and number of neurons that will be present
# each population for each subject

n_subj_neurons = [(10000, 10000),
                  (9000, 9000),
                  (11000, 11000)]

# Number of samples of data to generate for each subject
n_smps = 20000

# True if we should used shared posteriors among subjects
use_shared_posts = False 


### Parameters for creating hypercube functions

In [5]:
hc_fcn_params = {'n_divisions_per_dim': [50, 50], 
                 'dim_ranges': np.asarray([[-.1, 1.1], [-.1, 1.1]]), 
                 'n_div_per_hc_side_per_dim': [1, 1]}

### Here we specify the mean and standard deviation functions for the different parameters of the models

#### Specify some helper functions

In [6]:
class exp2d(torch.nn.Module):
    def __init__(self, ctr, std, gain, offset):
        #assert(ctr.shape == [1, 2])
        #assert(std.shape == [1,2])
        
        super().__init__()
        self.ctr = torch.nn.Parameter(ctr)
        self.std = torch.nn.Parameter(std)
        self.gain = torch.nn.Parameter(gain)
        self.offset = torch.nn.Parameter(offset)
        
    def forward(self, x):
        return (self.gain*torch.exp(-1*torch.sum((x - self.ctr)**2/self.std, dim=1)) + self.offset).unsqueeze(1)  

class constantF(torch.nn.Module):
    def __init__(self, vl):
        super().__init__()
        self.vl = vl
        
    def forward(self, x):
        return self.vl*torch.ones([x.shape[0], 1])

#### Specify the distributions over p and u modes

Here we implicitly define the number of modes by the number of distributions we define

In [7]:
p_ctrs = [torch.tensor([.1, .1]), torch.tensor([.9, .9])]
true_p_dists = CondMatrixProductDistribution([CondGaussianDistribution(mn_f=exp2d(ctr = c, 
                                                                                  std = torch.tensor([1.0, 1.0]),
                                                                                  gain = torch.tensor(1.0), 
                                                                                  offset = torch.tensor(0.0)),
                                                                        std_f=constantF(.1)) 
                                              for c in p_ctrs])


 
u_ctrs = [torch.tensor([.1, .1]), torch.tensor([.9, .9])]
true_u_dists = CondMatrixProductDistribution([CondGaussianDistribution(mn_f=exp2d(ctr = c, 
                                                                                  std = torch.tensor([1.0, 1.0]),
                                                                                  gain = torch.tensor(1.0), 
                                                                                  offset = torch.tensor(0.0)),
                                                                        std_f=constantF(.1)) 
                                              for c in u_ctrs])


#### Specify the distributions over scales and offsets and direct connections

In [8]:
true_scale_dist = CondGaussianDistribution(mn_f=exp2d(ctr = torch.tensor([.5, .5]), 
                                                    std = torch.tensor([.5, .5]),
                                                    gain = torch.tensor(10.0), 
                                                    offset = torch.tensor(0.0)),
                                        std_f=constantF(.1))

true_offset_dist = CondGaussianDistribution(mn_f=exp2d(ctr = torch.tensor([.5, .5]), 
                                                    std = torch.tensor([1.0, 1.0]),
                                                    gain = torch.tensor(10.0), 
                                                    offset = torch.tensor(0.0)),
                                        std_f=constantF(.1))

true_psi_dist = CondGaussianDistribution(mn_f=exp2d(ctr = torch.tensor([.5, .5]), 
                                                    std = torch.tensor([6.0, 6.0]),
                                                    gain = torch.tensor(.2), 
                                                    offset = torch.tensor(.1)),
                                        std_f=constantF(.01))

true_direct_map_dist = CondGaussianDistribution(mn_f=exp2d(ctr = torch.tensor([.8, .5]), 
                                                    std = torch.tensor([6.0, 6.0]),
                                                    gain = torch.tensor(.2), 
                                                    offset = torch.tensor(.1)),
                                        std_f=constantF(.01))

## Here we generate our true subject models and data

In [9]:
n_modes = len(true_p_dists.dists)
n_subjs = len(n_subj_neurons)
true_subj_models = [None]*n_subjs
true_data = [None]*n_subjs

for s_i in range(n_subjs):
    
    with torch.no_grad():
        # Generate neuron locations
        p_neuron_locs = torch.rand(size=[n_subj_neurons[s_i][0], 2])
        u_neuron_locs = torch.rand(size=[n_subj_neurons[s_i][1], 2])
    
        # Generate modes
        p_modes = true_p_dists.form_standard_sample(true_p_dists.sample(p_neuron_locs))
        u_modes = true_u_dists.form_standard_sample(true_u_dists.sample(u_neuron_locs))
        
        # Generate scales and offsets
        scales = true_scale_dist.form_standard_sample(true_scale_dist.sample(u_neuron_locs)).squeeze()
        offsets = true_offset_dist.form_standard_sample(true_offset_dist.sample(u_neuron_locs)).squeeze()
        
        # Generate direct maps
        direct_mappings = true_direct_map_dist.form_standard_sample(true_direct_map_dist.sample(u_neuron_locs)).squeeze()
        
        # Generate psi
        psi = true_psi_dist.form_standard_sample(true_psi_dist.sample(u_neuron_locs)).squeeze()
        assert(torch.all(psi > 0))
    
        s_mdl = LatentRegModel(d_in = [n_subj_neurons[s_i][0]], d_out = [n_subj_neurons[s_i][1]], 
                               d_proj=[n_modes], d_trans=[n_modes], 
                               m=IdentityMap(),
                               s=[torch.nn.Identity()], 
                               use_scales=True,
                               use_offsets=True,
                               direct_pairs=[(0,0)], 
                               assign_direct_pair_mappings=True)
    
        s_mdl.u[0].data = u_modes
        s_mdl.p[0].data = p_modes
        s_mdl.offsets[0].data = offsets
        s_mdl.scales[0].data = scales
        s_mdl.psi[0].data = psi
        s_mdl.direct_mappings[0].data = direct_mappings
    
        true_subj_models[s_i] = {'mdl': s_mdl, 'p_neuron_locs': p_neuron_locs, 'u_neuron_locs': u_neuron_locs}
    
        
        p_data = [torch.randn(size=[n_smps, n_subj_neurons[s_i][0]])]
        u_data = s_mdl.generate(p_data)
        
        # Delay u data with respect to u data (since we model u_{t+1} as a function of p_t)
        p_data[0] = p_data[0][1:,:]
        u_data[0] = u_data[0][0:-1, :]
        
        
        
        true_data[s_i] = (p_data, u_data)

## Now we set things up for fitting with variational inference

### Define prior distributions

In [10]:
p_prior = CondMatrixHypercubePrior(n_cols=n_modes, mn_hc_params=hc_fcn_params, std_hc_params=hc_fcn_params, 
                                   min_std=.00001)

u_prior = CondMatrixHypercubePrior(n_cols=n_modes, mn_hc_params=hc_fcn_params, std_hc_params=hc_fcn_params, 
                                   min_std=.00001)

scales_prior = CondMatrixHypercubePrior(n_cols=1, mn_hc_params=hc_fcn_params, std_hc_params=hc_fcn_params, 
                                   min_std=.00001, mn_init=1.0)

offsets_prior = CondMatrixHypercubePrior(n_cols=1, mn_hc_params=hc_fcn_params, std_hc_params=hc_fcn_params, 
                                   min_std=.00001, mn_init=0.0)

direct_mappings_prior = CondMatrixHypercubePrior(n_cols=1, mn_hc_params=hc_fcn_params, std_hc_params=hc_fcn_params, 
                                   min_std=.00001, mn_init=0.0)

prior_collection = PriorCollection(p_dists=[p_prior], u_dists=[u_prior], psi_dists=[None], 
                                   scale_dists=[scales_prior], offset_dists=[offsets_prior], 
                                   direct_mapping_dists=[direct_mappings_prior])

### Define subject models and posteriors for each subject

In [11]:
vi_collections = [None]*n_subjs
for s_i in range(n_subjs):
    
    # Create subject model for fitting
    subject_specific_m = GroupLinearTransform(d=[n_modes], nonnegative_scale=True, 
                                              v_mn=1.0, v_std=.01, o_mn=0.0, o_std=.01)
    s_mdl = SharedMLatentRegModel(d_in = [n_subj_neurons[s_i][0]], d_out = [n_subj_neurons[s_i][1]], 
                                  d_proj=[n_modes], d_trans=[n_modes], specific_m=subject_specific_m,
                                  shared_m=IdentityMap(), s=[torch.nn.Identity()],
                                  use_scales=True, use_offsets=True, direct_pairs=[(0,0)],
                                  assign_p_modes=False, assign_u_modes=False, assign_scales=False, assign_offsets=False,
                                  assign_direct_pair_mappings=False,
                                  assign_psi=True) # We will fit point estimates for psi (and not distributions)    
    
    # Create posterior distributions 
    if use_shared_posts:
        if s_i == 0:
            p_post = CondMatrixHypercubePrior(n_cols=n_modes, mn_hc_params=hc_fcn_params, 
                                              std_hc_params=hc_fcn_params, min_std=.00001, 
                                              mn_init=.1)
            u_post = CondMatrixHypercubePrior(n_cols=n_modes, mn_hc_params=hc_fcn_params, 
                                              std_hc_params=hc_fcn_params, min_std=.00001,
                                              mn_init=.1)
            scale_post = CondMatrixHypercubePrior(n_cols=n_modes, mn_hc_params=hc_fcn_params, 
                                                  std_hc_params=hc_fcn_params, min_std=.00001, 
                                                  mn_init=1.0)
            offset_post = CondMatrixHypercubePrior(n_cols=n_modes, mn_hc_params=hc_fcn_params, 
                                                   std_hc_params=hc_fcn_params, min_std=.00001,
                                                   mn_init=0.0)
            direct_mappings_post = CondMatrixHypercubePrior(n_cols=n_modes, mn_hc_params=hc_fcn_params, 
                                                            std_hc_params=hc_fcn_params, min_std=.00001,
                                                            mn_init=0.0)
        else:
            pass # Do nothing, we can just keep using the posteriors we already created for subject 1
    else:
        p_post = MatrixGaussianProductDistribution(shape=[n_subj_neurons[s_i][0], n_modes], mn_mn=.01, mn_std=.001)
        u_post = MatrixGaussianProductDistribution(shape=[n_subj_neurons[s_i][1], n_modes], mn_mn=.01, mn_std=.001)
        scale_post = MatrixGaussianProductDistribution(shape=[n_subj_neurons[s_i][1], 1], mn_mn=1.0, mn_std=.001)
        offset_post = MatrixGaussianProductDistribution(shape=[n_subj_neurons[s_i][1], 1], mn_mn=0.0, mn_std=.001)
        direct_mappings_post = MatrixGaussianProductDistribution(shape=[n_subj_neurons[s_i][1], 1], mn_mn=0.0, mn_std=.001)
    
    # Package data
    data = TimeSeriesDataset([true_data[s_i][0][0], true_data[s_i][1][0]])[:]
    
    vi_collections[s_i] = SubjectVICollection(s_mdl=s_mdl, p_dists=[p_post], u_dists=[u_post], psi_dists=[None],
                                        scale_dists=[scale_post], 
                                        offset_dists=[offset_post],
                                        direct_mappings_dists=[direct_mappings_post],
                                        data=data, input_grps=[0], output_grps=[1], 
                                        props=[true_subj_models[s_i]['p_neuron_locs'], 
                                               true_subj_models[s_i]['u_neuron_locs']],
                                        p_props = [0], u_props=[1], psi_props=[None], 
                                        scale_props=[1], offset_props=[1], 
                                        direct_mapping_props=[1], min_var=[.0001])

## Generate penalizers

In [12]:
subj_v_params = [coll.s_mdl.specific_m.v[0] for coll in vi_collections]
v_penalizer = ScalarPenalizer(params=subj_v_params, w=10000000.0, init_ctr=1.0, learnable_parameters=False, 
                              description='m scales')
subj_o_params = [coll.s_mdl.specific_m.o[0] for coll in vi_collections]
o_penalizer = ScalarPenalizer(params=subj_o_params, w=10000000.0, init_ctr=0.0, learnable_parameters=False, 
                              description='m offsets')

## Create the fitter 

In [13]:
fitter = MultiSubjectVIFitter(s_collections=vi_collections, prior_collection=prior_collection,
                              penalizers=[v_penalizer, o_penalizer], input_modules=None)

## Fit the model

In [14]:
devices, _ = list_torch_devices()

Found 1 GPUs


In [15]:
fitter.distribute(devices, distribute_data=True)

In [None]:
logs0 = fitter.fit(n_epochs=1000, n_batches=2, update_int=10, learning_rates=[(0, .1, {'fast': 1})], 
                  enforce_priors=(use_shared_posts==False))
logs1 = fitter.fit(n_epochs=1000, n_batches=2, update_int=10, learning_rates=.01, 
                  enforce_priors=(use_shared_posts==False))

	nonzero(Tensor input, *, Tensor out)
Consider using one of the following signatures instead:
	nonzero(Tensor input, *, bool as_tuple) (Triggered internally at  /opt/conda/conda-bld/pytorch_1595629403081/work/torch/csrc/utils/python_arg_parser.cpp:766.)
  small_psi_inds = torch.nonzero(s_mdl.psi[h] < s_min_var[h])


*****************************************************
Epoch 0 complete.  Obj: 5.33e+14, LR: [0.1 {'fast': 1}]
Model NLLs:  s_0: 1.77e+14, s_1: 1.42e+14, s_2: 2.14e+14
Subj P KLs:  s_0: 4.77e+05, s_1: 4.29e+05, s_2: 5.68e+05
Subj U KLs:  s_0: 9.11e+03, s_1: 8.23e+03, s_2: 1.00e+04
Subj Psi KLs:  s_0: 0.00e+00, s_1: 0.00e+00, s_2: 0.00e+00
Subj Scale KLs:  s_0: 8.47e+05, s_1: 7.53e+05, s_2: 9.16e+05
Subj Offsets KLs:  s_0: 8.00e+05, s_1: 7.18e+05, s_2: 8.88e+05
Subj Direct Mappings KLs:  s_0: 7.96e+05, s_1: 7.29e+05, s_2: 8.70e+05
Penalties:  p_0: 6.14e+05, p_1: 5.64e+05
m scales state
 Center: [1.]
 Last Penalty: 613611.8125
m offsets state
 Center: [0.]
 Last Penalty: 564233.3671875
Device memory allocated:  d_0: 7.46e+09
Device max memory allocated:  d_0: 9.42e+09
Elapsed time: 0.5881752967834473
*****************************************************
Epoch 10 complete.  Obj: 1.39e+13, LR: [0.1 {'fast': 1}]
Model NLLs:  s_0: 4.53e+12, s_1: 3.65e+12, s_2: 5.74e+12
Subj P KLs:  s_0: 2.62e

*****************************************************
Epoch 110 complete.  Obj: 8.48e+11, LR: [0.1 {'fast': 1}]
Model NLLs:  s_0: 2.75e+11, s_1: 2.24e+11, s_2: 3.49e+11
Subj P KLs:  s_0: 6.51e+04, s_1: 5.84e+04, s_2: 7.08e+04
Subj U KLs:  s_0: 6.28e+04, s_1: 5.56e+04, s_2: 6.79e+04
Subj Psi KLs:  s_0: 0.00e+00, s_1: 0.00e+00, s_2: 0.00e+00
Subj Scale KLs:  s_0: 3.69e+04, s_1: 3.32e+04, s_2: 4.06e+04
Subj Offsets KLs:  s_0: 4.51e+04, s_1: 4.55e+04, s_2: 5.65e+04
Subj Direct Mappings KLs:  s_0: 4.54e+04, s_1: 4.08e+04, s_2: 5.05e+04
Penalties:  p_0: 2.84e+07, p_1: 1.86e+08
m scales state
 Center: [1.]
 Last Penalty: 28441736.0
m offsets state
 Center: [0.]
 Last Penalty: 186404976.0
Device memory allocated:  d_0: 7.46e+09
Device max memory allocated:  d_0: 9.42e+09
Elapsed time: 35.87657880783081
*****************************************************
Epoch 120 complete.  Obj: 7.62e+11, LR: [0.1 {'fast': 1}]
Model NLLs:  s_0: 2.40e+11, s_1: 2.06e+11, s_2: 3.16e+11
Subj P KLs:  s_0: 6.90e+0

*****************************************************
Epoch 220 complete.  Obj: 6.00e+09, LR: [0.1 {'fast': 1}]
Model NLLs:  s_0: 1.96e+09, s_1: 1.51e+09, s_2: 2.53e+09
Subj P KLs:  s_0: 9.22e+04, s_1: 7.43e+04, s_2: 9.00e+04
Subj U KLs:  s_0: 9.67e+04, s_1: 7.82e+04, s_2: 9.62e+04
Subj Psi KLs:  s_0: 0.00e+00, s_1: 0.00e+00, s_2: 0.00e+00
Subj Scale KLs:  s_0: 3.54e+04, s_1: 3.20e+04, s_2: 3.90e+04
Subj Offsets KLs:  s_0: 4.73e+04, s_1: 4.45e+04, s_2: 5.32e+04
Subj Direct Mappings KLs:  s_0: 3.42e+04, s_1: 3.10e+04, s_2: 3.79e+04
Penalties:  p_0: 3.07e+07, p_1: 1.68e+08
m scales state
 Center: [1.]
 Last Penalty: 30738052.0
m offsets state
 Center: [0.]
 Last Penalty: 167663836.0
Device memory allocated:  d_0: 7.46e+09
Device max memory allocated:  d_0: 9.42e+09
Elapsed time: 70.0114758014679
*****************************************************
Epoch 230 complete.  Obj: 5.47e+09, LR: [0.1 {'fast': 1}]
Model NLLs:  s_0: 1.77e+09, s_1: 1.40e+09, s_2: 2.31e+09
Subj P KLs:  s_0: 9.24e+04

*****************************************************
Epoch 330 complete.  Obj: 3.65e+09, LR: [0.1 {'fast': 1}]
Model NLLs:  s_0: 1.20e+09, s_1: 9.60e+08, s_2: 1.49e+09
Subj P KLs:  s_0: 9.52e+04, s_1: 7.84e+04, s_2: 9.50e+04
Subj U KLs:  s_0: 1.01e+05, s_1: 8.25e+04, s_2: 1.02e+05
Subj Psi KLs:  s_0: 0.00e+00, s_1: 0.00e+00, s_2: 0.00e+00
Subj Scale KLs:  s_0: 3.60e+04, s_1: 3.25e+04, s_2: 3.97e+04
Subj Offsets KLs:  s_0: 4.81e+04, s_1: 4.64e+04, s_2: 5.52e+04
Subj Direct Mappings KLs:  s_0: 2.56e+04, s_1: 2.31e+04, s_2: 2.84e+04
Penalties:  p_0: 3.05e+07, p_1: 1.44e+08
m scales state
 Center: [1.]
 Last Penalty: 30527713.0
m offsets state
 Center: [0.]
 Last Penalty: 144407896.0
Device memory allocated:  d_0: 7.46e+09
Device max memory allocated:  d_0: 9.42e+09
Elapsed time: 104.58137249946594
*****************************************************
Epoch 340 complete.  Obj: 3.56e+09, LR: [0.1 {'fast': 1}]
Model NLLs:  s_0: 1.16e+09, s_1: 9.53e+08, s_2: 1.44e+09
Subj P KLs:  s_0: 9.55e+

*****************************************************
Epoch 440 complete.  Obj: 2.92e+09, LR: [0.1 {'fast': 1}]
Model NLLs:  s_0: 9.52e+08, s_1: 7.90e+08, s_2: 1.18e+09
Subj P KLs:  s_0: 9.85e+04, s_1: 8.15e+04, s_2: 9.89e+04
Subj U KLs:  s_0: 1.05e+05, s_1: 8.57e+04, s_2: 1.05e+05
Subj Psi KLs:  s_0: 0.00e+00, s_1: 0.00e+00, s_2: 0.00e+00
Subj Scale KLs:  s_0: 3.77e+04, s_1: 3.41e+04, s_2: 4.15e+04
Subj Offsets KLs:  s_0: 4.91e+04, s_1: 4.86e+04, s_2: 5.74e+04
Subj Direct Mappings KLs:  s_0: 1.58e+04, s_1: 1.43e+04, s_2: 1.76e+04
Penalties:  p_0: 3.03e+07, p_1: 1.20e+08
m scales state
 Center: [1.]
 Last Penalty: 30268243.0
m offsets state
 Center: [0.]
 Last Penalty: 120067848.0
Device memory allocated:  d_0: 7.46e+09
Device max memory allocated:  d_0: 9.42e+09
Elapsed time: 138.85844135284424
*****************************************************
Epoch 450 complete.  Obj: 2.85e+09, LR: [0.1 {'fast': 1}]
Model NLLs:  s_0: 9.45e+08, s_1: 7.63e+08, s_2: 1.14e+09
Subj P KLs:  s_0: 9.88e+

*****************************************************
Epoch 550 complete.  Obj: 2.50e+09, LR: [0.1 {'fast': 1}]
Model NLLs:  s_0: 8.34e+08, s_1: 6.82e+08, s_2: 9.79e+08
Subj P KLs:  s_0: 1.01e+05, s_1: 8.42e+04, s_2: 1.02e+05
Subj U KLs:  s_0: 1.08e+05, s_1: 8.84e+04, s_2: 1.09e+05
Subj Psi KLs:  s_0: 0.00e+00, s_1: 0.00e+00, s_2: 0.00e+00
Subj Scale KLs:  s_0: 3.92e+04, s_1: 3.54e+04, s_2: 4.32e+04
Subj Offsets KLs:  s_0: 5.03e+04, s_1: 5.04e+04, s_2: 5.94e+04
Subj Direct Mappings KLs:  s_0: 8.51e+03, s_1: 7.69e+03, s_2: 9.50e+03
Penalties:  p_0: 3.00e+07, p_1: 9.66e+07
m scales state
 Center: [1.]
 Last Penalty: 30013779.0
m offsets state
 Center: [0.]
 Last Penalty: 96585692.0
Device memory allocated:  d_0: 7.46e+09
Device max memory allocated:  d_0: 9.42e+09
Elapsed time: 172.99810194969177
*****************************************************
Epoch 560 complete.  Obj: 2.44e+09, LR: [0.1 {'fast': 1}]
Model NLLs:  s_0: 8.12e+08, s_1: 6.56e+08, s_2: 9.77e+08
Subj P KLs:  s_0: 1.02e+0

*****************************************************
Epoch 660 complete.  Obj: 2.23e+09, LR: [0.1 {'fast': 1}]
Model NLLs:  s_0: 7.38e+08, s_1: 6.10e+08, s_2: 8.83e+08
Subj P KLs:  s_0: 1.04e+05, s_1: 8.65e+04, s_2: 1.05e+05
Subj U KLs:  s_0: 1.11e+05, s_1: 9.08e+04, s_2: 1.12e+05
Subj Psi KLs:  s_0: 0.00e+00, s_1: 0.00e+00, s_2: 0.00e+00
Subj Scale KLs:  s_0: 4.06e+04, s_1: 3.67e+04, s_2: 4.47e+04
Subj Offsets KLs:  s_0: 5.16e+04, s_1: 5.17e+04, s_2: 6.09e+04
Subj Direct Mappings KLs:  s_0: 5.34e+03, s_1: 4.82e+03, s_2: 5.99e+03
Penalties:  p_0: 2.98e+07, p_1: 7.51e+07
m scales state
 Center: [1.]
 Last Penalty: 29756853.0
m offsets state
 Center: [0.]
 Last Penalty: 75069560.0
Device memory allocated:  d_0: 7.46e+09
Device max memory allocated:  d_0: 9.42e+09
Elapsed time: 207.7509081363678
*****************************************************
Epoch 670 complete.  Obj: 2.23e+09, LR: [0.1 {'fast': 1}]
Model NLLs:  s_0: 7.46e+08, s_1: 5.93e+08, s_2: 8.87e+08
Subj P KLs:  s_0: 1.04e+05

*****************************************************
Epoch 770 complete.  Obj: 2.13e+09, LR: [0.1 {'fast': 1}]
Model NLLs:  s_0: 6.97e+08, s_1: 5.74e+08, s_2: 8.57e+08
Subj P KLs:  s_0: 1.06e+05, s_1: 8.86e+04, s_2: 1.07e+05
Subj U KLs:  s_0: 1.13e+05, s_1: 9.30e+04, s_2: 1.14e+05
Subj Psi KLs:  s_0: 0.00e+00, s_1: 0.00e+00, s_2: 0.00e+00
Subj Scale KLs:  s_0: 4.18e+04, s_1: 3.78e+04, s_2: 4.61e+04
Subj Offsets KLs:  s_0: 5.27e+04, s_1: 5.26e+04, s_2: 6.20e+04
Subj Direct Mappings KLs:  s_0: 4.52e+03, s_1: 4.09e+03, s_2: 5.03e+03
Penalties:  p_0: 2.95e+07, p_1: 5.67e+07
m scales state
 Center: [1.]
 Last Penalty: 29499820.0
m offsets state
 Center: [0.]
 Last Penalty: 56744500.0
Device memory allocated:  d_0: 7.46e+09
Device max memory allocated:  d_0: 9.42e+09
Elapsed time: 241.41956090927124
*****************************************************
Epoch 780 complete.  Obj: 2.11e+09, LR: [0.1 {'fast': 1}]
Model NLLs:  s_0: 7.02e+08, s_1: 5.73e+08, s_2: 8.39e+08
Subj P KLs:  s_0: 1.07e+0

*****************************************************
Epoch 880 complete.  Obj: 2.06e+09, LR: [0.1 {'fast': 1}]
Model NLLs:  s_0: 6.81e+08, s_1: 5.46e+08, s_2: 8.35e+08
Subj P KLs:  s_0: 1.08e+05, s_1: 9.05e+04, s_2: 1.10e+05
Subj U KLs:  s_0: 1.15e+05, s_1: 9.49e+04, s_2: 1.17e+05
Subj Psi KLs:  s_0: 0.00e+00, s_1: 0.00e+00, s_2: 0.00e+00
Subj Scale KLs:  s_0: 4.29e+04, s_1: 3.87e+04, s_2: 4.73e+04
Subj Offsets KLs:  s_0: 5.36e+04, s_1: 5.31e+04, s_2: 6.28e+04
Subj Direct Mappings KLs:  s_0: 4.48e+03, s_1: 4.07e+03, s_2: 4.99e+03
Penalties:  p_0: 2.92e+07, p_1: 4.11e+07
m scales state
 Center: [1.]
 Last Penalty: 29225159.0
m offsets state
 Center: [0.]
 Last Penalty: 41127806.0
Device memory allocated:  d_0: 7.46e+09
Device max memory allocated:  d_0: 9.42e+09
Elapsed time: 275.19532656669617
*****************************************************
Epoch 890 complete.  Obj: 2.05e+09, LR: [0.1 {'fast': 1}]
Model NLLs:  s_0: 6.83e+08, s_1: 5.44e+08, s_2: 8.19e+08
Subj P KLs:  s_0: 1.09e+0

*****************************************************
Epoch 990 complete.  Obj: 2.06e+09, LR: [0.1 {'fast': 1}]
Model NLLs:  s_0: 6.84e+08, s_1: 5.41e+08, s_2: 8.38e+08
Subj P KLs:  s_0: 1.10e+05, s_1: 9.23e+04, s_2: 1.12e+05
Subj U KLs:  s_0: 1.17e+05, s_1: 9.67e+04, s_2: 1.19e+05
Subj Psi KLs:  s_0: 0.00e+00, s_1: 0.00e+00, s_2: 0.00e+00
Subj Scale KLs:  s_0: 4.39e+04, s_1: 3.96e+04, s_2: 4.84e+04
Subj Offsets KLs:  s_0: 5.42e+04, s_1: 5.32e+04, s_2: 6.31e+04
Subj Direct Mappings KLs:  s_0: 4.50e+03, s_1: 4.11e+03, s_2: 5.06e+03
Penalties:  p_0: 2.89e+07, p_1: 2.89e+07
m scales state
 Center: [1.]
 Last Penalty: 28935884.0
m offsets state
 Center: [0.]
 Last Penalty: 28926218.5
Device memory allocated:  d_0: 7.46e+09
Device max memory allocated:  d_0: 9.42e+09
Elapsed time: 309.5412526130676
*****************************************************
Epoch 0 complete.  Obj: 3.17e+10, LR: [0.01]
Model NLLs:  s_0: 9.68e+09, s_1: 9.10e+09, s_2: 1.29e+10
Subj P KLs:  s_0: 1.11e+05, s_1: 9.25e+

*****************************************************
Epoch 100 complete.  Obj: 1.17e+09, LR: [0.01]
Model NLLs:  s_0: 3.95e+08, s_1: 3.26e+08, s_2: 4.43e+08
Subj P KLs:  s_0: 1.18e+05, s_1: 9.92e+04, s_2: 1.20e+05
Subj U KLs:  s_0: 1.26e+05, s_1: 1.05e+05, s_2: 1.28e+05
Subj Psi KLs:  s_0: 0.00e+00, s_1: 0.00e+00, s_2: 0.00e+00
Subj Scale KLs:  s_0: 4.82e+04, s_1: 4.34e+04, s_2: 5.30e+04
Subj Offsets KLs:  s_0: 4.82e+04, s_1: 4.65e+04, s_2: 5.44e+04
Subj Direct Mappings KLs:  s_0: 3.12e+03, s_1: 2.90e+03, s_2: 3.47e+03
Penalties:  p_0: 2.84e+07, p_1: 1.07e+07
m scales state
 Center: [1.]
 Last Penalty: 28426984.0
m offsets state
 Center: [0.]
 Last Penalty: 10674125.25
Device memory allocated:  d_0: 7.46e+09
Device max memory allocated:  d_0: 9.42e+09
Elapsed time: 32.385454416275024
*****************************************************
Epoch 110 complete.  Obj: 1.21e+09, LR: [0.01]
Model NLLs:  s_0: 4.52e+08, s_1: 3.21e+08, s_2: 4.36e+08
Subj P KLs:  s_0: 1.19e+05, s_1: 9.98e+04, s_2

*****************************************************
Epoch 210 complete.  Obj: 1.07e+09, LR: [0.01]
Model NLLs:  s_0: 3.31e+08, s_1: 3.06e+08, s_2: 4.31e+08
Subj P KLs:  s_0: 1.24e+05, s_1: 1.05e+05, s_2: 1.27e+05
Subj U KLs:  s_0: 1.33e+05, s_1: 1.11e+05, s_2: 1.36e+05
Subj Psi KLs:  s_0: 0.00e+00, s_1: 0.00e+00, s_2: 0.00e+00
Subj Scale KLs:  s_0: 5.15e+04, s_1: 4.65e+04, s_2: 5.67e+04
Subj Offsets KLs:  s_0: 4.25e+04, s_1: 3.72e+04, s_2: 4.27e+04
Subj Direct Mappings KLs:  s_0: 2.70e+03, s_1: 2.63e+03, s_2: 3.13e+03
Penalties:  p_0: 2.82e+07, p_1: 2.75e+06
m scales state
 Center: [1.]
 Last Penalty: 28212862.0
m offsets state
 Center: [0.]
 Last Penalty: 2749569.203125
Device memory allocated:  d_0: 7.46e+09
Device max memory allocated:  d_0: 9.42e+09
Elapsed time: 66.2759370803833
*****************************************************
Epoch 220 complete.  Obj: 1.03e+09, LR: [0.01]
Model NLLs:  s_0: 3.30e+08, s_1: 2.95e+08, s_2: 4.09e+08
Subj P KLs:  s_0: 1.24e+05, s_1: 1.05e+05, s_

*****************************************************
Epoch 320 complete.  Obj: 2.26e+09, LR: [0.01]
Model NLLs:  s_0: 1.42e+09, s_1: 4.61e+08, s_2: 3.82e+08
Subj P KLs:  s_0: 1.28e+05, s_1: 1.09e+05, s_2: 1.31e+05
Subj U KLs:  s_0: 1.38e+05, s_1: 1.15e+05, s_2: 1.41e+05
Subj Psi KLs:  s_0: 0.00e+00, s_1: 0.00e+00, s_2: 0.00e+00
Subj Scale KLs:  s_0: 5.40e+04, s_1: 4.87e+04, s_2: 5.93e+04
Subj Offsets KLs:  s_0: 3.69e+04, s_1: 2.92e+04, s_2: 3.33e+04
Subj Direct Mappings KLs:  s_0: 4.06e+03, s_1: 3.26e+03, s_2: 3.38e+03
Penalties:  p_0: 2.79e+07, p_1: 5.49e+05
m scales state
 Center: [1.]
 Last Penalty: 27911699.0
m offsets state
 Center: [0.]
 Last Penalty: 548661.9482421875
Device memory allocated:  d_0: 7.46e+09
Device max memory allocated:  d_0: 9.42e+09
Elapsed time: 100.66066431999207
*****************************************************
Epoch 330 complete.  Obj: 1.71e+09, LR: [0.01]
Model NLLs:  s_0: 7.96e+08, s_1: 4.59e+08, s_2: 4.55e+08
Subj P KLs:  s_0: 1.29e+05, s_1: 1.09e+0

*****************************************************
Epoch 430 complete.  Obj: 1.09e+09, LR: [0.01]
Model NLLs:  s_0: 2.89e+08, s_1: 2.55e+08, s_2: 5.49e+08
Subj P KLs:  s_0: 1.31e+05, s_1: 1.12e+05, s_2: 1.35e+05
Subj U KLs:  s_0: 1.41e+05, s_1: 1.18e+05, s_2: 1.45e+05
Subj Psi KLs:  s_0: 0.00e+00, s_1: 0.00e+00, s_2: 0.00e+00
Subj Scale KLs:  s_0: 5.56e+04, s_1: 5.03e+04, s_2: 6.13e+04
Subj Offsets KLs:  s_0: 2.91e+04, s_1: 2.41e+04, s_2: 2.86e+04
Subj Direct Mappings KLs:  s_0: 2.18e+03, s_1: 2.29e+03, s_2: 2.92e+03
Penalties:  p_0: 2.79e+07, p_1: 9.78e+04
m scales state
 Center: [1.]
 Last Penalty: 27880144.0
m offsets state
 Center: [0.]
 Last Penalty: 97783.88348388672
Device memory allocated:  d_0: 7.46e+09
Device max memory allocated:  d_0: 9.42e+09
Elapsed time: 135.72918391227722
*****************************************************
Epoch 440 complete.  Obj: 1.02e+09, LR: [0.01]
Model NLLs:  s_0: 2.88e+08, s_1: 2.53e+08, s_2: 4.79e+08
Subj P KLs:  s_0: 1.32e+05, s_1: 1.12e+0

*****************************************************
Epoch 540 complete.  Obj: 1.31e+09, LR: [0.01]
Model NLLs:  s_0: 3.08e+08, s_1: 6.28e+08, s_2: 3.72e+08
Subj P KLs:  s_0: 1.34e+05, s_1: 1.14e+05, s_2: 1.38e+05
Subj U KLs:  s_0: 1.44e+05, s_1: 1.21e+05, s_2: 1.49e+05
Subj Psi KLs:  s_0: 0.00e+00, s_1: 0.00e+00, s_2: 0.00e+00
Subj Scale KLs:  s_0: 5.72e+04, s_1: 5.18e+04, s_2: 6.30e+04
Subj Offsets KLs:  s_0: 2.52e+04, s_1: 2.24e+04, s_2: 2.66e+04
Subj Direct Mappings KLs:  s_0: 3.17e+03, s_1: 3.41e+03, s_2: 3.03e+03
Penalties:  p_0: 2.77e+07, p_1: 1.47e+04
m scales state
 Center: [1.]
 Last Penalty: 27749606.0
m offsets state
 Center: [0.]
 Last Penalty: 14733.672973632812
Device memory allocated:  d_0: 7.46e+09
Device max memory allocated:  d_0: 9.42e+09
Elapsed time: 169.35400390625
*****************************************************
Epoch 550 complete.  Obj: 1.02e+09, LR: [0.01]
Model NLLs:  s_0: 3.35e+08, s_1: 3.25e+08, s_2: 3.53e+08
Subj P KLs:  s_0: 1.34e+05, s_1: 1.14e+05,

*****************************************************
Epoch 650 complete.  Obj: 9.50e+08, LR: [0.01]
Model NLLs:  s_0: 2.90e+08, s_1: 2.47e+08, s_2: 4.13e+08
Subj P KLs:  s_0: 1.36e+05, s_1: 1.16e+05, s_2: 1.40e+05
Subj U KLs:  s_0: 1.47e+05, s_1: 1.23e+05, s_2: 1.51e+05
Subj Psi KLs:  s_0: 0.00e+00, s_1: 0.00e+00, s_2: 0.00e+00
Subj Scale KLs:  s_0: 5.84e+04, s_1: 5.28e+04, s_2: 6.44e+04
Subj Offsets KLs:  s_0: 2.30e+04, s_1: 2.14e+04, s_2: 2.63e+04
Subj Direct Mappings KLs:  s_0: 1.71e+03, s_1: 2.13e+03, s_2: 2.78e+03
Penalties:  p_0: 2.75e+07, p_1: 9.99e+02
m scales state
 Center: [1.]
 Last Penalty: 27524110.0
m offsets state
 Center: [0.]
 Last Penalty: 998.8152465820312
Device memory allocated:  d_0: 7.46e+09
Device max memory allocated:  d_0: 9.42e+09
Elapsed time: 204.04998397827148
*****************************************************
Epoch 660 complete.  Obj: 9.39e+08, LR: [0.01]
Model NLLs:  s_0: 2.89e+08, s_1: 2.46e+08, s_2: 4.03e+08
Subj P KLs:  s_0: 1.36e+05, s_1: 1.16e+0

*****************************************************
Epoch 760 complete.  Obj: 9.36e+08, LR: [0.01]
Model NLLs:  s_0: 2.82e+08, s_1: 2.74e+08, s_2: 3.80e+08
Subj P KLs:  s_0: 1.38e+05, s_1: 1.18e+05, s_2: 1.43e+05
Subj U KLs:  s_0: 1.49e+05, s_1: 1.26e+05, s_2: 1.54e+05
Subj Psi KLs:  s_0: 0.00e+00, s_1: 0.00e+00, s_2: 0.00e+00
Subj Scale KLs:  s_0: 5.94e+04, s_1: 5.40e+04, s_2: 6.56e+04
Subj Offsets KLs:  s_0: 2.35e+04, s_1: 2.25e+04, s_2: 2.60e+04
Subj Direct Mappings KLs:  s_0: 2.37e+03, s_1: 3.14e+03, s_2: 2.68e+03
Penalties:  p_0: 2.74e+07, p_1: 9.02e+01
m scales state
 Center: [1.]
 Last Penalty: 27394511.0
m offsets state
 Center: [0.]
 Last Penalty: 90.19266784191132
Device memory allocated:  d_0: 7.46e+09
Device max memory allocated:  d_0: 9.42e+09
Elapsed time: 239.06087160110474
*****************************************************
Epoch 770 complete.  Obj: 9.28e+08, LR: [0.01]
Model NLLs:  s_0: 2.82e+08, s_1: 2.86e+08, s_2: 3.60e+08
Subj P KLs:  s_0: 1.38e+05, s_1: 1.18e+0

*****************************************************
Epoch 870 complete.  Obj: 1.28e+09, LR: [0.01]
Model NLLs:  s_0: 3.46e+08, s_1: 5.22e+08, s_2: 4.16e+08
Subj P KLs:  s_0: 1.40e+05, s_1: 1.20e+05, s_2: 1.45e+05
Subj U KLs:  s_0: 1.51e+05, s_1: 1.28e+05, s_2: 1.56e+05
Subj Psi KLs:  s_0: 0.00e+00, s_1: 0.00e+00, s_2: 0.00e+00
Subj Scale KLs:  s_0: 6.06e+04, s_1: 5.50e+04, s_2: 6.67e+04
Subj Offsets KLs:  s_0: 2.45e+04, s_1: 2.18e+04, s_2: 2.60e+04
Subj Direct Mappings KLs:  s_0: 3.42e+03, s_1: 2.92e+03, s_2: 2.97e+03
Penalties:  p_0: 2.72e+07, p_1: 2.05e+02
m scales state
 Center: [1.]
 Last Penalty: 27180056.0
m offsets state
 Center: [0.]
 Last Penalty: 204.9793930053711
Device memory allocated:  d_0: 7.46e+09
Device max memory allocated:  d_0: 9.42e+09
Elapsed time: 273.64424419403076


In [None]:
fitter.plot_log(logs0[0])

## Move everything to cpu

In [None]:
fitter.to('cpu')

## Look at predictions the models make on training data

In [None]:
s_preds = [predict_with_truth(s_coll, s_coll.data) for s_coll in vi_collections]

In [None]:
plt_s_i = 2
plot_v_i = 3
smp_inds = slice(0, 100)

plt.figure()
plt.plot(s_preds[plt_s_i]['truth'][0][smp_inds, plot_v_i], 'b-')
plt.plot(s_preds[plt_s_i]['pred'][0][smp_inds, plot_v_i], 'r-')

## Look at true and fit offset and scale distributions

In [None]:
# Offsets
#true_dist = true_offset_dist 
#fit_dist = offsets_prior.dists[0] 

# Scales
true_dist = true_scale_dist
fit_dist = scales_prior.dists[0]

# Direct mapings
#true_dist = true_direct_map_dist
#fit_dist = direct_mappings_prior.dists[0]

In [None]:
plt.figure()
visualize_2d_function(torch_mod_to_fcn(true_dist.mn_f), ax=plt.subplot(1,2, 1))
#plt.gca().get_images()[0].set_clim(0.0, 10.0)
visualize_2d_function(torch_mod_to_fcn(fit_dist.mn_f), ax=plt.subplot(1,2, 2), 
                      dim_0_range=[0, .99], dim_1_range=[0, .99])
#plt.gca().get_images()[0].set_clim(0.0, 10.0)

## Look at true and fit offset values compared to posteriors on a single neuron basis

In [None]:
vis_s_i = 2
e_shape = [21, 21, 1]

# Offsets
true_vls = true_subj_models[vis_s_i]['mdl'].offsets[0]
fit_dist = vi_collections[vis_s_i].offset_dists[0]

# Scales
true_vls = true_subj_models[vis_s_i]['mdl'].scales[0]
fit_dist = vi_collections[vis_s_i].scale_dists[0]

# Direct mappings
#true_vls = true_subj_models[vis_s_i]['mdl'].direct_mappings[0]
#fit_dist = vi_collections[vis_s_i].direct_mapping_dists[0]

In [None]:
vis_neuron_locs = true_subj_models[vis_s_i]['u_neuron_locs']
vis_true_offsets = true_vls.detach().numpy()
vis_fit_offsets = fit_dist.dists[0](vis_neuron_locs).detach().numpy()
vis_neuron_locs = 1000*np.concatenate([vis_neuron_locs.numpy(), np.zeros([vis_neuron_locs.shape[0], 1])], axis=1)

In [None]:
true_image = generate_dot_image_3d(image_shape=[1001, 1001, 1], dot_ctrs=vis_neuron_locs, 
                                   dot_vls=vis_true_offsets, 
                     ellipse_shape=e_shape) 

fit_image = generate_dot_image_3d(image_shape=[1001, 1001, 1], dot_ctrs=vis_neuron_locs, dot_vls=vis_fit_offsets, 
                     ellipse_shape=e_shape) 

In [None]:
plt.figure()
plt.subplot(1,2,1)
plt.imshow(np.squeeze(true_image))
plt.colorbar()
plt.subplot(1,2,2)
plt.imshow(np.squeeze(fit_image))
plt.colorbar()

## Look at true and estimated distributions over modes

In [None]:
vis_m = 1

#### Learn a transformation to align modes

In [None]:
u_neuron_locs = true_subj_models[0]['u_neuron_locs']
true_u_modes = true_u_dists(u_neuron_locs).detach().cpu().numpy()
est_u_modes = prior_collection.u_dists[0](u_neuron_locs).detach().cpu().numpy()
#est_u_modes = vi_collections[0].u_dists[0](u_neuron_locs).detach().cpu().numpy()
u_neuron_locs = 1000*np.concatenate([u_neuron_locs.numpy(), np.zeros([u_neuron_locs.shape[0], 1])], axis=1)

p_neuron_locs = true_subj_models[0]['p_neuron_locs']
true_p_modes = true_p_dists(p_neuron_locs).detach().cpu().numpy()
#true_p_modes = true_subj_models[0]['mdl'].p[0].detach().cpu().numpy()
est_p_modes = prior_collection.p_dists[0](p_neuron_locs).detach().cpu().numpy()
#est_p_modes = vi_collections[0].p_dists[0](p_neuron_locs).detach().cpu().numpy()
#est_p_modes = vi_collections[0].s_mdl.p[0].detach().cpu().numpy()
p_neuron_locs = 1000*np.concatenate([p_neuron_locs.numpy(), np.zeros([p_neuron_locs.shape[0], 1])], axis=1)

In [None]:
mode_t = np.linalg.lstsq(est_u_modes, true_u_modes, rcond=None)
mode_t = mode_t[0]

In [None]:
est_u_modes_t = np.matmul(est_u_modes, mode_t)
est_p_modes_t = np.matmul(est_p_modes, np.linalg.inv(mode_t))

In [None]:
true_p_modes.shape

In [None]:
e_shape = [21, 21, 1]

true_u_image = generate_dot_image_3d(image_shape=[1001, 1001, 1], dot_ctrs=u_neuron_locs, 
                                   dot_vls=true_u_modes[:,vis_m], 
                                   ellipse_shape=e_shape) 

fit_u_image = generate_dot_image_3d(image_shape=[1001, 1001, 1], dot_ctrs=u_neuron_locs, 
                                  dot_vls=est_u_modes_t[:,vis_m], 
                                   ellipse_shape=e_shape) 

true_p_image = generate_dot_image_3d(image_shape=[1002, 1002, 1], dot_ctrs=p_neuron_locs, 
                                   dot_vls=true_p_modes[:,vis_m], 
                                   ellipse_shape=e_shape) 

fit_p_image = generate_dot_image_3d(image_shape=[1002, 1002, 1], dot_ctrs=p_neuron_locs, 
                                  dot_vls=est_p_modes[:,vis_m], 
                                  ellipse_shape=e_shape) 

In [None]:
plt.figure()
plt.subplot(1,2,1)
plt.imshow(np.squeeze(true_u_image), clim=[0, 1])
plt.colorbar()
plt.title('True u mode')
plt.subplot(1,2,2)
plt.imshow(np.squeeze(fit_u_image), clim=[0, 1])
plt.title('Est u mode')
plt.colorbar()

plt.figure()
plt.subplot(1,2,1)
plt.imshow(np.squeeze(true_p_image), clim=[0, 1])
plt.colorbar()
plt.title('True p mode')
plt.subplot(1,2,2)
plt.imshow(np.squeeze(fit_p_image), clim=[0, 1.4])
plt.title('Est p mode')
plt.colorbar()
