In [None]:
import mitsuba as mi
mi.set_variant("llvm_ad_rgb")
import drjit as dr

import os
base_dir = 'estimator_comparison_attached_reparam'
if not os.path.exists(base_dir):
    os.makedirs(base_dir)
    
import numpy as np
import matplotlib.pyplot as plt
import cmap_diff
%config InlineBackend.figure_formats = ['svg']
%matplotlib inline
    
mi.Thread.thread().logger().set_log_level(mi.LogLevel.Warn)

In [None]:
scene_names = [
    ('attached_disc_roughness', 'plane.bsdf.alpha.data'),
    ('attached_disc_normalmap', 'plane.bsdf.normalmap.data'),
]

methods = [
    ('Detached BSDF sampling',           {'method': 'bs_detached'}),
    ('Attached BSDF sampling (naive)',   {'method': 'bs_attached'}),
    ('Attached BSDF sampling (reparam)', {'method': 'bs_attached_reparam',
                                          'reparam_kappa': 1e6, 'reparam_rays': 48}),
]

for scene_name, param_key in scene_names:
    print("*", scene_name)
    scene_dir = "{}/{}".format(base_dir, scene_name)
    if not os.path.exists(scene_dir):
        os.makedirs(scene_dir)
    
    
    scene = mi.load_file('scenes/{}.xml'.format(scene_name), res=256)
    params = mi.traverse(scene)
    params.keep([param_key])
    params[param_key] = dr.maximum(0.1, params[param_key])
    
    # Primal rendering
    integrator = mi.load_dict({'type': 'estimator_comparison', 'method': 'primal_mis', 'hide_emitters': True})
    image = mi.render(scene, params, integrator=integrator, seed=0, spp=32)
    mi.util.convert_to_bitmap(image, uint8_srgb=False).write('{}/primal.exr'.format(scene_dir))

    # Compare various methods ...
    for method_name, method_dict in methods:
        print("   ", method_name)
        integrator_dict = {'type': 'estimator_comparison', 'hide_emitters': True}
        for k, v in method_dict.items():
            integrator_dict[k] = v
        integrator = mi.load_dict(integrator_dict)

        # Differentiable rendering ...
        dr.enable_grad(params[param_key])
        image = mi.render(scene, params, integrator=integrator, seed=0, spp=256, antithetic_pass=False)
        # ... and propagate back to input parameters
        dr.backward(image)
        param_grad = dr.grad(params[param_key])
        dr.set_grad(params[param_key], 0.0)
        dr.disable_grad(params[param_key])
        
        mi.util.convert_to_bitmap(param_grad, uint8_srgb=False).write('{}/{}.exr'.format(scene_dir, method_dict['method']))

In [None]:
scales = [0.5, 2]

for scene_idx, scene_name in enumerate([k for k, _ in scene_names]):
    scene_dir = "{}/{}".format(base_dir, scene_name)
    
    fig, axes = plt.subplots(ncols=4, figsize=(10,3))
    for ax in axes:
        ax.set_xticks([]); ax.set_yticks([])
        
    image_primal = np.array(mi.Bitmap("{}/primal.exr".format(scene_dir)))
    image_primal = np.clip(image_primal**(1/2.2), 0.0, 1.0) # Crude gamma correction
    axes[0].imshow(image_primal)
    axes[0].set_title("Primal")
    
    for method_idx, method_name in enumerate([v['method'] for k, v in methods]):
        vminmax = scales[scene_idx]
        image_grad = np.array(mi.Bitmap("{}/{}.exr".format(scene_dir, method_name)))
        data_ = image_grad[:,:,0] if len(image_grad.shape) == 3 else image_grad
        im = axes[method_idx+1].imshow(data_, cmap='diff', vmin=-vminmax, vmax=+vminmax)
    
    axes[1].set_title("Detached")
    axes[2].set_title("Attached (naïve)")
    axes[3].set_title("Attached (reparam.)")
    plt.suptitle(scene_name, weight='bold', size=14)
    
    outname = '{}/comparison.jpg'.format(scene_dir)
    plt.savefig(outname, dpi=300, pad_inches=0.1, bbox_inches='tight')
        
    plt.tight_layout()
    plt.show()