In [1]:
import numpy as np
import torch

import matplotlib.pyplot as plt
%matplotlib inline

In [364]:
import torch
import warnings

class EKFAC(torch.optim.Optimizer):
    """ Implements the Eigenvalue-corrected Kronecker-factored Optimized Curvature preconditioner 
    
    See details at https://arxiv.org/pdf/1806.03884.pdf"""
    
    def __init__(self, 
                 network, 
                 recompute_KFAC_steps=1,
                 epsilon=0.1):
        """
        Arguments:
            network - the network to operate on
            recompute_KFAC_steps (integer) - the number of steps between successive recomputations of the 
                                             Kronecker factors of the layer-wise Fisher matrix
            epsilon (float) - the damping parameter used to avoid infinities"""
        
        self.epsilon = epsilon
        
        self.params_by_layer = []
        
        self.modules_with_weights = [torch.nn.Bilinear, 
                        torch.nn.Conv1d,
                        torch.nn.Conv2d,
                        torch.nn.Conv3d,
                        torch.nn.ConvTranspose1d,
                        torch.nn.ConvTranspose2d,
                        torch.nn.ConvTranspose3d,
                        torch.nn.Linear,                        
                       ]
        
        self.stored_items = {}
        
        # need to keep track of iteration because we only recompute KFAC matrices every 'self.recompute_KFAC_steps' steps 
        self.iteration_number = 0 
        self.recompute_KFAC_steps = recompute_KFAC_steps
        
        tracked_modules_count = 0
        for layer in network.modules():
            if type(layer) in self.modules_with_weights: 
                if type(layer) != torch.nn.Linear:
                    warnings.warn('Have not tested this for any module type other than linear')
                    
                # add functions to the module such that for all layers with weights
                layer.register_forward_pre_hook(self.store_input)
                layer.register_backward_hook(self.store_grad_output)
                
                # add parameters to the list, grouped by layer 
                self.params_by_layer.append({'params': [layer.weight]})
                if layer.bias is not None:
                    self.params_by_layer[-1]['params'].append(layer.bias)
        
                # make a label for the module and add it to the keys of the stored_items dictionary
                tracked_modules_count += 1
                self.stored_items[layer] = {} 
                
        default_options = {}
        super(EKFAC, self).__init__(self.params_by_layer, default_options)
       
    def step(self):
        
        if self.iteration_number % self.recompute_KFAC_steps == 0:
            self.compute_Kronecker_matrices()
            
        self.compute_scalings()
