In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn import init
import torch.autograd as autograd

import warnings
from torch.nn.modules.utils import _single, _pair
import math
import copy
# do crazy end-to-end optimization of backward path with MNIST/CIFAR10 and nn.linear    
# when calling loss.backward both gradinet_weight and grad_weight_feedback are being computed in customized modules
# first control against BP and regular modules
# Linear from here :https://pytorch.org/docs/stable/notes/extending.html
# Conv from here: https://github.com/pytorch/pytorch/blob/master/torch/nn/grad.py    
# autograd : https://pytorch.org/docs/stable/autograd.html#torch.autograd.backward  

import torch.autograd as autograd
class ReLUGrad(nn.Module):
    def __init__(self):
        super(ReLUGrad, self).__init__()
    def forward(self, grad_output, input):
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input

class LinearFunction2(autograd.Function):

    """
    Autograd function for a linear layer with asymmetric feedback and feedforward pathways
    forward  : weight
    backward : weight_feedback
    bias is set to None for now
    """

    @staticmethod
    # same as reference linear function, but with additional fa tensor for backward
    def forward(context, input, input2, weight, weight_feedback, bias=None, algorithm='BP'):
        context.save_for_backward(input,input2,  weight, weight_feedback, bias)
        context.algorithm = algorithm
        output = input.mm(weight.t())

        if bias is not None:
            output  += bias.unsqueeze(0).expand_as(output)
        return output

    @staticmethod
    def backward(context, grad_output):
        input,input2, weight, weight_feedback, bias, algorithm_id = context.saved_tensors
        grad_input = grad_input2 = grad_weight = grad_weight_feedback = grad_bias = None

        if context.needs_input_grad[0]:
            # all of the logic of FA resides in this one line
            # calculate the gradient of input with fixed fa tensor, rather than the "correct" model weight
            grad_input = grad_output.mm(weight_feedback)
        if context.needs_input_grad[2]:
            # grad for weight with FA'ed grad_output from downstream layer
            # it is same with original linear function
            #grad_weight = grad_output.t().mm(input) # (sorta)Hebbian update for forward
            grad_weight = grad_output.t().mm(input2) # using the second input for computing the gradients
            
        if context.needs_input_grad[3]:
            # only YY needs gradients for backward weights

            #grad_weight_feedback = grad_output.t().mm(input)  # (sorta)Hebbian update for backward
            grad_weight_feedback = grad_output.t().mm(input2)   # using the second input for computing the gradients
            
        if bias is not None and context.needs_input_grad[4]:
            grad_bias = grad_output.sum(0).squeeze(0)
        
        if context.algorithm == 'YY':
            return grad_input, None, grad_weight, grad_weight_feedback, grad_bias
        else:
            return grad_input, None, grad_weight, None, grad_bias

class Linear2(nn.Module):

    """
    a linear layer with asymmetric feedback and feedforward pathways
    forward  : weight
    backward : weight_feedback
    """

    def __init__(self, input_features, output_features, bias, algorithm ):     # we ignore bias for now
        
        super(Linear2, self).__init__()
        implemented_algorithms = ['BP', 'FA', 'YY']
        assert algorithm in implemented_algorithms, 'feedback algorithm %s is not implemented'

        
        # self.input_features = input_features
        # self.output_features = output_features
        self.algorithm = algorithm
        # weight and bias for forward pass
        # weight has transposed form for efficiency (?) (transposed at forward pass)

        self.weight = nn.Parameter(torch.Tensor(output_features, input_features))
        # as in torchvision/nn/modules/linear scaling was based on weight input (weight.size(1))
        # since  weight_feedback is the transpose scaling should be like below
#         self.scale_feedback = 1. / math.sqrt(self.weight.size(0))
        if bias:  
            self.bias = nn.Parameter(torch.Tensor(output_features))
        else: 
            self.register_parameter('bias', None)
        if self.algorithm == 'YY':
            back_requires_grad = True
        else:
            back_requires_grad = False
    
        self.weight_feedback = nn.Parameter(torch.Tensor(output_features, input_features), 
                                            requires_grad=back_requires_grad)

        self.reset_parameters()
        if self.algorithm == 'BP':
            self.weight_feedback.data = copy.deepcopy(self.weight.detach())


        
    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        init.kaiming_uniform_(self.weight_feedback, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def forward(self, input, input2):

        # if self.algorithm == 'FA':

        #     weight_feedback = self.weight_feedback
        if self.algorithm == 'BP':

            self.weight_feedback.data = copy.deepcopy(self.weight.detach())


        return LinearFunction2.apply(input, input2, self.weight, self.weight_feedback, self.bias, self.algorithm)
        
        
