In [1]:
import numpy as np
import lasagne
import matplotlib.pyplot as plt
from theano import tensor as tnsr
from theano import function
from theano import map as tmap
from theano import shared

Using gpu device 0: GeForce GTX TITAN X (CNMeM is disabled, cuDNN not available)


In [70]:
class batch_feature_map_decoder(object):
    
    def __init__(self,
                 init_feature_maps,
                 trn_model, val_model,
                 trn_model_input_tnsr_dict, val_model_input_tnsr_dict,
                 trn_data, val_data,
                 epochs = 1,
                 check_every=10,
                 num_iters=10,
                 learning_rate = 1.0,
                 learn_these_feature_maps = None,
                 is_movie = False,
                 check_dims = True,
                 print_stuff=False):

        '''
    batch_feature_map_decoder(   init_feature_maps,
                                 trn_model, val_model,
                                 trn_model_input_tnsr_dict, trn_model_input_tnsr_dict,
                                 trn_data, val_data,
                                 epochs = 1,
                                 check_every=10,
                                 num_iters=10,
                                 learning_rate = 1.0,
                                 learn_these_feature_maps = None,
                                 is_movie = False,
                                 check_dims = True,
                                 print_stuff=False)

        a class for gradient descent on independent feature maps, given encoding models and activity patterns.

        inputs:
                init_feature_maps   ~ dict map_name/tensor pairs. 
                                      tensors have shape (T,D,S,S), where T is temporal length of stimulus.
                                      for stacks of T static images, optimization is performed independently
                                      for each image.
                                      for movies of T frames, we assume we have model weights that have a temporal
                                      dimension, so optimization will be for entire movie (i.e., frames not indepen't.)
                trn_model,val_model ~ lasagne models
   trn_/val_model_input_tensor_dict ~ dicts of theano tensors that are inputs to trn/val_model, resp.
                           trn_data ~ tensor of voxel activity, shape=(num_stimuli, T, num_trn_voxels)
                           val_data ~ tensor of voxel activity, shape=(num_stimuli, T, num_val_voxels)
                        check_every ~ int. how many gradients steps until validation loss is checked. default = 10
                          num_iters ~ number of gradient steps per batch before tranistioning to next batch. default = 100
                      learning_rate ~ size of gradient step. default = 1
           learn_these_feature_maps ~ list of names of feature maps to learn. 
                                      default = None, meaning learn all feature maps
                           is_movie ~ encoding model includes temporal kernel, so that output at any one point
                                      depends on many of the T frames, set this to True. Otherwise, will treat each
                                      of the T time-points as independent frames and optimize them separately.
                         check_dims ~ if True, run potentially slow sanity check on dimensions of inputs
        outputs:
        decoded_feature_maps ~ original l_model with params all trained up. this is a convenience, as params are learned in-place
              final_val_loss ~ the final validation loss for each of the models
                 trn_history ~ array showing training err history
                 val_history ~ history showing number of voxels with decreased validat'n loss at each iteration

        '''
        
        ##record all the inputs
        self.init_feature_maps = init_feature_maps
        self.trn_model = trn_model
        self.val_model = val_model
        self.learn_these_feature_maps = learn_these_feature_maps
        self.is_movie = is_movie
        self.trn_model_input_tnsr_dict = trn_model_input_tnsr_dict
        self.val_model_input_tnsr_dict = val_model_input_tnsr_dict
        self.trn_data = trn_data
        self.val_data = val_data
        self.epochs = epochs
        self.check_every=check_every
        self.num_iters = num_iters
        self.learning_rate = learning_rate
        self.print_stuff = print_stuff
        self.check_dims = check_dims

        ##inspect one batch of input to sanity check dimensions
        if self.check_dims:
            ##check data dimensions
            self.check_consistent()
    
        ##get learned feature maps: stores as a list of shared variables just like output of lasage.layers.get_all_params
        self.get_learned_maps()
        
        ##construct a gradient update and loss functions to be called iteratively by the learn method
        self.decoding_kernel = self.construct_decoding_kernel()


    def check_consistent( self ):
        ##read first batches to get some dimensions. assumes first dimension of input is time, last dimension is voxel 
        num_trn_images,trn_batch_size = self.trn_data.shape[0],self.trn_data.shape[-1]
        num_val_images,val_batch_size = self.val_data.shape[0],self.val_data.shape[-1]   
        assert num_trn_images == num_val_images, "number of trn/val images don't match"

        
        

    def get_learned_maps( self ):
        '''
        NOTE: Should probably establish learned_maps as shared variables?
        return list of theano variables
        '''
        if self.learn_these_feature_maps is not None:            
            self.learned_maps = [v for v in self.trn_model_input_tnsr_dict.values() if v.name in self.learn_these_feature_maps] ##<<unpack the dict.
        else:
            self.learned_maps = [v for v in self.trn_model_input_tnsr_dict.values()] ##<<unpack the dict.
        print 'will solve for: %s' %(self.learned_maps)


                
    ##construct a gradient update and loss functions to be called iteratively by the learn method
    def construct_decoding_kernel( self ):
        
        if self.is_movie:
            raise NotImplementedError('movies are not implemented yet.')
        
        ##===express training error
        trn_activity_tensor = tnsr.matrix('trn_activity')
        
        trn_pred_expr = lasagne.layers.get_output(self.trn_model)

        trn_diff = trn_activity_tensor-trn_pred_expr  ##difference tensor: (T x V)
        
        ##first sum over voxels, then over time-points
        trn_loss_expr = (trn_diff*trn_diff).sum(axis=1).sum()

        ##Training loss is scalar because of sum() on the trn_loss_expr: this is a sum over T dimension
        self.trn_loss_func = function([trn_activity_tensor]+self.trn_model_input_tnsr_dict.values(), trn_loss_expr)

        ##===express validation error
        val_activity_tensor = tnsr.matrix('val_activity')
        val_pred_expr = lasagne.layers.get_output(self.val_model)
        val_diff = val_activity_tensor-val_pred_expr  ##difference tensor: (T x V)
        val_loss_expr = (val_diff*val_diff).sum(axis=1) ##sum-sqaured-diffs tensor: SUM OVER VOXELS

        ##Validation loss has T distinct outputs, not scalar.
        ##self.loss(activity_matrix, fmap0, ... , fmapN)
        self.loss = function([val_activity_tensor]+self.val_model_input_tnsr_dict.values(), val_loss_expr)

        ##===build gradient w.r.t. input vars
        grad_expr = tnsr.grad(trn_loss_expr, wrt=self.learned_maps)

        ##a list of feature map gradients.
        ##each gradient in the list should be like T,D,S,S tensor4
        print 'compiling...'
        grad_func = function([trn_activity_tensor]+self.trn_model_input_tnsr_dict.values(), grad_expr)
        
        ##closure
        def decoding_kernel():
            args = [self.trn_data]+self.init_feature_maps.values()
            map_grads = grad_func(*args)
            ##update each feature map using names of theano variables ininput_var_dict
            ##note: we can enumerate because dict.values() and dict.keys() has same order...I think...
            ##note: probably a way we can vectorize this.
            for ii,k in enumerate(self.learned_maps.keys()):
                ##We assume (hope!) that 
                self.init_feature_maps[k] -= learning_rate*map_grads[ii]
            return self.trn_loss_func(self.trn_data,*self.init_feature_maps.values())
        
        return decoding_kernel
        
        
    ##if the last gradient step made things better for some the time-points, update the feature maps for that time-point
    def update_best_time_points(self, best_fmap_dict, improved_time_points):
        
        ##this implicitly checks for consistency between keys of learned maps, init_feature_maps, and best_fmap_dict
        for k in self.learned_maps.keys():
            ##this indexing should work because time points is always the first dimension
            best_fmap_dict[k][improved_time_points] = np.copy(self.init_feature_maps[k][improved_time_points])
        return best_fmap_dict        
    
    ##iteratively perform gradient descent
    def learn(self):
        
        ##initialize best parameters to whatever they are
        best_fmap_dict = {k:np.copy(v) for k,v in self.init_feature_maps.iteritem()}
        
        ##initalize validation loss to whatever you get from initial fmaps
        val_loss_was = self.loss(val_data, *best_fmap_dict.values())

        ##initialize train loss to whatever, we only report the difference    
        trn_loss_was = 0.0 ##we keep track of total across voxels as sanity check, since it *must* decrease
  
        val_history = []
        trn_history = []
        
        ##descend and validate
        step_count = 0
        while step_count < self.num_iters:
            
            ##update fmaps, output training loss
            trn_loss_is = self.decoding_kernel(trn_data, self.init_feature_maps.values())                    
            if step_count % self.check_every == 0:

                ##check for improvements
                val_loss_is = 0
                val_loss_is = self.loss(val_data, *self.init_feature_maps.values())
                improved = (val_loss_is < val_loss_was) ##improved ~ (T,)

                ##update val loss
                val_loss_was[improved] = val_loss_is[improved]

                ##update best feature maps
                best_fmap_values = self.update_best_time_points(best_fmap_dict, improved)

                ##report on loss history
                val_history.append(improved.sum())
                trn_history.append(trn_loss_is)
                if self.print_stuff:
                    print '====iter: %d' %(step_count)
                    print 'number of improved models: %d' %(val_history[-1])
                    print 'trn error: %0.6f' %(trn_history[-1])

            step_count += 1       

        return best_fmap_dict, val_loss_was, val_history, trn_history

