In [1]:
import sys
import time
import random
import csv
import numpy as np

import scipy.io
from scipy.io import loadmat

import torch
import pickle

from tqdm import tqdm
from tqdm import notebook

import matplotlib
import matplotlib.pyplot as plt

import importlib
import utils
importlib.reload(utils)
TORCH_DTYPE = torch.float64 #NB: Basically all of the matrices in Spatial_GP have 1.e-7 added to the diagonal, to be changed if we want to use float64
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
torch.set_default_dtype(TORCH_DTYPE)
torch.set_default_device(device)
print(f'Device is: {device}')


Using device: cuda:0 (from utils.py)
Using device: cuda:0 (from utils.py)
Device is: cuda:0


### How this works:

We choose initial training points to be considered a random set of images to present to the retina. On this set, the complete GP will be run , to find STA paramaters eps_0: center, beta: width , rho: smoothness

We run the algorithm saving the seed used for picking the training images, so that this set can be changed and different runs with different initial conditions can be averaged.

1. Import the dataset and create a total training set X,R
2. Pick the cell and the initial training points, extracted randomly. These correspond also to the number of inducing points
3. Save the seed so you can keep the initial training set, and the fitted model
3. 


### Parameters of the training

In [2]:
rand_xtilde = True # If True, xtilde (inducing points) are chosen randomly, if False, xtilde is chosen from the first ntilde images

cellid      = 8   # Choose cell
ntrain_start = 50 # Number of first training data points

kernfun     = utils.acosker # Choose kernel function

Nmstep  = 6   # Total number of M-steps iterations. 
Nestep  = 8   # Total number of E-steps iterations.
Maxiter = 10  # Iterations of the optimization algorithm comprising M and E steps

ntilde       = ntrain_start

### Import dataset and generate starting dataset

Create starting dataset on which to train with m step with ntilde = ntrain_start

In [3]:
# Open the .pkl dataset file for reading in binary mode (rb)
with open('/home/idv-eqs8-pza/IDV_code/Variational_GP/spatial_GP/Data/data2_41mixed_tr28.pkl', 'rb') as file:
    # Load the data from the file
    loaded_data = pickle.load(file)
    # loaded_data is a Dataset object from module Data with attributes "images_train, _val, _test" as well as responses

X_train = torch.tensor(loaded_data.images_train).to(device, dtype=TORCH_DTYPE) #shape (2910,108,108,1) where 108 is the number of pixels. 2910 is the amount of training points
X_val   = torch.tensor(loaded_data.images_val).to(device, dtype=TORCH_DTYPE)
X_test  = torch.tensor(loaded_data.images_test).to(device, dtype=TORCH_DTYPE) # shape (30,108,108,1) # nimages, npx, npx

R_train = torch.tensor(loaded_data.responses_train).to(device, dtype=TORCH_DTYPE) #shape (2910,41) 2910 is the amount of training data, 41 is the number of cells
R_val   = torch.tensor(loaded_data.responses_val).to(device, dtype=TORCH_DTYPE)
R_test  = torch.tensor(loaded_data.responses_test).to(device, dtype=TORCH_DTYPE) # shape (30,30,42) 30 repetitions, 30 images, 42 cells

# Create the complete dataset
X = torch.cat( (X_train, X_val), axis=0,) #shape (3160,108,108,1)
R = torch.cat( (R_train, R_val), axis=0,)
n_px_side = X.shape[1]  

# Reshape images to 1D vector  
X = torch.reshape(X, ( X.shape[0], X.shape[1]*X.shape[2])) 

# Choose a random subset of the data and save the idx
torch.manual_seed(0)
torch.cuda.manual_seed(0)

arange = torch.arange(0, X.shape[0])
rndm   = torch.randint(0, X.shape[0], (ntrain_start,)) # these will be the indices of the initial training
combined = torch.cat( (arange, rndm) )
unique, counts = combined.unique(return_counts=True)

remaining_idx = unique[counts==1] # Indices of the remaining data in the dataset (not yet used for training)
start_idx = unique[counts>1]      # Indices of the data used for the initial training
used_idx = start_idx              # Indices of the data used for training, including the initial training
xtilde = X[start_idx,:]
Xremaining = X[remaining_idx,:]
Xused = X[used_idx,:]

# Choose a cell
r = R[:,cellid] # shape (nt,) where nt is the number of trials

### Select cell, starting hyperparameters and firing rate parameters

In [4]:
# For details on the hyperparameters choice see one_cell_fit.ipynb
logbetaexpr = utils.fromlogbetasam_to_logbetaexpr( logbetasam=torch.tensor(5.5) )# Logbetaexpr in this code is equal to logbeta in Samuele's code. Samuele's code set logbeta to 5.5
logrhoexpr  = utils.fromlogrhosam_to_logrhoexpr( logrhosam=torch.tensor(5)) 
# logbetaexpr = torch.tensor(4.65)
# logrhoexpr = torch.tensor(4.3)
logsigma_0 = torch.tensor(0) 
sigma_0    = torch.exp(logsigma_0)
Amp        = torch.tensor(1.0) 

