In [None]:
import sys
sys.path.append("..")
from figutils import *

In [None]:
scene_name = 'weights'
output_dir = os.path.join(OUTPUT_DIR, scene_name)
i = 49

In [None]:
methods = ["baseline", "cv_pss", "cv_ps"]
method_names = [r"\textbf{(a)} Baseline", r"\textbf{(b)} CV-PSS", r"\textbf{(c)} CV-PS"]
results = []
final_states = []
intermediate_states = []
textures = []
weights = []

ref_img = mi.Bitmap(os.path.join(output_dir, "img_ref.exr"))
start_img = mi.Bitmap(os.path.join(output_dir, "img_start.exr"))

for method in methods:
    results.append(np.load(os.path.join(output_dir, method, "result.npz")))
    final_states.append(mi.Bitmap(os.path.join(output_dir, method, "img_final.exr")))
    intermediate_states.append(mi.Bitmap(os.path.join(output_dir, method, "img", f"{i:04d}.exr")))
    if method != 'baseline' and i > 1:
        weights.append(mi.Bitmap(os.path.join(output_dir, method, "weights", f"{i:04d}.exr")))
    else:
        weights.append(None)


In [None]:
sns.set_style('white')

In [None]:
base_size = 4
w,h = ref_img.size()

n_cols = 3
n_rows = 2
aspect = w * n_cols / h / n_rows

fig = plt.figure(1, figsize=(TEXT_WIDTH, TEXT_WIDTH / aspect))
wspace= 0.01
gs = fig.add_gridspec(n_rows, n_cols, wspace=wspace, hspace=wspace*aspect)

for i, method in enumerate(method_names):
    ax = fig.add_subplot(gs[0, i])
    ax.imshow(mi.util.convert_to_bitmap(intermediate_states[i]), interpolation='none')
    disable_ticks(ax)

    if i == 0:
        ax.set_ylabel("Primal")

    ax = fig.add_subplot(gs[1, i])
    if weights[i] is not None:
        weight = mi.TensorXf(weights[i])[:,:,0]
        im = ax.imshow(weight, cmap='Reds_r', vmin=0, vmax=1, interpolation='none')
    else:
        weight = np.ones((h,w))
        ax.text(0.5, 0.5, "N/A", ha="center", va="center", color="darkgrey")
        disable_border(ax)

    if i == 0:
        ax.set_ylabel("Weights")

    ax.set_title(method, y=-0.25)
    disable_ticks(ax)
    if i == 2:
        cbax = ax.inset_axes([1.02, 0, 0.04, 1], transform=ax.transAxes)
        cbar = fig.colorbar(im, cax=cbax, ticks=[0, 0.5, 1])
        cbar.outline.set_visible(False)
        cbar.ax.tick_params(size=0)
save_fig("weights")
