In [None]:
import numpy as np
import lasagne
import matplotlib.pyplot as plt
from fwrf.models import *
from fwrf.utils import *
from fwrf.utils import make_rf_table
from scipy.stats.stats import pearsonr

In [None]:
def decode_feature_maps(trn_fwrf_model, val_fwrf_model, trn_activity, val_activity, feature_map_dict, error_diff_thresh):
    ##build loss
    trn_activity_tensor = tnsr.matrix('trn_activity')
    val_activity_tensor = tnsr.matrix('val_activity')

    trn_diff = trn_activity_tensor-trn_fwrf_model.pred_expr  ##difference tensor: (T x V)
    trn_loss_expr = (trn_diff*trn_diff).sum(axis=1) ##sum-sqaured-diffs tensor:: the sum is over VOXELS.
    
    ##Outputs are scalar
    trn_loss_func = function([[trn_activity_tensor]+trn_fwrf_model.input_var_dict.values()], trn_loss_expr.sum())

    val_diff = val_activity_tensor-val_fwrf_model.pred_expr  ##difference tensor: (T x V)
    val_loss_expr = (val_diff*val_diff).sum(axis=1) ##sum-sqaured-diffs tensor: SUM OVER VOXELS
    
    ##Outputs should just be T.
    val_loss_func = BLARF

    ##build gradient w.r.t. input vars
    grad_expr = tnsr.gradient(trn_loss_expr, wrt=trn_fwrf_model.input_var_dict.values())
    
    ##a list of feature map gradients.
    ##each gradient in the list should be like T,D,S,S tensor4
    grad_func = function([trn_activity_tensor]+trn_fwrf_model.input_var_dict.values(), grad_expr)
    
    val_loss_is = np.inf
    val_loss_was = np.inf
    
    err_diff = np.abs(val_loss_is - val_loss_was)
    
    new_feature_map_values = SOME KIND OF COPY OF THE THE INIT. FEATURE MAPS. A LIST OF T,D,S,S tensors
    
    for step in range(number_of_steps):
        
        map_grads = grad_func(BLARF, new_feature_map_values) ##list of T,D,S,S tensors
        ##update each feature map using 
        for fm in len(map_grads):
            new_feature_map_values[fm] -= learning_rate*map_grads[fm]
        
        ##test against val set
        val_loss = val_loss_fun(new_feature_map_values)
        
        ##save if improved. We'll stay with the batch model where we iterate for fixed number for all batches
        
    ##On exit, give the feature map, the val_loss, history

