In [74]:
import torch
import torchvision.models as models
import numpy as np
from torch import nn

#example of loss and weights
def loss(x):
  return x.pow(3).sum()
weights = torch.rand(2, 2)

In [181]:
# hessian and loss are reverred to a minimum point

#calculate hessian
def calculate_hessian(loss, weights):
  hessian = torch.autograd.functional.hessian(loss, weights)
  hessian = hessian.sum(dim = 1).sum(dim = 0)
  return hessian.numpy()

#max eigenvalue
def max_eigenvalue(hessian):
  return max(np.linalg.eigvals(hessian))

#mean eigenvalue
def mean_eigenvalue(hessian):
  return np.linalg.eigvals(hessian).mean()

# measure based on hessian norm2 as in https://arxiv.org/pdf/1703.04933.pdf
def norm2_hessian(loss, hessian):
  norma = np.linalg.norm(hessian)
  return norma/(2+2*loss)

#measure based on hessian nuclear norm
def normn_hessian(loss, hessian):
  norma = np.linalg.norm(hessian, ord="nuc")
  return norma/(2+2*loss)

#measure trace hessian
def trace_hessian(hessian):
  return sum(np.linalg.eigvals(hessian))

#measure determinant hessian
def determinant_hessian(hessian):
  return np.linalg.det(hessian)

# possible to add pag 11 https://arxiv.org/pdf/1901.04653.pd

In [182]:
#tests
hessian = calculate_hessian(loss, weights)
hessian

array([[3.2199879, 4.332014 ],
       [5.7130795, 1.9056337]], dtype=float32)

In [183]:
np.linalg.eigvals(hessian)

array([ 7.58088  , -2.4552588], dtype=float32)

In [184]:
eps_sharpness(0.1, loss(weights), hessian)

tensor(0.0167)

In [185]:
norm2_hessian(loss(weights),hessian)

tensor(1.6666)

In [186]:
normn_hessian(loss(weights), hessian)

tensor(2.0877)

In [187]:
trace_hessian(hessian)

5.125621318817139

In [189]:
determinant_hessian(hessian)

-18.613024