In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys
import os 
sys.path.append(os.path.abspath('../'))
import torch
from time import time
from tqdm import tqdm
import drjit as dr
import mitsuba as mi
import numpy as np
import matplotlib.pyplot as plt 
from IPython.display import clear_output
import torch.nn as nn
import torch.nn.functional as F


from convolutions import *
from utils_fns import *
from utils_general import update_sigma_linear, run_scheduler_step, plt_errors, show_with_error
from optimizations import *
from utils_optim import run_optimization, run_grad_optimization, run_cg_optimization, run_bfgs_optimization
from utils_general import run_scheduler_step
from utils_mitsuba import get_mts_rendering, render_smooth
from read_scenes import create_scene_from_xml

if torch.cuda.is_available():
    device = 'cuda'
    print("is available")
    mi.set_variant('cuda_ad_rgb')

is available


# Setup f(x), g(x)

In [16]:
def layer_f(weights, x):
    '''
    Analytic, for linear 3d x -> 4d output
    weight matrix with shape (4,3)
    '''
    return weights @ x
    
def loss_f(input):
    '''
    Loss function, s.t. second derivative is non-zero
    '''
    return torch.sum(input**4)
    
def f(weights, x):
    return loss_f(layer_f(weights, x))

def ddloss_f(input):
    '''
    Second derivative of loss function wrt input
    '''
    return 12*input**2

def dlayer_f(x):
    '''
    Derivative of layer_f wrt weights
    each column would be x with shape (4,3)
    '''
    x = x.reshape(1,-1)
    return x.expand(4,3)
# def analytical_dfdx(x):
    


In [36]:
x = torch.tensor([1.0, 2.0, 3.0], device=device, requires_grad=True)
x = x.reshape(3,1)
weights = torch.rand((4,3), device=device, requires_grad=True)
# weights = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], device=device, requires_grad=True)
# weights = torch.tensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], device=device, requires_grad=True)
weights = torch.tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], device=device, requires_grad=True)
y = layer_f(weights, x)
loss = loss_f(y)
print(f'y: {y}')    
print(f'loss: {loss}')

ddloss = ddloss_f(y)
dlayer = dlayer_f(x)
print(f'd2loss/dy2: {ddloss}')
print(f'dy/dx: {dlayer}')

weights_sod = dlayer @ dlayer.T * ddloss
print(f'analytical second order derivatives for weights: {weights_sod}')
hess_f = torch.func.hessian(f, argnums=0)
# print(weights)
print(hess_f(weights, x))

y: tensor([[6.],
        [6.],
        [6.],
        [6.]], device='cuda:0', grad_fn=<MmBackward0>)
loss: 5184.0
d2loss/dy2: tensor([[432.],
        [432.],
        [432.],
        [432.]], device='cuda:0', grad_fn=<MulBackward0>)
dy/dx: tensor([[1., 2., 3.],
        [1., 2., 3.],
        [1., 2., 3.],
        [1., 2., 3.]], device='cuda:0', grad_fn=<ExpandBackward0>)
analytical second order derivatives for weights: tensor([[6048., 6048., 6048., 6048.],
        [6048., 6048., 6048., 6048.],
        [6048., 6048., 6048., 6048.],
        [6048., 6048., 6048., 6048.]], device='cuda:0', grad_fn=<MulBackward0>)
tensor([[[[ 432.,  864., 1296.],
          [   0.,    0.,    0.],
          [   0.,    0.,    0.],
          [   0.,    0.,    0.]],

         [[ 864., 1728., 2592.],
          [   0.,    0.,    0.],
          [   0.,    0.,    0.],
          [   0.,    0.,    0.]],

         [[1296., 2592., 3888.],
          [   0.,    0.,    0.],
          [   0.,    0.,    0.],
          [   0.,  