In [1]:
import sys
import time
import random
import csv
import numpy as np

import scipy.io
from scipy.io import loadmat

import torch
import pickle

from tqdm import tqdm
from tqdm import notebook

import matplotlib
import matplotlib.pyplot as plt

import importlib
import utils

# Enable autoreload
%load_ext autoreload
%autoreload 2
importlib.reload(utils)

# Set seeds for reproducibility
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
if torch.cuda.is_available():
    torch.cuda.manual_seed(0)
    torch.cuda.manual_seed_all(0)  # if you are using multi-GPU.
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark     = False


TORCH_DTYPE = torch.float64 #NB: Basically all of the matrices in Spatial_GP have 1.e-7 added to the diagonal, to be changed if we want to use float64
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
torch.set_default_dtype(TORCH_DTYPE)
torch.set_default_device(device)
print(f'Device is: {device}')


Using device: cuda:0 (from utils.py)
Using device: cuda:0 (from utils.py)
Device is: cuda:0


### How this works:

We choose initial training points to be considered a random set of images to present to the retina. On this set, the complete GP will be run , to find STA paramaters eps_0: center, beta: width , rho: smoothness

We run the algorithm saving the seed used for picking the training images, so that this set can be changed and different runs with different initial conditions can be averaged.

1. Import the dataset and create a total training set X,R
2. Pick the cell and the initial training points, extracted randomly. These correspond also to the number of inducing points
3. Save the seed so you can keep the initial training set, and the fitted model
3. 


### Parameters of the training

In [2]:
rand_xtilde = True # If True, xtilde (inducing points) are chosen randomly, if False, xtilde is chosen from the first ntilde images

cellid       = 8   # Choose cell
ntrain_start = 50  # Number of first training data points

kernfun = 'acosker' # Choose kernel function

nMstep      = 6   # Total number of M-steps iterations. 
nEstep      = 8   # Total number of E-steps iterations.
nFparamstep = 5  
maxiter     = 10  # Iterations of the optimization algorithm comprising M and E steps

ntilde  = ntrain_start

### Import dataset and generate starting dataset

Create starting dataset on which to train with m step with ntilde = ntrain_start

In [3]:
# Open the .pkl dataset file for reading in binary mode (rb)
with open('/home/idv-eqs8-pza/IDV_code/Variational_GP/spatial_GP/Data/data2_41mixed_tr28.pkl', 'rb') as file:
    # Load the data from the file
    loaded_data = pickle.load(file)
    # loaded_data is a Dataset object from module Data with attributes "images_train, _val, _test" as well as responses

X_train = torch.tensor(loaded_data.images_train).to(device, dtype=TORCH_DTYPE) #shape (2910,108,108,1) where 108 is the number of pixels. 2910 is the amount of training points
X_val   = torch.tensor(loaded_data.images_val).to(device, dtype=TORCH_DTYPE)
X_test  = torch.tensor(loaded_data.images_test).to(device, dtype=TORCH_DTYPE) # shape (30,108,108,1) # nimages, npx, npx

R_train = torch.tensor(loaded_data.responses_train).to(device, dtype=TORCH_DTYPE) #shape (2910,41) 2910 is the amount of training data, 41 is the number of cells
R_val   = torch.tensor(loaded_data.responses_val).to(device, dtype=TORCH_DTYPE)
R_test  = torch.tensor(loaded_data.responses_test).to(device, dtype=TORCH_DTYPE) # shape (30,30,42) 30 repetitions, 30 images, 42 cells

# Create the complete dataset
X = torch.cat( (X_train, X_val), axis=0,) #shape (3160,108,108,1)
R = torch.cat( (R_train, R_val), axis=0,)
n_px_side = X.shape[1]  

# Reshape images to 1D vector and choose a cell
X = torch.reshape(X, ( X.shape[0], X.shape[1]*X.shape[2])) 
R = R[:,cellid] # shape (nt,) where nt is the number of trials

# Choose a random subset of the data and save the idx
torch.manual_seed(0)
torch.cuda.manual_seed(0)

all_idx  = torch.arange(0, X.shape[0])                   # Indices of the whole dataset
rndm_idx = torch.randint(0, X.shape[0], (ntrain_start,)) # These will be the indices of the initial training
combined = torch.cat( (all_idx, rndm_idx) )                   # Combine the indices, there are now some duplicates
unique, counts = combined.unique(return_counts=True)

remaining_idx = unique[counts==1]                        # Indices of the remaining data in the dataset (not yet used for training)
start_idx     = unique[counts>1]                         # Indices of the data used for the initial training
used_idx      = start_idx                                # Indices of the data used for training, including the initial training

