In [None]:
import os
import sys
sys.path.insert(0, '../')
import torch
import numpy as np
import imageio
import glob
from einops import rearrange
from matplotlib import pyplot as plt

try:
    import piplite
    await piplite.install(['ipywidgets'])
except ImportError:
    pass
import ipywidgets as widgets

In [None]:
from engine.trainer import Trainer
from engine.eval import evaluation_path
from data import dataset_dict
from utils.opt import config_parser
from utils.vis import plot_palette_colors, visualize_depth_numpy, visualize_palette_components_numpy
from utils.color import rgb2hex, hex2rgb
from utils.ray import get_rays, ndc_rays_blender

## Utils

In [None]:
def print_divider():
    print()

                    
def render_one_view(test_dataset, tensorf, c2w, renderer, N_samples=-1,
                    white_bg=False, ndc_ray=False, palette=None, device='cuda'):
    
    torch.cuda.empty_cache()

    near_far = test_dataset.near_far

    if palette is None and hasattr(tensorf, 'get_palette_array'):
        palette = tensorf.get_palette_array().cpu()
    
    W, H = test_dataset.img_wh

    c2w = torch.FloatTensor(c2w)
    rays_o, rays_d = get_rays(test_dataset.directions, c2w)  # both (h*w, 3)
    if ndc_ray:
        rays_o, rays_d = ndc_rays_blender(H, W, test_dataset.focal[0], 1.0, rays_o, rays_d)
    rays = torch.cat([rays_o, rays_d], 1)  # (h*w, 6)
    
    res = renderer(rays, tensorf, chunk=2048, N_samples=N_samples, palette=palette,
                   ndc_ray=ndc_ray, white_bg=white_bg, device=device, ret_opaque_map=True)

    rgb_map = res['rgb_map']
    depth_map = res['depth_map']

    rgb_map, depth_map = rgb_map.reshape(H, W, 3).cpu(), depth_map.reshape(H, W).cpu()

    rgb_map = (rgb_map.numpy() * 255).astype('uint8')

    depth_map, _ = visualize_depth_numpy(depth_map.numpy(), near_far)

    is_vis_plt = (palette is not None) and ('opaque_map' in res)
    plt_decomp = None
    if is_vis_plt:
        opaque = rearrange(res['opaque_map'], '(h w) c-> h w c', h=H, w=W).cpu()
        plt_decomp = visualize_palette_components_numpy(opaque.numpy(), palette.numpy())
        plt_decomp = (plt_decomp * 255).astype('uint8')
    
    return rgb_map, depth_map, plt_decomp


## Config

In [None]:
# Make paths accessible by this notebook
path_redirect = [
    # option name, path in the config, redirected path
    ('palette_path', './data_palette', '../data_palette')
]

In [None]:
run_dir = '../logs/chair/'
ckpt_path = None
out_dir = os.path.join(run_dir, 'demo_out')

print('Run dir:', run_dir)
print('Demo output dir:', out_dir)

## Load and Setup

In [None]:
# Read args
parser = config_parser()
config_path = os.path.join(run_dir, 'args.txt')
if os.path.exists(config_path):
    with open(config_path, 'r') as f:
        args, remainings = parser.parse_known_args(args=[], config_file_contents=f.read())
        
        # override ckpt path
        if ckpt_path is not None:
            setattr(args, 'ckpt', ckpt_path)
        
        # redirect path
        for entry in path_redirect:
            setattr(args, entry[0], getattr(args, entry[0]).replace(entry[1], entry[2]))

        print('Args loaded:', args)
else:
    print(f'ERROR: cannot read args in {run_dir}.')
print_divider()


# Setup trainer
print('Initializing trainer and model...')
ckpt_dir = os.path.join(run_dir, 'checkpoints')
tb_dir = os.path.join(run_dir, 'tensorboard')
trainer = Trainer(args, run_dir, ckpt_dir, tb_dir)
model = trainer.build_network()
model.eval()
print_divider()


# Create downsampled dataset
dataset = dataset_dict[args.dataset_name]
ds_test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train * 2., is_stack=True)
print('Downsampled dataset loaded')

## Palette Editing

In [None]:
palette_prior = trainer.palette_prior.detach().cpu().numpy()
palette = model.renderModule.palette.get_palette_array().detach().cpu().numpy()

In [None]:
print('Initial palette prior:')
plot_palette_colors(palette_prior)

In [None]:
print('Optimized palette:')
new_palette = palette.clip(0, 1.)
plot_palette_colors(new_palette)

In [None]:
color_pickers = []

for i in range(palette.shape[0]):
    color_picker = widgets.ColorPicker(concise=False, description=f'Color {i}', value=rgb2hex(new_palette[i]), disabled=False)
    color_pickers.append(color_picker)

box_layout = widgets.Layout(width='100%', grid_template_rows='auto', grid_template_columns='25% 25% 25% 25%')
box_auto = widgets.GridBox(children=color_pickers, layout=box_layout)
display(box_auto)

In [None]:
print('Palette for rendering:')

new_palette = np.array([hex2rgb(cl_pk.value) for cl_pk in color_pickers]).astype(np.float32) / 255.
plot_palette_colors(new_palette)

## Rendering

In [None]:
# Modify this to change the rendering view
render_cam_idx = 1

c2w = ds_test_dataset.poses[render_cam_idx]
white_bg = ds_test_dataset.white_bg
ndc_ray = args.ndc_ray

with torch.no_grad():
    
    rgb, depth, plt_decomps = render_one_view(ds_test_dataset, model, c2w, trainer.renderer, palette=torch.from_numpy(new_palette),
                                              N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, device=trainer.device)

fig, axes = plt.subplots(1, 2, figsize=(16, 16))
axes[0].set_axis_off()
axes[0].imshow(rgb)
axes[1].set_axis_off()
axes[1].imshow(depth)

fig, axes = plt.subplots(1, 1, figsize=(16, 8))
axes.set_axis_off()
axes.imshow(plt_decomps)


In [None]:
# Run the cells below to save this editing

'''Modify this to name this editing'''
edit_name = 'red_chair'

assert edit_name

out_fn = f'rgb_palette{"_" + edit_name if edit_name else ""}'
out_path = os.path.join(out_dir, f'{out_fn}.npy')

if not os.path.exists(out_dir):
    os.makedirs(out_dir)

if os.path.exists(out_path):
    print('Error: file exists. Please specify another `edit_name`.')
else:
    np.save(out_path, new_palette)
    print('Save palette to', out_path)


In [None]:
'''Choose between 'test' / 'path' '''
cam_poses='test'

save_dir = os.path.join(out_dir, f'render_{cam_poses}{"_" + edit_name if edit_name else ""}')

if os.path.exists(save_dir):
    print('Error: directory exists. Please specify another `edit_name`.')
else:
    c2ws = trainer.test_dataset.poses if cam_poses == 'test' else trainer.test_dataset.render_path
    if cam_poses == 'test' and args.dataset_name == 'llff':
        c2ws = c2ws[::8, ...]
    white_bg = trainer.test_dataset.white_bg
    ndc_ray = trainer.args.ndc_ray

    print('Save renderings to', save_dir)
    print('=== render path ======>', c2ws.shape)
    with torch.no_grad():
        evaluation_path(trainer.test_dataset, model, c2ws, trainer.renderer, save_dir, palette=torch.from_numpy(new_palette),
                        N_samples=-1, white_bg=white_bg, ndc_ray=ndc_ray, save_video=True, device=trainer.device)