In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import os 
sys.path.append(os.path.abspath('../'))
import torch
from time import time
from tqdm import tqdm
import drjit as dr
import mitsuba as mi
import numpy as np
import matplotlib.pyplot as plt 
from IPython.display import clear_output
import torch.nn as nn
import torch.nn.functional as F


from convolutions import *
from utils_fns import *
from utils_general import update_sigma_linear, run_scheduler_step, plt_errors, show_with_error
from optimizations import *
from utils_optim import run_optimization, run_grad_optimization, run_cg_optimization, run_bfgs_optimization
from utils_general import run_scheduler_step
from utils_mitsuba import get_mts_rendering, render_smooth
from read_scenes import create_scene_from_xml

if torch.cuda.is_available():
    device = 'cuda'
    print("is available")
    mi.set_variant('cuda_ad_rgb')

is available


# Setup f(x), g(x)

In [3]:
def layer_f(weights, x):
    '''
    Analytic, for linear 3d x -> 4d output
    weight matrix with shape (4,3)
    '''
    return weights @ x
    
def loss_f(input):
    '''
    Loss function, s.t. second derivative is non-zero
    '''
    return torch.sum(input**4)
    
def f(weights, x):
    return loss_f(layer_f(weights, x))

def ddloss_f(input):
    '''
    Second derivative of loss function wrt input
    '''
    input = input.flatten()
    return torch.diag(12*input**2)

def dlayer_f(weights, x):
    '''
    Jacobian of layer_f wrt weights
    each column would be x with shape (4,3)
    '''
    return torch.func.jacrev(layer_f, argnums=0)(weights, x)
# def analytical_dfdx(x):
    


In [4]:
x = torch.tensor([1.0, 2.0, 3.0], device=device, requires_grad=True)
x = x.reshape(3,1)
weights = torch.rand((4,3), device=device, requires_grad=True)
# weights = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], device=device, requires_grad=True)
# weights = torch.tensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], device=device, requires_grad=True)
# weights = torch.tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], device=device, requires_grad=True)
y = layer_f(weights, x)
loss = loss_f(y)
# print(f'y: {y}')    
# print(f'loss: {loss}')

ddloss = ddloss_f(y)
dlayer = dlayer_f(weights, x).reshape(4,12)
# print(f'd2loss/dy2: \n{ddloss.cpu().detach().numpy()}')
# print(f'dy/dx: \n{dlayer.cpu().detach().numpy()}')

with np.printoptions(linewidth=np.inf, precision=2):
    weights_sod =  dlayer.T @ ddloss @ dlayer
    print(f'analytical second order derivatives for weights: \n{weights_sod.cpu().detach().numpy()}')
    hess_f = torch.func.hessian(f, argnums=0)
    weights_sod_torch = hess_f(weights, x).reshape(12,12)
    print(f'Computed from pytorch: \n{weights_sod_torch.cpu().detach().numpy()}')

analytical second order derivatives for weights: 
[[ 20.68  41.36  62.04   0.     0.     0.     0.     0.     0.     0.     0.     0.  ]
 [ 41.36  82.72 124.08   0.     0.     0.     0.     0.     0.     0.     0.     0.  ]
 [ 62.04 124.08 186.12   0.     0.     0.     0.     0.     0.     0.     0.     0.  ]
 [  0.     0.     0.    74.32 148.64 222.95   0.     0.     0.     0.     0.     0.  ]
 [  0.     0.     0.   148.64 297.27 445.91   0.     0.     0.     0.     0.     0.  ]
 [  0.     0.     0.   222.95 445.91 668.86   0.     0.     0.     0.     0.     0.  ]
 [  0.     0.     0.     0.     0.     0.    57.69 115.38 173.07   0.     0.     0.  ]
 [  0.     0.     0.     0.     0.     0.   115.38 230.76 346.15   0.     0.     0.  ]
 [  0.     0.     0.     0.     0.     0.   173.07 346.15 519.22   0.     0.     0.  ]
 [  0.     0.     0.     0.     0.     0.     0.     0.     0.    35.39  70.78 106.17]
 [  0.     0.     0.     0.     0.     0.     0.     0.     0.    70.78 141.56 2