# Set the starting set
xtilde_start  = X[start_idx,:]                           # In the simplest case the starting points are all inducing points
X_remaining   = X[remaining_idx,:]
X_used        = X[used_idx,:]

R_remaining = R[remaining_idx]
R_used      = R[used_idx]



### Select cell, starting hyperparameters and firing rate parameters

In [4]:
# For details on the hyperparameters choice see one_cell_fit.ipynb
logbetaexpr = utils.fromlogbetasam_to_logbetaexpr( logbetasam=torch.tensor(5.5) )# Logbetaexpr in this code is equal to logbeta in Samuele's code. Samuele's code set logbeta to 5.5
logrhoexpr  = utils.fromlogrhosam_to_logrhoexpr( logrhosam=torch.tensor(5)) 
# logbetaexpr = torch.tensor(4.65)
# logrhoexpr = torch.tensor(4.3)
logsigma_0 = torch.tensor(0) 
sigma_0    = torch.exp(logsigma_0)
Amp        = torch.tensor(1.0) 
eps_0x = torch.tensor(0.0001)
eps_0y = torch.tensor(0.0001)
# Hypermarameters, if needed to be set manually
theta = {'sigma_0': sigma_0, 'eps_0x':eps_0x, 'eps_0y':eps_0y, '-2log2beta': logbetaexpr, '-log2rho2': logrhoexpr, 'Amp': Amp }

# Set the gradient of the hyperparemters to be updatable 
for key, value in theta.items():
    # to exclude a single hyperparemeters from the optimization ( to exclude them all just set nmstep=0 and dont do the M-step)
    # if key == 'Amp':
        # continue
    theta[key] = value.requires_grad_()

# If hyperparameters are set manually:
hyperparams_tuple = utils.generate_theta( x=X_used, r=R_used, n_px_side=n_px_side, display=True, **theta)
# If hyperparameters are set based on the STAs:
# hyperparams_tuple = utils.generate_theta( x=X, r=r, n_px_side=n_px_side, display=True)

A        = torch.tensor(0.007)
logA     = torch.log(A)
# lambda0  = torch.tensor(0.31)
lambda0  = torch.tensor(1)
f_params = {'logA': logA, 'loglambda0':torch.log(lambda0)}

for key, value in f_params.items():
    f_params[key] = value.requires_grad_()

args = {
        'cellid': cellid,                       # sending also cellid to the fit function to make sure we dont lose track of it
        'ntilde':  ntilde,
        'maxiter': maxiter,
        'nMstep':  nMstep,
        'nEstep':  nEstep,
        'nFparamstep': nFparamstep,
        'kernfun': kernfun,
        'n_px_side': n_px_side,
        'display_prog':  False,
        'hyperparams_tuple': hyperparams_tuple,
        'f_params': f_params,
        'xtilde': xtilde_start,
        'm': torch.zeros( (ntilde) )
        # 'm': torch.ones( (ntilde) )
        #'V': dont initialize V if you want it to be initialized as K_tilde and projected _exactly_ as K_tilde_b for stabilisation
    }



updated sigma_0 to 1.0000
updated eps_0x to 0.0001
updated eps_0y to 0.0001
updated -2log2beta to 4.8069
updated -log2rho2 to 4.3069
updated Amp to 1.0000


### Fit the starting model
And save it needed to start a new active fit

In [5]:
theta_fit, f_params_fit, m_b_fit, V_b_fit, C_fit, mask_fit, K_tilde_b_fit, K_tilde_inv_b_fit, K_b, Kvec, B, fit_parameters, values_track_fit, err_dict = utils.varGP(X_used, 
                                                                                                                                                          R_used, 
                                                                                                                                                          **args)
# Save the model. All of the matrices are projected in the eigenspace of big eigenvalues of K_tilde. Indicated by _b
model_start = {
    'theta': theta_fit,
    'f_params': f_params_fit,
    'm_b': m_b_fit,
    'V_b': V_b_fit,
    'C'  : C_fit,
    'mask': mask_fit,
    'K_tilde_inv_b': K_tilde_inv_b_fit,
    'K_tilde_b': K_tilde_b_fit,
    'K_b': K_b,
    'Kvec': Kvec,
    'B': B,
    'values_track': values_track_fit,
}

for key in fit_parameters.keys():
    model_start[key] = fit_parameters[key]

if err_dict['is_error']:
    print('Error in the fit')
    raise err_dict['error']