In [None]:
class batch_feature_map_decoder(object):
    
    def __init__(self,
                 trn_model, val_model,
                 trn_model_input_tnsr_dict, trn_model_input_tnsr_dict,
                 trn_data_generator, val_data_generator,
                 epochs = 1,
                 check_every=10,
                 num_iters=10,
                 learning_rate = 1.0,
                 learn_these_maps = None,
                 check_dims = True,
                 print_stuff=False):

        '''
    batch_feature_map_decoder(   init_feature_maps,
                                 trn_model, val_model,
                                 trn_model_input_tnsr_dict, trn_model_input_tnsr_dict,
                                 trn_data_generator, val_data_generator,
                                 epochs = 1,
                                 check_every=10,
                                 num_iters=10,
                                 learning_rate = 1.0,
                                 learn_these_feature_maps = None,
                                 check_dims = True,
                                 print_stuff=False)

        a class for gradient descent on independent feature maps, given encoding models and activity patterns.

        inputs:
                init_feature_maps   ~ list of feature_map_dicts. length of list = num_stimuli
                trn_model,val_model ~ lasagne models
   trn_/val_model_input_tensor_dict ~ dicts of theano tensors that are inputs to trn/val_model, resp.
                 trn_data_generator ~ a generator of (num_stimuli, num_voxels) voxel activity matrices. you must write it.
                                      it should yield batches of data, i.e., trn_data_generator() returns your training batches.
                                      iteration is over voxels, so each batch is (num_stimuli, batch_size) matrix. 
                 val_data_generator ~ generator for validation data
                             epochs ~ number of times through all batches in the generator
                        check_every ~ int. how many gradients steps until validation loss is checked. default = 10
                          num_iters ~ number of gradient steps per batch before tranistioning to next batch. default = 100
                      learning_rate ~ size of gradient step. default = 1
           learn_these_feature_maps ~ list of names of feature maps to learn. 
                                      default = None, meaning learn all feature maps
                         check_dims ~ if True, run potentially slow sanity check on dimensions of inputs
        outputs:
        decoded_feature_maps ~ original l_model with params all trained up. this is a convenience, as params are learned in-place
              final_val_loss ~ the final validation loss for each of the models
                 trn_history ~ array showing training err history
                 val_history ~ history showing number of voxels with decreased validat'n loss at each iteration

        '''
        
        ##record all the inputs
        if type(init_feature_maps) not list:
            init_feature_maps = [init_feature_maps]
        self.init_feature_maps = init_feature_maps
        self.num_stimuli = len(init_feature_maps)
        self.trn_model = trn_model
        self.val_model = val_model
        self.learn_these_feature_maps = learn_these_feature_maps
        self.trn_model_input_tnsr_dict = trn_model_input_tnsr_dict
        self.val_model_input_tnsr_dict = val_model_input_tnsr_dict
        self.trn_data_generator = trn_data_generator
        self.val_data_generator = val_data_generator
        self.epochs = epochs
        self.check_every=check_every
        self.num_iters = num_iters
        self.learning_rate = learning_rate
        self.print_stuff = print_stuff
        self.check_dims = check_dims

        ##inspect one batch of input to sanity check dimensions
        if self.check_dims:
            self.grab_example_batch()
            ##check data dimensions
            self.check_consistent()
    
        ##get learned feature maps: stores as a list of shared variables just like output of lasage.layers.get_all_params
        self.get_learned_params()


        ##try to determine which dimension of each parameter is the voxel dimension
        self.find_batch_dimension()

        ##construct a gradient update and loss functions to be called iteratively by the learn method
        self.construct_training_kernel()
    
    def grab_example_batch( self ):
        self.trn_in, self.trn_out = next(self.trn_data_generator()) 
        self.val_in, self.val_out = next(self.val_data_generator())


    def check_consistent( self ):
        ##read first batches to get some dimensions. assumes first dimension of input is time, last dimension is voxel 
        num_trn_images,trn_batch_size = self.trn_out.shape[0],self.trn_out.shape[-1]
        num_val_images,val_batch_size = self.val_out.shape[0],self.val_out.shape[-1]   
        assert num_trn_images == num_val_images, "number of trn/val images don't match"
        assert num_trn_voxels == self.num_stimuli,"number stimuli in trn/val does not match specified number of stimuli"
        
        

    def get_learned_params( self ):
        if self.learn_these_params is not None:            
            self.params = [fev for v in trn_.values()] ##<<unpack the dict.
        else:
            self.params = lasagne.layers.get_all_params(self.l_model,trainable=True)
        print 'will solve for: %s' %(self.params)

    def find_batch_dimension( self ):
        if self.voxel_dims is None:
            self.voxel_dims = {}
            for p in self.params:
                ##the voxel dimension should be the one that matches "num_voxels"
                vdim = [ii for ii,pdim in enumerate(p.shape.eval()) if pdim==self.num_voxels]

                ##if we happen to have multiple dimensions that = "num_voxels", user must disambiguate
                assert len(vdim)==1, "can't determine voxel dimension for param %s. supply 'voxel_dims' argument" %p.name
                self.voxel_dims[p.name] = vdim[0]

                
    ##construct a gradient update and loss functions to be called iteratively by the learn method
    def construct_training_kernel( self ):
        voxel_data_tnsr = tnsr.matrix('voxel_data_tnsr')  ##voxel data tensor: (T x V)

        ##get symbolic prediction expression
        pred_expr = lasagne.layers.get_output(self.l_model)  ##voxel prediction tensor: (T x V)

        ##generate symbolic loss expression
        trn_diff = voxel_data_tnsr-pred_expr        ##difference tensor: shape = (T, V)
        loss_expr = (trn_diff*trn_diff).sum(axis=0) ##sum squared diffs over time: shape = (V,)

        ##for *training* error we compute of errors along voxel dimension.
        ##we have to do this because auto-diff requires. a scalar loss function.
        ##BUT: this is fine because gradient w.r.t. one voxel's weights is not affected by loss for any other voxel.
        trn_loss_expr = loss_expr.sum()

        #construct update rule using *training* loss.
        fwrf_update = lasagne.updates.sgd(trn_loss_expr,self.params,learning_rate=self.learning_rate)
        self.trn_kernel = function([voxel_data_tnsr]+self.model_input_tnsr_dict.values(),
                                   trn_loss_expr,
                                   updates=fwrf_update)           
        print 'will update wrt: %s' % (self.params,)
        
        ##compile loss and training functions
        ##NOTE: this is *validation* loss, not summed over voxels, so it has len = num_voxels
        print 'compiling...'
        self.loss = function([voxel_data_tnsr]+self.model_input_tnsr_dict.values(), loss_expr)
                
    
    ##if the last gradient step made things better for some voxels, update their parameters
    def update_best_param_values(self, best_param_values, improved_voxels):
        for ii,p in enumerate(self.params):
            vdim = self.voxel_dims[p.name] ##the voxel dimension
            s = [slice(None),]*p.ndim      ##create a slicing object with right number of dims
            s[vdim] = improved_voxels      ##assign improved voxel indices to the correct dim of the slice object
            best_param_values[ii][s] = np.copy(p.get_value()[s])   ##keep a record of the best params.
        return best_param_values        
    
    ##iteratively perform gradient descent
    def learn(self):
        
        ##initialize best parameters to whatever they are
        best_param_values = [np.copy(p.get_value()) for p in self.params]
        
        ##initalize validation loss to whatever you get from initial weigths
        val_loss_was = 0
        for val_in, val_out in self.val_data_generator():
            val_loss_was += self.loss(val_out, *val_in.values())

        ##initialize train loss to whatever, we only report the difference    
        trn_loss_was = 0.0 ##we keep track of total across voxels as sanity check, since it *must* decrease
  
        val_history = []
        trn_history = []
        
        ##descend and validate
        epoch_count = 0
        while epoch_count < self.epochs:
            print '=======epoch: %d' %(epoch_count) 
            for trn_in, trn_out in self.trn_data_generator():
                step_count = 0
                while step_count < self.num_iters:
                    
                    ##update params, output training loss
                    trn_loss_is = self.trn_kernel(trn_out, *trn_in.values())                    
                    if step_count % self.check_every == 0:

                        ##check for improvements
                        val_loss_is = 0
                        for val_in, val_out in self.val_data_generator():
                            val_loss_is += self.loss(val_out, *val_in.values())
                        improved = (val_loss_is < val_loss_was)

                        ##update val loss
                        val_loss_was[improved] = val_loss_is[improved]

                        ##replace old params with better params
                        best_param_values = self.update_best_param_values(best_param_values, improved)

                        ##report on loss history
                        val_history.append(improved.sum())
                        trn_history.append(trn_loss_is)
                        if self.print_stuff:
                            print '====iter: %d' %(step_count)
                            print 'number of improved models: %d' %(val_history[-1])
                            print 'trn error: %0.6f' %(trn_history[-1])
                    
                    step_count += 1

            epoch_count += 1

        ##restore best values of learned params
        set_named_model_params(self.l_model, **{k.name:v for k,v in zip(self.params, best_param_values)})
       

        return self.l_model, val_loss_was, val_history, trn_history