In [71]:
##construct models
T,D,S,Vtrn,Vval = 11,17,12,44,22

input_tnsr = tnsr.tensor4('input_tnsr')
input_layer = lasagne.layers.InputLayer((T,D,S,S),input_var=input_tnsr)

true_feature_map = {'fmap':np.random.random(size=(T,D,S,S)).astype('float32')}
init_feature_map = {'fmap':np.random.random(size=(T,D,S,S)).astype('float32')}

trn_model = lasagne.layers.DenseLayer(input_layer,Vtrn)
val_model = lasagne.layers.DenseLayer(input_layer,Vval)

true_trn_activity = function([input_tnsr],lasagne.layers.get_output(trn_model))(true_feature_map['fmap'])
true_val_activity = function([input_tnsr],lasagne.layers.get_output(val_model))(true_feature_map['fmap'])

trn_model_input_tnsr_dict = {'fmap':input_tnsr}
val_model_input_tnsr_dict = {'fmap':input_tnsr}
trn_data = np.random.random(size=(T,Vtrn)).astype('float32')
val_data = np.random.random(size=(T,Vval)).astype('float32')
epochs = 1,
check_every=10,
num_iters=10,
learning_rate = 1.0,
learn_these_feature_maps = None,
is_movie = False,
check_dims = True,
print_stuff=False

decoder = batch_feature_map_decoder(init_feature_map,trn_model,val_model,trn_model_input_tnsr_dict,val_model_input_tnsr_dict,trn_data,val_data)

will solve for: [input_tnsr]
compiling...


In [72]:
##test val loss
L = decoder.loss(true_val_activity,*true_feature_map.values())
assert L.shape[0]==T, 'wrong bitch, shape is %d, not %d' %(L.shape[0], T)

##test trn loss
assert decoder.trn_loss_func(true_trn_activity,*true_feature_map.values())==0, 'wrong bitch!'
assert decoder.trn_loss_func(true_trn_activity,*init_feature_map.values())!=0, 'wrong bitch!'

In [73]:
##test decoding kernel
print decoder.decoding_kernel()

AttributeError: 'list' object has no attribute 'keys'

int