In [None]:
import sys
sys.path.append("..")
from figutils import *
import matplotlib.patches as patches
from tol_colors import tol_cmap, tol_cset

In [None]:
def l1(x, y):
    return(dr.mean(dr.abs(x-y)))[0]

In [None]:
def ape(x, y):
    return dr.abs(x-y)

In [None]:
methods = ["baseline", "cv_ps"]
method_names = ["Equal time", "Equal quality", "Ours"]
sensor_counts = [8, 10, 8]
sensors = [4, 1, 0]

scenes = ['janga', 'rover', 'dust_devil']
scene_names = ['Janga', 'Rover', 'Dust Devil']
imgs = []
error_imgs = []
error_scales = [0, 0, 0]
noisy_imgs = []
it = 39
ref_imgs = []
results = []
final_losses = []

for i, scene_name in enumerate(scenes):
    imgs.append([])
    error_imgs.append([])
    final_losses.append([])
    results.append([])
    noisy_imgs.append([])
    img_ref = mi.TensorXf(mi.Bitmap(os.path.join(OUTPUT_DIR, scene_name, "img_ref_display.exr")))
    if scene_name == "dust_devil":
        img_ref_viz = mi.TensorXf(mi.Bitmap(os.path.join(OUTPUT_DIR, scene_name, "img_ref_re.exr")))
    for j, method in enumerate(methods):
        dirs = [scene_name]
        if method == "baseline":
            dirs.append(f"{scene_name}_high_spp")
        for d in dirs:
            noisy_img = mi.TensorXf(mi.Bitmap(os.path.join(OUTPUT_DIR, d, method, "img", f"{it:04d}.exr")))
            img = mi.TensorXf(mi.Bitmap(os.path.join(OUTPUT_DIR, d, method, "img_final.exr")))
            final_losses[i].append(l1(img, img_ref))
            results[i].append(np.load(os.path.join(OUTPUT_DIR, d, method, "result.npz")))
            w = img.shape[1] // sensor_counts[i]
            sensor = sensors[i]

            if scene_name == "dust_devil":
                img = mi.TensorXf(mi.Bitmap(os.path.join(OUTPUT_DIR, d, method, "img_final_re.exr")))
                imgs[i].append(img)
                error_imgs[i].append(ape(img, img_ref_viz))
            else:
                imgs[i].append(img[:, w*sensor:w*(sensor+1)])
                error_imgs[i].append(ape(img[:, w*sensor:w*(sensor+1)], img_ref[:, w*sensor:w*(sensor+1)]))
            error_scales[i] = max(error_scales[i], dr.max(error_imgs[i][-1])[0])

            wn = noisy_img.shape[1] // sensor_counts[i]
            noisy_imgs[i].append(noisy_img[:, wn*sensor:wn*(sensor+1)])

    if scene_name == "dust_devil":
        ref_imgs.append(img_ref_viz)
    else:
        ref_imgs.append(img_ref[:, w*sensor:w*(sensor+1)])



In [None]:
aspect = 4.3/len(scenes)
n_rows = 3

inset_hspace = 0.02
inset_wspace = 0.02
inset_aspect = gridspec_aspect(2, 4, 1, 1, wspace=inset_wspace, hspace=inset_hspace)

row_wspace = 0.2
bar_r = 0.5
width_ratios = [1, inset_aspect, bar_r, bar_r]
row_aspect = gridspec_aspect(1, 4, width_ratios, 1, wspace=row_wspace)

outer_hspace = 0.1
outer_aspect = gridspec_aspect(n_rows, 1, row_aspect, 1, hspace=outer_hspace)

fig = plt.figure(1, figsize=(PAGE_WIDTH, PAGE_WIDTH / outer_aspect))
outer = fig.add_gridspec(n_rows, 1, hspace=outer_hspace)

nbins = 4
sns.set_palette(sns.color_palette("colorblind"))
error_maps = True

from matplotlib.ticker import FormatStrFormatter
crops = [
    [0.21, 0.3, 0.25],
    [0.5, 0.35, 0.25],
    [0.4, 0.5, 0.25]
]