# Hypermarameters, if needed to be set manually
theta = {'sigma_0': sigma_0, 'eps_0x':torch.tensor(0.), 'eps_0y':torch.tensor(0.), '-2log2beta': logbetaexpr, '-log2rho2': logrhoexpr, 'Amp': Amp }
# Set the gradient of the hyperparemters to be updatable 
for key, value in theta.items():
    # to exclude a single hyperparemeters from the optimization ( to exclude them all just set nmstep=0 and dont do the M-step)
    # if key == 'Amp':
        # continue
    theta[key] = value.requires_grad_()

# If hyperparameters are set manually:
hyperparams_tuple = utils.generate_theta( x=X, r=r, n_px_side=n_px_side, display=True, **theta)
# If hyperparameters are set based on the STAs:
# hyperparams_tuple = utils.generate_theta( x=X, r=r, n_px_side=n_px_side, display=True)

# Set parameters of the firing rate function
A        = torch.tensor(0.007)
logA     = torch.log(A)
f_params = {'logA': logA, 'lambda0':torch.tensor(0.31)}

for key, value in f_params.items():
    f_params[key] = value.requires_grad_()

args = {
        'ntilde':  ntrain_start,
        'Maxiter': Maxiter,
        'Nmstep':  Nmstep,
        'Nestep':  Nestep,
        'kernfun': kernfun,
        'n_px_side': n_px_side,
        'display_prog':  False,
        'hyperparams_tuple': hyperparams_tuple,
        'f_params': f_params,
        'xtilde': xtilde,
        'm': torch.zeros( (ntilde) )
    }



updated sigma_0 to 1.0
updated eps_0x to 0.0
updated eps_0y to 0.0
updated -2log2beta to 4.806852819440055
updated -log2rho2 to 4.306852819440055
updated Amp to 1.0
 Before overloading
 Hyperparameters have been SET as  : beta = 0.05856070, rho = 0.02928035
 Samuele hyperparameters           : logbetasam = 4.9822, logrhosam = 7.0617

 After overloading
 Dict of learnable hyperparameters : sigma_0 = 1.00000000, eps_0x = 0.00000000, eps_0y = 0.00000000, -2log2beta = 4.80685282, -log2rho2 = 4.30685282, Amp = 1.00000000
 Hyperparameters from the logexpr  : beta = 0.04520382, rho = 0.08208500
 Samuele hyperparameters           : logbetasam = 5.5000, logrhosam = 5.0000


### Fit the model
And save it needed to start a new active fit

In [5]:
theta_fit, f_params_fit, m_b_fit, V_b_fit, C_fit, mask_fit, K_tilde_b_fit, K_tilde_inv_b_fit, B, values_track_fit = utils.varGP(X, r, **args)

# Save the model. All of the matrices are projected in the eigenspace of big eigenvalues of K_tilde. Indicated by _b
model_b_starting = {
    'theta': theta_fit,
    'f_params': f_params_fit,
    'm_b': m_b_fit,
    'V_b': V_b_fit,
    'C'  : C_fit,
    'mask': mask_fit,
    'K_tilde_inv_b': K_tilde_inv_b_fit,
    'K_tilde_b': K_tilde_b_fit,
    'B': B,
    'cellid' : cellid,
    'ntilde' : ntilde,
    'Maxiter': Maxiter,
    'Nmstep' : Nmstep,
    'Nestep' : Nestep,
    'kernfun': kernfun.__name__,
    'values_track': values_track_fit,
}

# Save the model
with open('models/pietro_model.pkl', 'wb') as file:
    pickle.dump(model_b_starting, file)


*Iteration*: 0
 2940.8513 = logmarginal in m-step
 2924.4308 = logmarginal in m-step
 2924.2237 = logmarginal in m-step
 2924.2237 = logmarginal in m-step
 2915.6985 = logmarginal in m-step
 2915.6705 = logmarginal in m-step
 2915.6705 = logmarginal in m-step
 2914.7421 = logmarginal in m-step
 2914.7321 = logmarginal in m-step
*Iteration*: 1
 2852.9357 = logmarginal in m-step
 2831.7983 = logmarginal in m-step
 2831.5464 = logmarginal in m-step
 2831.5464 = logmarginal in m-step
 2820.8374 = logmarginal in m-step
 2820.7878 = logmarginal in m-step
 2820.7878 = logmarginal in m-step
 2818.3318 = logmarginal in m-step
 2818.3295 = logmarginal in m-step
