### Core algorithm related codes anf functions presented for understanding of the algorithm
Note that this section of code is just for demonstration. Some part of code like - data preprocessing,network creation, logging results are omited, since they are not necessary for understanding our algorithm.

### Training Functions :

1. Function : [train_step_mtl] is the function that updates the weights for one mini-batch
        
        images->[batch_size x image_w x image_h x 6]
        labels [batch_size, trans(3), rot(3), flow[flow_w x flow_hX 2]
        --> input : images ,labels
        --> Calculate Gradients for main and auxiliary task.
        --> Applies clipped gradients to the weights of the model 
2. Function : [train_step_main] is the training function for main task
        
        --> input : images, labels [batch_size, trans(3)] or [batch_size, rot(3)]
        --> calculates rotation and translation task gradients from corresponding loss
        
3. Function : [train_step_aux] is the auxiliary task training function (The PRF method)
        
        --> input : 
                  -> images, labels [batch_size, flow[flow_w x flow_hX 2]]
                  -> Gradient variance for main task and auxiliary task
                  -> Mean gradients for main task
        --> Calculates optical flow task gradients from corresponding loss.
        --> Returns probability ratio x Auxiliary task gradients
        


In [None]:
@tf.function
def train_step_main(_images, _labels,_task):
    _mtp=train_params_tracker(_task)
    _loss_funcs=list(loss_funcs_dict.keys())
    _n=_loss_funcs.index(_task)
    
    #Calculate gradients for main task
    with tf.GradientTape() as tape:
        _predictions = model(_images, training=True)
        lbl,pred,loss_name=_labels[_n],_predictions[_n],_loss_funcs[_n]

        loss_func=loss_funcs_dict[loss_name]
        loss_weight=loss_weights_dict[loss_name]

        _batch_loss=loss_func(lbl, pred)
        _loss = tf.reduce_mean(_batch_loss)
        _gradients = tape.gradient(_loss, model.trainable_variables) 
        
        #Log necessary variables
        _mtp._loss_dict[loss_name+'_loss']=_loss
        _mtp._loss_dict['train_loss']=loss_weight*_loss
        _mtp._grad_dict[loss_name+'_grad'] = _gradients
        
    del tape
    for i in range(len(_mtp.total_gradients)):
        if _gradients[i]!=None: _mtp.total_gradients[i]= tf.multiply(loss_weight,_gradients[i])

    return [_mtp.__dict__]

@tf.function
def train_step_aux(_images, _labels,_task,_mtp,_grad_var_dict):
        #Extract main task and auxiliary task variances
        #Extract main task gradients
        _atp=train_params_tracker(_task)
        _loss_funcs=list(loss_funcs_dict.keys())
        _main_task=_loss_funcs[_mtl_main_task_index]
        _n=_loss_funcs.index(_task)
        _main_task_gradients=_mtp['_grad_dict'][_main_task+'_grad'] #Extract main task gradients
        _main_task_grad_var=_grad_var_dict[_main_task] #Extract main task variances
        _aux_task_grad_var=_grad_var_dict[_task] #Extract auxiliary task variances
        
        #Calculate auxiliary task gradients
        with tf.GradientTape() as tape: 
            _predictions = model(_images, training=True)

            #Get loss function and loss weight beta
            lbl,pred,loss_name=_labels[_n],_predictions[_n],_loss_funcs[_n]
            loss_func=loss_funcs_dict[loss_name]
            loss_weight=loss_weights_dict[loss_name]

            _batch_loss,_abs_loss_b=loss_func(lbl, pred)
            _loss_b = tf.reduce_mean(_batch_loss)
            _aux_task_gradients=tape.gradient(_loss_b, model.trainable_variables) 
            
            #Log loss and gradients
            _atp._loss_dict[loss_name+'_loss']=_loss_b
            _atp._loss_dict['train_loss']=loss_weight*_loss_b
            _atp._grad_dict[loss_name+'_grad'] = _aux_task_gradients

        del tape
        
        _beta_aux=loss_weight
        for k in range(len(_atp.total_gradients)):
                if _main_task_gradients[k] != None and _aux_task_gradients[k] != None: # Multiply shared weight gradients by factors
                        # v1 is for vanila auxiliary task guidance
                        if exp == 'ATG': 
                                _atp.total_gradients[k]=tf.multiply(_beta_aux,_aux_task_gradients[k])
                                
                        #v-4 is for PRF
                        _min_std = 10e-20
                        if exp == 'PRF':
                            _var_main_task=_main_grad_var[k]+_min_std #_min_std is added to avoid division by zero
                            _var_main_task=_aux_grad_var[k]+_min_std

                            _conf_ratio=_var_aux_task/_var_main_task  
                            _conf_ratio=tf.sqrt(_conf_ratio)
                            _task_similarity=tf.square((_main_task_gradients[k]-_aux_task_gradients[k]))
                            _task_similarity=_task_similarity/_var_main_task
                            _task_similarity=tf.exp(-0.5*_task_similarity)
                            _prob_ratio=tf.multiply(_conf_ratio,_task_similarity) 
                            _atp.total_gradients[k]=tf.multiply((_beta_aux*_prob_ratio),_aux_task_gradients[k])
                                                                        
                elif _aux_task_gradients[k] != None: # Update task specific gradients  
                     _atp.total_gradients[k]= _beta_aux*_aux_task_gradients[k]
        
        return [_atp.__dict__]
    
#@tf.function
def train_step_mtl(_images,_labels,_grad_var_dict={}):
    _updated_train_params=train_params_tracker()
    _train_params=[]      
    _train_params+=train_step_main(_images,_labels,'trans')
    _train_params+=train_step_main(_images,_labels,'rot')
    _train_params+=train_step_aux(_images, _labels,'flow_4',_train_params[_mtl_main_task_index],_grad_var_dict)
    _updated_train_params.update_all(_train_params)
    
    del _train_params
    apply_clipped_grads(_updated_train_params.total_gradients)
    _grad_dict = _updated_train_params._grad_dict
    
    _loss_dict = _updated_train_params._loss_dict
    del _updated_train_params

    return _loss_dict

### Utility functions

1. Function : [get_grad_var] function for calculating variance of gradients. (Gradient variance is calculated here for PRF)
        
        --> takes input samples and labels
        --> Calculates batch of gradients and corresponding variances

2. Function : [apply_clipped_grads]  for applying clipped gradients

3. Function : [train_params_tracker] a utility function for keeping track of gradients

In [None]:
@tf.function        
def get_grad_var(_images, _labels,_task):
        _loss_funcs=list(loss_funcs_dict.keys())
        _n=_loss_funcs.index(_task)
        
        _grad_variance=[None]*36 #36 is the total number of shared layers(including BatchNorm and Leaky ReLU)
        _sample_grads_list=[None]*batch_size
        
        with tf.GradientTape(persistent=True) as tape: 
            _predictions = model(_images, training=True)

            lbl,pred,loss_name=_labels[_n],_predictions[_n],_loss_funcs[_n]
            loss_func=loss_funcs_dict[loss_name]

            _batch_loss=loss_func(lbl, pred)
            _loss_b = tf.reduce_mean(_batch_loss)

            _mean_grads=tape.gradient(_loss_b, model.trainable_variables[:36])
            
            for k in range(batch_size):
                _sample_grads=tape.gradient(_batch_loss[k], model.trainable_variables[:36]) 
                _sample_grads_list[k]=_sample_grads
                
        del tape                      
        for _sample_grads in _sample_grads_list:
            for k,(_g1,_g2) in enumerate(zip(_mean_grads,_sample_grads)):
                if _grad_variance[k] != None : _grad_variance[k]=_grad_variance[k]+(1/batch_size)*tf.square(_g1-_g2)
                else : _grad_variance[k]= (1/batch_size)*tf.square(_g1-_g2)
                    
        return _grad_variance
            


def apply_clipped_grads(total_gradients):
    if _grad_clip != 0:
        for  _layer in grad_clip_layers.keys():
            _clip_value=grad_clip_layers[_layer]
            ind=layer_name_dict[_layer]
            total_gradients[ind]= tf.clip_by_value(total_gradients[ind],-_clip_value,_clip_value) 

    curr_optimizer.apply_gradients(zip(total_gradients, model.trainable_variables))
    return 1

class train_params_tracker():
    def __init__(self,_task=None):
        self.total_gradients=[None]*len(model.trainable_variables)
        self._grad_dict={}
        self._std_dict={}
        self._loss_dict={'train_loss':0.0,'rot_loss':0.0,'trans_loss':0.0,'flow_4_loss':0.0,'disp_4_loss':0.0}
        
    def update(self,obj):
        for i,(gr1,gr2) in enumerate(zip(self.total_gradients,obj['total_gradients'])):
            if gr1 ==None: gr1=0.0
            if gr2 ==None: gr2=0.0
            self.total_gradients[i]=gr1+gr2
        for key in self._loss_dict.keys():
            self._loss_dict[key]+=obj['_loss_dict'][key]
        self._grad_dict.update(obj['_grad_dict'])
        self._std_dict.update(obj['_std_dict'])
    def update_all(self,list_obj):
        for obj in list_obj:
            self.update(obj)

### Loss Functions

In [None]:
def flow_loss(y_true, y_pred):
        y_true=tf.cast(y_true,tf.float32)

        y_true=tf.clip_by_value(tf.cast(y_true,tf.float32),0.0,1.0)
        y_pred=tf.clip_by_value(tf.cast(y_pred,tf.float32),0.0,1.0)
        _add_offset=1

        y_true=tf.math.log(y_true+_add_offset)
        y_pred=tf.math.log(y_pred+_add_offset)

        loss=tf.square(tf.subtract(y_true,y_pred))
        abs_loss=tf.reduce_mean(tf.abs(tf.subtract(y_true,y_pred)),axis=[1,2,3])
        loss=tf.sqrt(tf.reduce_mean(loss,axis=[1,2,3]))
        return loss,abs_loss 


def trans_loss(y_true,y_pred):
    abs_loss_trans=tf.reduce_mean(tf.abs(tf.subtract(y_true,y_pred)),axis=-1)
    loss_trans=tf.reduce_mean(tf.abs(tf.subtract(y_true,y_pred)),axis=-1)
    return loss_trans,abs_loss_trans

def rot_loss(y_true,y_pred):
    abs_loss_rot=tf.reduce_mean(tf.abs(tf.subtract(y_true,y_pred)),axis=-1)
    loss_rot=tf.reduce_mean(tf.abs(tf.subtract(y_true,y_pred)),axis=-1)
    return loss_rot,abs_loss_rot

_mtl_main_task_index =1 #rotation
loss_funcs_dict={'trans':trans_loss_func,
                 'rot':rot_loss_func,
                 'flow_4':flow_loss_func
                }
loss_weights_dict={
                'trans':1.0
                'rot':10.0
                'flow':0.1
                }

### Training (bare minimum code is presented here for demonstration)

In [None]:
for epoch in epochs:
    
    for step in range(no_of_train_batches):
        _images,_labels=get_next_batch(train_data)
        _grad_var_dict['rot']=get_grad_var(_images, _labels,'rot') #main task
        _grad_var_dict['flow_4']=get_grad_var(_images, _labels,'flow_4')#auxiliary task
        train_loss=train_step_mtl(_images,_labels,_grad_var_dict=_grad_var_dict)
        
    for step in range(no_of_val_batches):
        _images,_labels=get_next_batch(val_data)
        val_loss=test_step(_images,_labels)
    