for i, scene_name in enumerate(scene_names):
    gs_row = gridspec.GridSpecFromSubplotSpec(1, 4, subplot_spec=outer[i], wspace=row_wspace, hspace=0.0, width_ratios=width_ratios)

    gs_insets = gridspec.GridSpecFromSubplotSpec(2, 4, subplot_spec=gs_row[1], wspace=inset_wspace, hspace=inset_hspace)
    gs_loss = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs_row[2], wspace=0.0, hspace=0.0)
    gs_rt = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs_row[3], wspace=0.0, hspace=0.0)

    ax = fig.add_subplot(gs_row[0])
    ax.imshow(mi.util.convert_to_bitmap(ref_imgs[i]), interpolation='none')
    

    disable_ticks(ax)
    if i == n_rows - 1:
        ax.set_title("Reference", y=-0.15)
    ax.set_ylabel(scene_name)

    ax_loss = fig.add_subplot(gs_loss[0])
    ax_loss.tick_params(bottom=False, labelbottom=False)
    ax_loss.locator_params(nbins=nbins)
    ax_loss.yaxis.set_major_formatter(FormatStrFormatter('%.1e'))

    if i == n_rows - 1:
        ax_loss.set_title(r"$\mathcal{L}^1$", y=-0.15)
    ax_rt = fig.add_subplot(gs_rt[0])
    ax_rt.tick_params(bottom=False, labelbottom=False)
    ax_rt.locator_params(nbins=nbins)
    if i == n_rows - 1:
        ax_rt.set_title("Runtime (s)", y=-0.15)

    wx, wy, ws = crops[i]
    h, w, _ = imgs[i][0].shape
    left = int(wx*w)
    top = int(wy*h)
    size = int(ws*h)

    rect = patches.Rectangle((left, top), size, size, linewidth=1.0, edgecolor='g', facecolor='none')
    ax.add_patch(rect)

    ax = fig.add_subplot(gs_insets[0 if error_maps else 1, -1])
    ax.imshow(mi.util.convert_to_bitmap(ref_imgs[i][top:top+size, left:left+size]), interpolation='none')
    disable_ticks(ax)
    if i == n_rows - 1:
        with sns.axes_style('white'):
            if error_maps:
                ax = fig.add_subplot(gs_insets[1, -1])
                disable_border(ax)
                disable_ticks(ax)
                ax.set_title("Reference", y=-0.3)
            else:
                ax.set_title("Reference", y=-0.3)

    h1, w1, _ = noisy_imgs[i][0].shape
    if scene_name == "Dust Devil":
        wx = (left - 280) / 720
    left1 = int(wx*w1)
    top1 = int(wy*h1)
    size1 = int(ws*h1)

    for j, method in enumerate(method_names):
        ax = fig.add_subplot(gs_insets[0, j])
        ax.imshow(mi.util.convert_to_bitmap(noisy_imgs[i][j][top1:top1+size1, left1:left1+size1]), interpolation='none')
        disable_ticks(ax)
        if j == 0:
            ax.set_ylabel("Primal")

        ax = fig.add_subplot(gs_insets[1, j])
        if error_maps:
            im = ax.imshow(error_imgs[i][j][top:top+size, left:left+size, 0], interpolation='none', cmap='inferno', vmin=0, vmax=0.1)
        else:
            ax.imshow(mi.util.convert_to_bitmap(imgs[i][j][top:top+size, left:left+size]), interpolation='none')
        disable_ticks(ax)
        if i == n_rows - 1:
            ax.set_title(method, y=-0.3)
        if j == 0:
            ax.set_ylabel(r"$\mathcal{L}^1$ Error" if error_maps else "Final")
        elif j == 2:
            cbax = ax.inset_axes([1.03, 0.0, 0.05, 1.0], transform=ax.transAxes)
            cbar = fig.colorbar(im, cax=cbax, ticks=[0, 0.05, 0.09])
            cbar.outline.set_visible(False)
            cbar.ax.tick_params(size=0)
            cbar.ax.locator_params(nbins=nbins)

        ax_loss.bar(j*(0.8 if j < 2 else 0.9), final_losses[i][j], color=sns.color_palette()[1+(2+j)%3])
        ax_rt.bar(j*(0.8 if j < 2 else 0.9), results[i][j]["runtime"].sum() * 1e-3, alpha=0.5, color=sns.color_palette()[1+(2+j)%3], hatch='////')
        ax_rt.bar(j*(0.8 if j < 2 else 0.9), results[i][j]["runtime"][:,0].sum() * 1e-3, label=method, alpha=1.0, color=sns.color_palette()[1+(2+j)%3])

    if i == 0:
        ax_rt.legend(loc='upper left', bbox_to_anchor=(1.0, 1))
    elif i == 1:
        legend_elements = [patches.Patch(facecolor=sns.color_palette()[3], alpha=1.0),
                            patches.Patch(facecolor=sns.color_palette()[3], alpha=0.75, hatch='////') ]
        ax_rt.legend(legend_elements, ['Primal', 'Ajoint'], loc='upper left', bbox_to_anchor=(1.0, 1.8))
save_fig("volumes")