# Save the model
# utils.save_model(model_start, f'models/starting_models_active_learning/cell:{cellid}_nstart:{ntrain_start}', additional_description='Starting model for active learning')


 Before overloading
 Hyperparameters have been SET as  : beta = 0.0586, rho = 0.0293
 Samuele hyperparameters           : logbetasam = 4.9822, logrhosam = 7.0617

 After overloading
 Dict of learnable hyperparameters : sigma_0 = 1.0000, eps_0x = 0.0000, eps_0y = 0.0000, -2log2beta = 4.2891, -log2rho2 = 6.3685, Amp = 1.0000
 Hyperparameters from the logexpr  : beta = 0.0586, rho = 0.0293
 Samuele hyperparameters           : logbetasam = 4.9822, logrhosam = 7.0617
Initialization took: 0.0470 seconds

*Iteration*: 0 E-step took: 0.2106s, M-step took: 0.0521s
*Iteration*: 1 E-step took: 0.0424s, M-step took: 0.0501s
*Iteration*: 2 E-step took: 0.0454s, M-step took: 0.0567s
*Iteration*: 3 E-step took: 0.0512s, M-step took: 0.0494s
*Iteration*: 4 E-step took: 0.0367s, M-step took: 0.0496s
*Iteration*: 5 E-step took: 0.0318s, M-step took: 0.0528s
*Iteration*: 6 E-step took: 0.0325s, M-step took: 0.0509s
*Iteration*: 7 E-step took: 0.0183s, M-step took: 0.0521s
*Iteration*: 8 E-step took: 0.01

### Infer mean and variance of the remaining datapoints, estimate utility of each.

Find the most useful image and its ID

In [None]:
model_start = utils.load_model('models/starting_models_active_learning/cell:8_nstart:50')

xstar = X_remaining

with torch.no_grad():
# reshape the remaining dataset to
# X_remaining = torch.reshape(Xremaining, (Xremaining.shape[0], Xremaining.shape[0]*Xremaining.shape[1]))

    kernfun       = model_start['kernfun']
    mask          = model_start['mask']
    C             = model_start['C']
    B             = model_start['B']
    K_tilde_b     = model_start['K_tilde_b']
    K_tilde_inv_b = model_start['K_tilde_inv_b']
    K_b           = model_start['K_b']
    Kvec          = model_start['Kvec']
    m_b           = model_start['m_b']
    V_b           = model_start['V_b']    
    f_params_fit  = model_start['f_params']
    A             = torch.exp(f_params_fit['logA'])
    lambda0       = torch.exp(f_params_fit['loglambda0'])

    start_time = time.time()
    # Calculate the matrices to compute the lambda moments. They are referred to the unseen images xstar
    Kvec = utils.acosker(theta, xstar[:,mask], x2=None, C=C, dC=None, diag=True)
    K    = utils.acosker(theta, xstar[:,mask], x2=xtilde_start[:,mask], C=C, dC=None, diag=False)
    K_b  = K @ B 

    lambda_m_t, lambda_var_t = utils.lambda_moments( xstar[:,mask], K_tilde_b, K_b@K_tilde_inv_b, Kvec, K_b, C, m_b, V_b, theta, kernfun)  
    u_t                      = torch.zeros( X_remaining.shape[0] )
    logf_mean_t              = torch.zeros( X_remaining.shape[0] )
    logf_var_t               = torch.zeros( X_remaining.shape[0] )
    print(f'Elapsed time for lambda moments: {time.time()-start_time:.2f} seconds')
    start_time = time.time()

    r_masked = torch.arange(0, 100, dtype=TORCH_DTYPE)
    for i, x_idx in enumerate(remaining_idx):

        logf_mean = A*lambda_m_t[i] + lambda0
        logf_var  = A**2 * lambda_var_t[i]

        u = utils.utility(logf_var, logf_mean, r_masked )

        # plt.subplot(111)
        # plt.scatter(logf_var.item(), u.item(), color=colors[r_cutoffs.index(r_cutoff)], s=10-r_cutoffs.index(r_cutoff)*5)
        # plt.title(f'r_cutoff: {r_cutoff}')
        print(f'Utility: {u.item():<8.4f} |  logf_mean: {logf_mean.item():8.4f} |  logf_var: {logf_var.item():6.4f}') 

        u_t[i]         = u
        logf_mean_t[i] = logf_mean
        logf_var_t[i]  = logf_var
        if i == 0 or u > u_best:
            x_idx_best = x_idx
            i_best     = i
            u_best     = u_t[i_best]
            
print(f'\nElapsed time for utility: {time.time()-start_time:.2f} seconds')  
# print(f'Utility: {u_t[i_best].item():<8.4f} |  Best image ID: {i_best}')
# # print(f'logf_mean: {loff_meanbest].item():8.4f} |  logf_var:, {logf_var_t[ i_best].item():.4f}')
# print(f'f_mean:{torch.exp(logf_mean_t[i_best]).item():8.4f} ')

