In [1]:
%load_ext autoreload
%autoreload 2

In [31]:
import sys
import os 
sys.path.append(os.path.abspath('../'))
import torch
from time import time
from tqdm import tqdm
import drjit as dr
import mitsuba as mi
import numpy as np
import matplotlib.pyplot as plt 
from IPython.display import clear_output
import torch.nn as nn
import torch.nn.functional as F
import re


from convolutions import *
from utils_fns import *
from utils_general import update_sigma_linear, run_scheduler_step, plt_errors, show_with_error
from optimizations import *
from utils_optim import run_optimization, run_grad_optimization, run_cg_optimization, run_bfgs_optimization
from utils_general import run_scheduler_step
from utils_mitsuba import get_mts_rendering, render_smooth
from read_scenes import create_scene_from_xml

if torch.cuda.is_available():
    device = 'cuda'
    print("is available")
    mi.set_variant('cuda_ad_rgb')

is available


# Setup f(x), g(x)

In [3]:
def layer_f(weights, x):
    '''
    Analytic, for linear nd x -> md output
    '''
    return weights @ x
    
def loss_f(input):
    '''
    Loss function, s.t. second derivative is non-zero
    '''
    return torch.sum(input**4)
    
def f(weights, x):
    return loss_f(layer_f(weights, x))

def ddloss_f(input):
    '''
    Second derivative of loss function wrt input
    '''
    input = input.flatten()
    return torch.diag(12*input**2)

def dlayer_f(weights, x):
    '''
    Jacobian of layer_f wrt weights
    '''
    return torch.func.jacrev(layer_f, argnums=0)(weights, x)

def jvp_layer_f(weights, x, v):
    '''
    JVP of layer_f wrt weights
    '''
    return torch.func.jvp(layer_f, (weights, x), (v, torch.zeros_like(x)))[1]
# def analytical_dfdx(x):
    


# Low dim test

In [4]:
x = torch.tensor([1.0, 2.0, 3.0], device=device, requires_grad=True)
x = torch.rand(3, device=device, requires_grad=True)
x = x.reshape(3,1)
weights = torch.rand((4,3), device=device, requires_grad=True)
# weights = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], device=device, requires_grad=True)
# weights = torch.tensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], device=device, requires_grad=True)
# weights = torch.tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], device=device, requires_grad=True)
y = layer_f(weights, x)
loss = loss_f(y)
# print(f'y: {y}')    
# print(f'loss: {loss}')

ddloss = ddloss_f(y)
dlayer = dlayer_f(weights, x).reshape(4,12)
weights_sod =  dlayer.T @ ddloss @ dlayer
# print(f'd2loss/dy2: \n{ddloss.cpu().detach().numpy()}')
# print(f'dy/dx: \n{dlayer.cpu().detach().numpy()}')

hess_f = torch.func.hessian(f, argnums=0)
weights_sod_torch = hess_f(weights, x).reshape(12,12)
with np.printoptions(linewidth=np.inf, precision=2):
    print(f'analytical second order derivatives for weights: \n{weights_sod.cpu().detach().numpy()}')
    print(f'Computed from pytorch: \n{weights_sod_torch.cpu().detach().numpy()}')

analytical second order derivatives for weights: 
[[ 1.46  4.85  3.96  0.    0.    0.    0.    0.    0.    0.    0.    0.  ]
 [ 4.85 16.12 13.16  0.    0.    0.    0.    0.    0.    0.    0.    0.  ]
 [ 3.96 13.16 10.75  0.    0.    0.    0.    0.    0.    0.    0.    0.  ]
 [ 0.    0.    0.    0.26  0.85  0.69  0.    0.    0.    0.    0.    0.  ]
 [ 0.    0.    0.    0.85  2.82  2.31  0.    0.    0.    0.    0.    0.  ]
 [ 0.    0.    0.    0.69  2.31  1.88  0.    0.    0.    0.    0.    0.  ]
 [ 0.    0.    0.    0.    0.    0.    0.1   0.35  0.28  0.    0.    0.  ]
 [ 0.    0.    0.    0.    0.    0.    0.35  1.15  0.94  0.    0.    0.  ]
 [ 0.    0.    0.    0.    0.    0.    0.28  0.94  0.77  0.    0.    0.  ]
 [ 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.57  1.89  1.54]
 [ 0.    0.    0.    0.    0.    0.    0.    0.    0.    1.89  6.27  5.12]
 [ 0.    0.    0.    0.    0.    0.    0.    0.    0.    1.54  5.12  4.18]]
