In [None]:
import sys
sys.path.append("..")
from figutils import *

In [None]:
def l1(x, y):
    return dr.mean(dr.abs(x-y))[0]
def l2(x, y):
    return dr.mean(dr.sqr(x-y))[0]

In [None]:

losses = ["l1", "l2"]
loss_names = [r"$\mathcal{L}^1$", r"$\mathcal{L}^2$"]
methods = ["baseline", "cv_ps"]
method_names = ["Baseline", "Ours"]
use_denoising = [False, True]
scene_name = "ajar"
img_ref = mi.TensorXf(mi.Bitmap(os.path.join(OUTPUT_DIR, f"{scene_name}_l1", "img_ref.exr")))
imgs = []
final_losses = []
for i, loss in enumerate(losses):
    imgs.append([])
    final_losses.append([])
    for j, method in enumerate(methods):
        imgs[i].append([])
        final_losses[i].append([])
        for denoised in use_denoising:
            base_dir = os.path.join(OUTPUT_DIR, f"{scene_name}_{loss}{'_denoised' if denoised else ''}")
            img = mi.TensorXf(mi.Bitmap(os.path.join(base_dir, method, "img_final.exr")))
            imgs[i][j].append(img)
            # final_losses[i].append(np.load(os.path.join(base_dir, method, "result.npz"))["loss"][-1])
            # final_losses[i].append(dr.mean(dr.sqr(imgs[i][-1] - img_ref))[0])
            final_losses[i][j].append((l1(img, img_ref), l2(img, img_ref)))



In [None]:
sns.set_style('white')

In [None]:
n_cols = 10
n_rows = 2

# crop1 = [2*h//3, h//5+s//10]
crop1 = [0.4, 0.22, 0.1]
# crop2 = [8*h//7, 2*h//3 + s//10]
# crop2 = [0.5, 0.52, 0.2]
crop2 = [0.5, 0.52, 0.12]
crops = [crop1, crop2]
crop_colors = ["r", "g"]

import matplotlib.gridspec as gridspec
import matplotlib.patches as patches


h,w,_ = img_ref.shape
img_r = w/h
inset_r = 1.0
inner_wspace = 0.05
inner_hspace = inner_wspace*inset_r

insets_r = gridspec_aspect(2, 4, 1, 1, wspace=inner_wspace, hspace=inner_hspace)
ref_r = gridspec_aspect(2, 1, 1, 1, hspace=inner_hspace)

outer_wspace = 0.05
width_ratios = [img_r, insets_r, insets_r, ref_r]
outer_aspect = gridspec_aspect(1, 4, width_ratios, 1, wspace=outer_wspace)


fig = plt.figure(1, figsize=(PAGE_WIDTH, PAGE_WIDTH / outer_aspect))
# outer = fig.add_gridspec(1, 4, width_ratios=[w/total_width, 2*h/total_width, 2*h/total_width, h/2/total_width], wspace=outer_wspace)
outer = fig.add_gridspec(1, 4, width_ratios=width_ratios, wspace=outer_wspace)
gs_ref = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=outer[0], wspace=outer_wspace)
gs_l1_inset = gridspec.GridSpecFromSubplotSpec(n_rows, 4, subplot_spec=outer[1], wspace=inner_wspace, hspace=inner_hspace)
gs_l2_inset = gridspec.GridSpecFromSubplotSpec(n_rows, 4, subplot_spec=outer[2], wspace=inner_wspace, hspace=inner_hspace)
gs_ref_inset = gridspec.GridSpecFromSubplotSpec(n_rows, 1, subplot_spec=outer[3], hspace=inner_hspace)
inners = [gs_l1_inset, gs_l2_inset]

ax_ref = fig.add_subplot(gs_ref[:, 0])
ax_ref.imshow(mi.util.convert_to_bitmap(img_ref), interpolation='none')
ax_ref.set_xlabel(r'$\mathcal{L}^1$:'+'\n'+r'$\mathcal{L}^2$:', y=-0.25, fontsize=DEFAULT_FONTSIZE, multialignment='right', loc='right')
disable_ticks(ax_ref)
# disable_border(ax_ref)

for l, ((rx, ry, s), color) in enumerate(zip(crops, crop_colors)):
    ax = fig.add_subplot(gs_ref_inset[l, 0])
    left = int(rx*w)
    size = int(s*w)
    top = int(ry*h)
    ax.imshow(mi.util.convert_to_bitmap(img_ref[top:top+size, left:left+size]), interpolation='none')
    disable_ticks(ax)
    plt.setp(ax.spines.values(), color=color)
    if l == 0:
        # disable_border(ax)
        ax.set_title("Reference", y=1.025)

    rect = patches.Rectangle((left, top), size, size, linewidth=0.5, edgecolor=color, facecolor='none')
    ax_ref.add_patch(rect)
    rect = patches.Rectangle((0, 0), size-1, size-1, linewidth=1.0, edgecolor=color, facecolor='none')
    ax.add_patch(rect)

for i, (loss_name, gs) in enumerate(zip(loss_names, inners)):
    for j, method_name in enumerate(method_names):
        for k, denoised in enumerate(use_denoising):
            for l, ((rx, ry, s), color) in enumerate(zip(crops, crop_colors)):
                ax = fig.add_subplot(gs[l, j*2+k])
                left = int(rx*w)
                right = left + int(s*w)
                top = int(ry*h)
                bottom = top + int(s*w)
                ax.imshow(mi.util.convert_to_bitmap(imgs[i][j][k][top:bottom, left:right]), interpolation='none')
                # ax.imshow(mi.util.convert_to_bitmap(imgs[i][j*2+k][ry:ry+s, rx:rx+s]))
                # plt.plot(np.arange(20))
                disable_ticks(ax)
                # disable_border(ax)
                # plt.setp(ax.spines.values(), color=color, linewidth=1.0)
                size = right-left-1
                rect = patches.Rectangle((0, 0), size, size, linewidth=1.0, edgecolor=color, facecolor='none')
                # ax.add_patch(rect)
                if l == 0:
                    ax.set_title('+ Denoising' if denoised else method_name, y=1.025)
                else:
                    ax.set_xlabel(f"{final_losses[i][j][k][0]:.2e}\n{final_losses[i][j][k][1]:.2e}", y=-0.25, fontsize=DEFAULT_FONTSIZE)
                # disable_border(ax)

    # Ghost axes for the labels (https://stackoverflow.com/a/69117807)
    ax_label = fig.add_subplot(gs[:])
    ax_label.axis('off')
    ax_label.set_title(loss_name, y=1.18)
    rect = patches.Rectangle((.1, 1.2), 0.8, 0.0, linewidth=0.5, edgecolor='black', facecolor='none', clip_on=False)
    ax_label.add_patch(rect)

ax_ref = fig.add_subplot(gs_ref[:])
ax_ref.axis('off')
ax_ref.set_title('Reference', y=1.025)
save_fig("ajar", pad_inches=0.02)