#         self.precondition()
        
        self.iteration_number += 1
        
    def store_input(self, module, inputs_to_module):
        """ When called before running each layer with weights, this function stores
        the input to the layer"""
        
        self.stored_items[module]['input'] = inputs_to_module[0]
    
    def store_grad_output(self, module, grad_wrt_input, grad_wrt_output):
        """ When called after the backward pass of each layer with weights, this function
        stores the gradient of the backwards-running function (usually the loss function) with respect
        to the pre-activations, i.e. the output of the layer"""
        
        # We have to scale by the batch size, because the grad_wrt_output which is passed to the 
        # function is already scaled down by batch_size, even though we did not do any reduction
        self.stored_items[module]['grad_wrt_output'] = grad_wrt_output[0] * grad_wrt_output[0].size(0) 
        
    def compute_Kronecker_matrices(self):
        """ For each layer (or, more properly, parameter group), computes the Kronecker-factored matrices, 
        where the Kronecker factors are defined by 
        A = E[input_to_layer @ input_to_layer.T]
        B = E[grad_wrt_output @ grad_wrt_output.T]
        """
        
        for layer, stored_values in self.stored_items.items():
            # notation follows the EKFAC paper
            h = stored_values['input'].t()
            delta = stored_values['grad_wrt_output']
            
            # We want E[ h @ h.T]
            # h should always be of size (n_inputs, batch_size)
            # delta should be of size (batch_size, n_outputs)
            with torch.no_grad():
                A = h @ h.transpose(1,0) / h.shape[1]
                B = delta.transpose(1,0) @ delta / delta.shape[0]
            
            # Eigendecompose A and B to get UA and UB, which contain the eigenvectors
            # UA @ diag(EvalsA) @ UA.t() = A
            EvalsA, UA = torch.symeig(A, eigenvectors=True)
            EvalsB, UB = torch.symeig(B, eigenvectors=True)
            
            self.stored_items[layer]['UA'] = UA
            self.stored_items[layer]['UB'] = UB
            
    def compute_scalings(self):
        
        for layer, stored_values in self.stored_items.items():
            UA = stored_values['UA']
            UB = stored_values['UB']
            h = stored_values['input'].t()
            delta = stored_values['grad_wrt_output']
            
            with torch.no_grad():
                batch_size = h.shape[1]
                # TODO Check that this is correct
                # Because delta and h contain information for each training example in the mini-batch,
                # when we do the matrix multiplication in the middle, we are averaging over the mini-batch.
                # So, we need to square the values first, so we can square-then-average, not average-then-square.
                scalings = ((UB.t() @ delta.t())**2) @ ((h.t() @ UA)**2) / batch_size
                
            stored_values['scalings'] = scalings
            
    def precondition(self):
        for layer, stored_values in self.stored_items.items():
            
            UA = stored_values['UA']
            UB = stored_values['UB']
            
            S = stored_values['scalings']
            
            grad_mb = layer.weight.grad.data # mb stands for 'mini-batch'
            grad_mb_kfe = UB @ grad_mb @ UA.t()
            grad_mb_kfe_scaled = grad_mb_kfe / (S + self.epsilon)
            grad_mb_orig = UB.t() @ grad_mb @ UA # back to original basis 
            
            layer.weight.grad.data = grad_mb_orig
            
    def approximate_Fisher_matrix(self, to_return=False):
        """ For testing/debugging, compute the layer-wise approximation to the empirical Fisher matrix 
            to compare to the Fischer information matrix """
        approximate_Fisher_matrices = []
        for layer, stored_values in self.stored_items.items():
            
            UA = stored_values['UA'].numpy()
            UB = stored_values['UB'].numpy()
            S = np.diag(stored_values['scalings'].numpy().reshape(-1))
            
            UAkronUB = np.kron(UA, UB)
            
            approximate_Fisher = UAkronUB @ S @ UAkronUB.T
            approximate_Fisher_matrices.append(approximate_Fisher)
            
            stored_values['aproximate_Fisher'] = torch.tensor(approximate_Fisher)
        
        if to_return:
            return approximate_Fisher_matrices
            
    def compute_hdeltaT(self):
        """ For testing/debugging, compute the layer-wise h delta T product.
        The minibatch-averaged h delta^T product should be equal to the gradient of the 
        weigt matrix for each linear layer."""
        
        for layer, stored_values in self.stored_items.items():
            h = stored_values['input']
            delta = stored_values['grad_wrt_output']
            stored_values['hdeltaT'] = h.t() @ delta / h.size(0)
            
    def compute_empirical_Fisher_matrix(self, to_return=False):
        """ For testing/debugging, compute empirical Fisher matrix """
        
        empirical_fisher_matrices = []
        for layer, stored_values in self.stored_items.items():
            h = stored_values['input']
            delta = stored_values['grad_wrt_output']

            with torch.no_grad():
                empirical_fisher_matrix = empirical_fisher(h, delta)    
                stored_values['empirical_fisher'] = empirical_fisher_matrix
                empirical_fisher_matrices.append(empirical_fisher_matrix)
        
        if to_return:
            return empirical_fisher_matrices
        
def outer_prod_individual(M1, M2):
    """ takes the outer product of M1 and M2, where M1 is NxA, and M2 is NxB
    """
    return torch.einsum('ij,ik->ijk', M1, M2)