Computed from pytorch: 
[[ 1.46  4.85  3.96  0.  

In [5]:
# torch hvp
random_vec = torch.rand(12, device=device)
hvp_torch = weights_sod_torch @ random_vec

# analytical hvp
random_vec = random_vec.reshape(4,3)
jvp_torch = jvp_layer_f(weights, x, random_vec)
jvp_analytical = dlayer @ random_vec.flatten()
hvp_analytical = dlayer.T @ ddloss @ jvp_torch
with np.printoptions(linewidth=np.inf, precision=2):
    # print(f'random vec: {random_vec.cpu().detach().numpy()}')
    print(f'Hessian-vector product: {hvp_torch.cpu().detach().numpy()}')
    # print(f'Jacobian-vector product torch: {jvp_torch.cpu().detach().numpy()}')
    # print(f'Analytical Jacobian-vector product: {jvp_analytical.cpu().detach().numpy()}')
    print(f'Analytical Hessian-vector product: {hvp_analytical.cpu().detach().numpy().flatten()}')
    


Hessian-vector product: [ 9.21 30.63 25.01  0.96  3.19  2.6   0.38  1.28  1.04  2.23  7.42  6.06]
Analytical Hessian-vector product: [ 9.21 30.63 25.01  0.96  3.19  2.6   0.38  1.28  1.04  2.23  7.42  6.06]


# High dim test

## correctness

In [24]:
input_shape = 1000
output_shape = 5

x = torch.rand((input_shape,1), device=device, requires_grad=True)
weights = torch.rand((output_shape, input_shape), device=device, requires_grad=True)
y = layer_f(weights, x)
loss = loss_f(y)

ddloss = ddloss_f(y)
dlayer = dlayer_f(weights, x).reshape(output_shape, input_shape*output_shape)

hess_f = torch.func.hessian(f, argnums=0)
weights_sod_torch = hess_f(weights, x).reshape(input_shape*output_shape,input_shape*output_shape)

# torch hvp
random_vec = torch.rand(input_shape*output_shape, device=device)
hvp_torch = weights_sod_torch @ random_vec

# analytical hvp
random_vec = random_vec.reshape(output_shape, input_shape)
jvp_torch = jvp_layer_f(weights, x, random_vec)
hvp_analytical = dlayer.T @ ddloss @ jvp_torch

hvp_analytical = hvp_analytical.flatten()
hvp_torch = hvp_torch.flatten()
close_elements = torch.isclose(hvp_analytical.flatten(), hvp_torch.flatten())
not_close_indices = torch.where(~close_elements)[0]

print(f'For input shape {input_shape} and output shape {output_shape}:')
print(f'Different element between analytical and torch HVP: {not_close_indices}')

For input shape 1000 and output shape 5:
Different element between analytical and torch HVP: tensor([], device='cuda:0', dtype=torch.int64)


## real task size

In [34]:
def get_memory_size_from_error(error):
    allocation_pattern = r"Tried to allocate ([\d.]+ GiB)"
    allocation_match = re.search(allocation_pattern, str(error))
    if allocation_match:
        return allocation_match.group(1)
    return None

In [37]:
torch.cuda.empty_cache() # clears cache for large matrix
input_shape = 256*256*3
output_shape = 10

x = torch.rand((input_shape,1), device=device, requires_grad=True)
weights = torch.rand((output_shape, input_shape), device=device, requires_grad=True)
# torch hvp
random_vec = torch.rand(input_shape*output_shape, device=device)

print(f'For input shape {input_shape} and output shape {output_shape} (a likely task size):')

# try block for hessian torch with size of mug task
try:    
    hess_f = torch.func.hessian(f, argnums=0)
    weights_sod_torch = hess_f(weights, x).reshape(input_shape*output_shape,input_shape*output_shape)
    hvp_torch = weights_sod_torch @ random_vec
    hvp_torch = hvp_torch.flatten()
    print(f'Torch HVP computed successfully')
except RuntimeError as e:
    if 'out of memory' in str(e):
        size = get_memory_size_from_error(e)
        print(f"CUDA out of memory({size} needed) for brute force Hessian computation")

# try block for H(theta)VP torch with size of mug task
try:
    y = layer_f(weights, x)
    loss = loss_f(y)
    ddloss = ddloss_f(y)
    dlayer = dlayer_f(weights, x).reshape(output_shape, input_shape*output_shape)
    # analytical hvp 
    random_vec = random_vec.reshape(output_shape, input_shape)
    jvp_torch = jvp_layer_f(weights, x, random_vec)
    hvp_analytical = dlayer.T @ ddloss @ jvp_torch
    hvp_analytical = hvp_analytical.flatten()
    print(f'Analytical HVP computed successfully')
except RuntimeError as e:
    if 'out of memory' in str(e):
        size = get_memory_size_from_error(e)
        print(f"CUDA out of memory({size} needed) for HVP computation")



For input shape 196608 and output shape 10 (a likely task size):
CUDA out of memory(14400.00 GiB needed) for brute force Hessian computation
Analytical HVP computed successfully