Elapsed time for lambda moments: 0.01 seconds
Utility: 0.1048   |  logf_mean:  -1.0160 |  logf_var: 0.5231


In [20]:
# lamnda moments for xstar
# xstar = X_remaining[i_best,:][None,:]
xstar = X_remaining
# Kstar = utils.acosker(theta, xstar[:,mask], xtilde_start[:,mask], C=C, dC=None, diag=False) # shape (nt, ntilde) in this case nt=1 (xstar is a single point)
# Kstar = Kstar @ B # All of the quantities in mu sig

# lambda_star, lambda_var = utils.lambda_moments_star(xstar[:,mask], xtilde_start[:,mask], C, theta, K_tilde_b, K_tilde_inv_b, m_b, V_b, B, kernfun)
# print(lambda_star)

start_time = time.time()

Kvec = utils.acosker(theta, xstar[:,mask], x2=None, C=C, dC=None, diag=True)
K    = utils.acosker(theta, xstar[:,mask], x2=xtilde_start[:,mask], C=C, dC=None, diag=False)
K_b  = K @ B 
# lambda_m, lambda_var = lambda_moments( x[:,mask], K_tilde_b, KKtilde_inv_b, Kvec, K_b, C, m_b, V_b, theta, kernfun=kernfun)        
lambda_m, lambda_var = utils.lambda_moments( xstar[:,mask], K_tilde_b, K_b@K_tilde_inv_b, Kvec, K_b, C, m_b, V_b, theta, kernfun)        
# lambda_m[i_best]
elapsed_time = time.time() - start_time
print(f'Elapsed time: {elapsed_time:.2f} seconds')
for l in lambda_m:
    print(f'lambda_star: {l.item():.4f}')
# lambda_m

Elapsed time: 0.03 seconds
lambda_star: -6.3282
lambda_star: -4.7341
lambda_star: -18.4920
lambda_star: -1.0950
lambda_star: -3.6952
lambda_star: -13.1677
lambda_star: -19.2033
lambda_star: -1.9126
lambda_star: -1.6674
lambda_star: -18.5475
lambda_star: -3.7471
lambda_star: -2.1771
lambda_star: -4.3133
lambda_star: -4.9543
lambda_star: 0.3284
lambda_star: 8.0920
lambda_star: -3.4693
lambda_star: -2.1290
lambda_star: 0.2832
lambda_star: -4.2928
lambda_star: -7.2099
lambda_star: 0.6206
lambda_star: -2.3041
lambda_star: -20.2102
lambda_star: -4.6108
lambda_star: -19.8725
lambda_star: -0.4028
lambda_star: -8.9779
lambda_star: -4.7919
lambda_star: -2.6390
lambda_star: -8.7125
lambda_star: -4.2765
lambda_star: -8.5441
lambda_star: -9.6030
lambda_star: -1.8046
lambda_star: -6.1663
lambda_star: -32.1507
lambda_star: -15.2567
lambda_star: 0.8789
lambda_star: 3.1196
lambda_star: 3.0626
lambda_star: 1.7219
lambda_star: -2.3248
lambda_star: -11.1120
lambda_star: -8.2088
lambda_star: -2.0743
lambda

### Update the indexes for the training set adding the index of the most useful image 

In [8]:
print(f'Utility: {u_t[i_best].item():<8.4f} |  Best image ID: {i_best}')
# print(f'logf_mean: {loff_meanbest].item():8.4f} |  logf_var:, {logf_var_t[ i_best].item():.4f}')
print(f'f_mean:{torch.exp(logf_mean_t[i_best]).item():8.4f} ')

used_idx      = torch.cat( (used_idx, x_idx_best[None]) )
remaining_idx = all_idx[~torch.isin( all_idx, used_idx )]

X_used      = X[used_idx]
R_used      = R[used_idx] 
X_remaining = X[remaining_idx]
R_remaining = R[remaining_idx]

# The added images are used as inducing points as long as the number of inducing points is less than 200
if used_idx.shape[0] < 200:
    ntilde = used_idx.shape[0]
    xtilde_updated = X[used_idx]


Utility: 0.2113   |  Best image ID: 707
f_mean:  5.1709 


##### The training set has been updated and so has the xtilde

In [11]:
print(f'Number of starting points {X[start_idx].shape[0]}, Number of to be used points    {X_used.shape[0]}')
print(f'Number of initial ntilde  {xtilde_start.shape[0]}, Number of to be updated ntilde {xtilde_updated.shape[0]}')

Number of starting points 50, Number of to be used points    51
Number of initial ntilde  50, Number of to be updated ntilde 51