def vectorize_individual(M):
    """ Given a tensor M, with size (A, B, C), vectorizes this tensor, leaving the first dimension intact,
    by stacking columns, resulting in a tensor of size (A, BC)"""
    Mt = M.transpose(1,2)
    return Mt.contiguous().view(Mt.size(0), -1)

def empirical_fisher(h, delta):
    """ given h, representing the input to a layer, and delta, representing the gradient with respect to its output,
    computes the empirical fisher matrix, averaged over the minibatch
    
    Arguments:
    h - torch.tensor, dimension (batch_size) * (n_inputs)
    delta - torch.tensor, dimension (batch_size) * (n_outputs)"""
    
    grad_individual = outer_prod_individual(h, delta)
    vec_grad_individual = vectorize_individual(grad_individual)
    fisher_individual = outer_prod_individual(vec_grad_individual, vec_grad_individual)
    return fisher_individual.mean(0)

In [361]:
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
Nbatch, D_in, H, D_out = 20, 2, 4, 1

# Use the nn package to define our model as a sequence of layers. nn.Sequential
# is a Module which contains other Modules, and applies them in sequence to
# produce its output. Each Linear Module computes output from input using a
# linear function, and holds internal Tensors for its weight and bias.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H, bias=False),
    torch.nn.Sigmoid(),
    torch.nn.Linear(H, D_out, bias=False),
)

In [297]:
EKFAC_one = EKFAC(model)

In [298]:
# Create random Tensors to hold inputs and outputs
x = torch.randn(Nbatch, D_in)
y = torch.randn(Nbatch, D_out)

In [299]:
y_mod = model(x)
loss_fun = torch.nn.MSELoss(reduction='sum')
z = loss_fun(y, y_mod)
z.backward()

In [300]:
EKFAC_one.step()
EKFAC_one.compute_hdeltaT()
EKFAC_one.compute_empirical_Fisher_matrix()

In [362]:
### code for testing empirical Fisher matrix implementation

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
Nbatch, D_in, H, D_out = 1, 2, 9, 1

model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H, bias=False),
    torch.nn.Sigmoid(),
    torch.nn.Linear(H, D_out, bias=False)
)

EKFAC_one = EKFAC(model)

x = torch.randn(Nbatch, D_in)
y = torch.randn(Nbatch, D_out)

# compute Fisher matrix altogether
y_mod = model(x)
loss_fun = torch.nn.MSELoss(reduction='mean')
z = loss_fun(y, y_mod)
z.backward()

empirical_Fisher_matrix = EKFAC_one.compute_empirical_Fisher_matrix(to_return=True)
EKFAC_one.step()
approximate_Fisher_matrix = EKFAC_one.approximate_Fisher_matrix(to_return=True)

# L1grads = []
# L2grads = []

# for i in range(Nbatch):
#     EKFAC_one.zero_grad()
    
#     x_test = x[i]
#     y_test = y[i]
    
#     y_mod = model(x_test)
#     z = loss_fun(y_test, y_mod)
#     z.backward()
    
#     for ind, module in enumerate(model.modules()):
#         if ind == 1:
#             L1grads.append(module.weight.grad.clone().numpy())
#         elif ind == 3:
#             L2grads.append(module.weight.grad.clone().numpy())

In [363]:
approximate_Fisher_matrix[1] - empirical_Fisher_matrix[1].numpy()