*Iteration*: 2
 2771.0141 = logmarginal in m-step
 2746.2591 = logmarginal in m-step
 2745.8120 = logmarginal in m-step
 2745.8120 = logmarginal in m-step
 2729.6119 = logmarginal in m-step
 2729.3478 = logmarginal in m-step
 2729.3478 = logmarginal in m-step
 2718.9554 = logmarginal in m-step
 2718.7884 = logmarginal in m-step
*Iteration

### Infer mean and variance of the remaining datapoints, estimate utility of each.

Find the most useful image and its ID

In [70]:
importlib.reload(utils)

start_time = time.time()
with open('models/pietro_model.pkl', 'rb') as file:
    loaded_model = pickle.load(file)

with torch.no_grad():
# reshape the remaining dataset to
# Xremaining = torch.reshape(Xremaining, (Xremaining.shape[0], Xremaining.shape[0]*Xremaining.shape[1]))

    mask  = loaded_model['mask']
    C     = loaded_model['C']
    B     = loaded_model['B']
    K_tilde      = loaded_model['K_tilde_b']
    K_tilde_inv  = loaded_model['K_tilde_inv_b']
    m            = loaded_model['m_b']
    V            = loaded_model['V_b']    
    f_params_fit = loaded_model['f_params']
    A            = torch.exp(f_params_fit['logA'])
    lambda0      = f_params_fit['lambda0']

    sigma2_arr = torch.zeros( Xremaining.shape[0] )
    i_best = 0

    start_time = time.time()
    diff = []
    for i, x_idx in enumerate(remaining_idx):
    # for i in range( 1000 ):
        xstar = Xremaining[i,:][None,:]

        mu_star, sigma2_star = utils.lambda_moments_star(xstar[:,mask], xtilde[:,mask], C, theta, K_tilde, K_tilde_inv, m, V, B, kernfun)

        # The usility function is supposed to work with the moments of log(firing rate) = log(f), which are
        logf_mu     = A*mu_star + lambda0
        logf_sigma2 = A**2 * sigma2_star

        #region estimating real utility
        r_masked = torch.arange(0, 100, dtype=TORCH_DTYPE)

        u = utils.utility(logf_sigma2, logf_mu, r_masked )


        
        # plt.subplot(111)
        # plt.scatter(logf_sigma2.item(), u.item(), color=colors[r_cutoffs.index(r_cutoff)], s=10-r_cutoffs.index(r_cutoff)*5)
        # plt.title(f'r_cutoff: {r_cutoff}')
        print(f'Utility: {u.item():<8.4f} mu_util: {logf_mu.item():>6.4f} sigma2_star: {sigma2_star.item():<8.4f}')    
        #endregion

        sigma2_arr[i] = sigma2_star
        if sigma2_star > sigma2_arr[i_best]:
            i_best = i
            x_idx_best = x_idx



    print(f'Elapsed time: {time.time()-start_time:.2f} seconds')

    print(f'Best image ID: {i_best}: Utility: {u.item():<8.4f} |sigma2: {sigma2_arr[i_best].item():.4f}')

    used_idx = torch.cat( (used_idx, x_idx_best[None]) )
    remaining_idx = arange[~torch.isin( arange, used_idx )]

    X_updated = X[used_idx]
    R_updated = R[used_idx, cellid] 
    
    # The added images are used as inducing points as long as the number of inducing points is less than 200
    if used_idx.shape[0] < 200:
        ntilde = used_idx.shape[0]
        xtilde_updated = X[used_idx]




Using device: cuda:0 (from utils.py)
Utility: 0.0031   mu_util: -0.2420 sigma2_star: 5.8137  
Utility: 0.0008   mu_util: -0.5029 sigma2_star: 1.9434  
Utility: 0.0107   mu_util: -0.6820 sigma2_star: 30.8665 
Utility: 0.0007   mu_util: -0.2193 sigma2_star: 1.2161  
Utility: 0.0011   mu_util: 0.2708 sigma2_star: 1.2104  
Utility: 0.0009   mu_util: -0.0601 sigma2_star: 1.3350  
Utility: 0.0082   mu_util: -0.7382 sigma2_star: 25.1079 
Utility: 0.0039   mu_util: -0.4249 sigma2_star: 8.8169  
Utility: 0.0018   mu_util: 0.3366 sigma2_star: 1.8997  
Utility: 0.0197   mu_util: -0.9388 sigma2_star: 72.3782 
Utility: 0.0015   mu_util: -0.8314 sigma2_star: 4.9652  
Utility: 0.0008   mu_util: -0.5144 sigma2_star: 1.9825  
Utility: 0.0058   mu_util: -0.8397 sigma2_star: 19.7359 
Utility: 0.0056   mu_util: -0.5816 sigma2_star: 14.5910 
Utility: 0.0008   mu_util: -0.1708 sigma2_star: 1.3553  
Utility: 0.0253   mu_util: 1.6869 sigma2_star: 7.1140  
Utility: 0.0032   mu_util: -1.0658 sigma2_star: 13.805