array([[-0.03250405, -0.17115346, -0.12354335, -0.10530421, -0.16963662,
        -0.2080802 , -0.22366273, -0.15949368, -0.15731049],
       [-0.17115346, -0.04094686, -0.0724525 , -0.18210094, -0.38260227,
        -0.37862158, -0.26161712, -0.2489062 , -0.20018882],
       [-0.12354335, -0.07245251, -0.00435159, -0.19718952, -0.09267814,
        -0.11479546, -0.16185072, -0.13504566, -0.13439783],
       [-0.10530421, -0.18210094, -0.19718952,  0.07043797, -0.51852846,
        -0.44541976, -0.28361455, -0.27673942, -0.19396444],
       [-0.16963664, -0.38260227, -0.09267814, -0.51852846,  0.14828075,
        -0.02879338, -0.29201096, -0.28335407, -0.24485016],
       [-0.2080802 , -0.37862158, -0.11479546, -0.44541976, -0.02879336,
        -0.04450651, -0.22582628, -0.21239723, -0.30395103],
       [-0.22366273, -0.26161712, -0.16185072, -0.28361455, -0.292011  ,
        -0.22582628, -0.03130716, -0.25572693, -0.19974801],
       [-0.15949368, -0.24890621, -0.13504566, -0.27673942, -0

array([[2.8731704, 1.9741668, 2.728336 , 2.3566275, 2.5245583, 1.8503456,
        2.7413404, 2.5156274, 2.0007644],
       [1.9741668, 1.3564578, 1.8746506, 1.6192482, 1.734634 , 1.2713798,
        1.8835859, 1.7284976, 1.3747332],
       [2.728336 , 1.8746506, 2.5908027, 2.2378318, 2.3972974, 1.7570713,
        2.6031516, 2.3888166, 1.8999074],
       [2.3566275, 1.6192482, 2.2378318, 1.9329494, 2.0706894, 1.5176878,
        2.248498 , 2.0633643, 1.6410639],
       [2.5245583, 1.734634 , 2.3972974, 2.0706894, 2.2182446, 1.6258366,
        2.4087238, 2.2103975, 1.7580044],
       [1.8503456, 1.2713798, 1.7570713, 1.5176878, 1.6258366, 1.1916381,
        1.7654461, 1.6200851, 1.288509 ],
       [2.7413404, 1.8835859, 2.6031516, 2.248498 , 2.4087238, 1.7654461,
        2.615559 , 2.4002028, 1.9089631],
       [2.5156274, 1.7284976, 2.3888166, 2.0633643, 2.2103975, 1.6200851,
        2.4002028, 2.2025778, 1.7517854],
       [2.0007644, 1.3747332, 1.8999074, 1.6410639, 1.7580044, 1.288509 

In [346]:
EKFAC_one.stored_items

{Linear(in_features=2, out_features=9, bias=False): {'UA': tensor([[-0.0587, -0.9983],
          [-0.9983,  0.0587]]),
  'UB': tensor([[ 4.8361e-01,  2.2956e-01, -1.6497e-03, -7.2406e-02, -3.5454e-02,
            5.7071e-01, -7.2590e-02,  4.4047e-01,  4.2653e-01],
          [-7.1136e-02, -3.8768e-01, -3.4171e-01, -8.0973e-01,  1.2047e-01,
            2.1952e-01,  9.8885e-03, -4.3114e-02, -8.6967e-02],
          [-6.7117e-01, -3.8924e-01,  1.1366e-01,  2.5546e-01,  9.1139e-03,
            3.8695e-01, -5.1582e-02,  8.5658e-02,  4.0005e-01],
          [ 6.9080e-02, -5.5708e-02, -8.7970e-01,  4.3574e-01, -2.1809e-02,
            1.2946e-01,  1.0069e-02, -7.4586e-02, -7.4077e-02],
          [ 4.1001e-01, -5.4350e-01,  2.4375e-01,  2.8049e-01,  5.4681e-01,
            1.7453e-01,  3.1730e-02, -4.8832e-02, -2.5604e-01],
          [ 2.4913e-01, -4.7546e-01, -1.0046e-01, -5.3937e-03, -1.5684e-01,
           -5.9822e-01, -5.1746e-02,  2.2947e-01,  5.1375e-01],
          [-1.7178e-01,  3.4700e-01

In [None]:
### It seems like when I only have a single entry in the mini-batch, the last layer's approximate FIsher is equal to the 
# exact Fisher