In [2]:
%load_ext autoreload
%autoreload 2

In [7]:
import torch
import numpy as np
from utils import load_ckpt, visualize_depth
from collections import defaultdict
import matplotlib.pyplot as plt

from models.rendering import *
from models.nerf import *

import metrics
from PIL import Image
import os
from torchvision import transforms as T
from torchvision.utils import save_image

from datasets import dataset_dict
import json


torch.backends.cudnn.benchmark = True

In [8]:
# Change to your settings...
############################
N_vocab = 800
encode_appearance = True
N_a = 48
encode_transient = True
N_tau = 16
beta_min = 0.1 # doesn't have effect in testing

# TODO: change to correct optimization ckpt
ckpt_path = '/home/lukas/exjobb-nerf/nerf_pl/ckpts/IC_scale3_nerfw_correct_cam_params/last.ckpt'
embedding_ckpt_path = '/home/lukas/exjobb-nerf/ckpts/IC_scale3_nerfw_optimize_appearance_top_bottom/last.ckpt'

N_emb_xyz = 10
N_emb_dir = 4
N_samples = 256
N_importance = 256
use_disp = False
chunk = 1024*32
#############################

embedding_xyz = PosEmbedding(N_emb_xyz-1, N_emb_xyz)
embedding_dir = PosEmbedding(N_emb_dir-1, N_emb_dir)
embeddings = {'xyz': embedding_xyz, 'dir': embedding_dir}

if encode_appearance:
    embedding_a = torch.nn.Embedding(N_vocab, N_a).cuda()
    # load embedding from optimization
    load_ckpt(embedding_a, embedding_ckpt_path, model_name='embedding_a')
    embeddings['a'] = embedding_a

if encode_transient:
    embedding_t = torch.nn.Embedding(N_vocab, N_tau).cuda()
    load_ckpt(embedding_t, ckpt_path, model_name='embedding_t')
    embeddings['t'] = embedding_t
    

nerf_coarse = NeRF('coarse',
                   in_channels_xyz=6*N_emb_xyz+3,
                   in_channels_dir=6*N_emb_dir+3,
                   encode_appearance=encode_appearance
                   ).cuda()

nerf_fine = NeRF('fine',
                 in_channels_xyz=6*N_emb_xyz+3,
                 in_channels_dir=6*N_emb_dir+3,
                 encode_appearance=encode_appearance,
                 in_channels_a=N_a,
                 encode_transient=encode_transient,
                 in_channels_t=N_tau,
                 beta_min=beta_min).cuda()

load_ckpt(nerf_coarse, ckpt_path, model_name='nerf_coarse')
load_ckpt(nerf_fine, ckpt_path, model_name='nerf_fine')

models = {'coarse': nerf_coarse, 'fine': nerf_fine}

In [9]:
@torch.no_grad()
def f(rays, ts, **kwargs):
    """Do batched inference on rays using chunk."""
    B = rays.shape[0]
    results = defaultdict(list)
    for i in range(0, B, chunk):
        kwargs_ = {}
        if 'a_embedded' in kwargs:
            kwargs_['a_embedded'] = kwargs['a_embedded'][i:i+chunk]
        rendered_ray_chunks = \
            render_rays(models,
                        embeddings,
                        rays[i:i+chunk],
                        ts[i:i+chunk],
                        N_samples,
                        use_disp,
                        0,
                        0,
                        N_importance,
                        chunk,
                        dataset.white_back,
                        test_time=True,
                        **kwargs_)

        for k, v in rendered_ray_chunks.items():
            results[k] += [v]

    for k, v in results.items():
        results[k] = torch.cat(v, 0)
    del rendered_ray_chunks
    torch.cuda.empty_cache()

    return results

In [10]:
root_dir = '/home/jupyter/data/IC_patient_000/segment_001/datasets/'

dataset = dataset_dict['phototourism'] \
          (root_dir,
           split='test_train',
           img_downscale=3, 
           use_cache=False, 
           tsv_file='IC_scale3_nerfw.tsv',
           exp_name='eval_IC_scale3_nerfw_top_bottom',
           mask_path=f'{root_dir}/mask.png',
           )



In [7]:
img_ids_test = dataset.img_ids_test
print(len(img_ids_test))
for idx in img_ids_test:
    print(dataset.image_paths[idx])

25
00785.png
00749.png
00714.png
00683.png
00646.png
00613.png
00580.png
00548.png
00521.png
00491.png
00462.png
00433.png
00404.png
00164.png
00134.png
00102.png
00028.png
00042.png
00074.png
00206.png
00240.png
00272.png
00301.png
00327.png
00359.png


In [8]:
# CHANGE TO CORRECT IMAGE

results = {}
samples = {}
for idx in img_ids_test:
    print(f"Calculating results for id: {idx}...")
    sample = dataset[idx]
    sample['dataset_path'] = dataset.image_paths[idx]
    img_w, img_h = sample["img_wh"]

    rays = sample["rays"].cuda()
    ts = sample["ts"].cuda()
    rgbs = sample["rgbs"]

    rgbs = rgbs.view(img_h, img_w, -1)
    rgbs_top, rgbs_bottom = torch.split(rgbs, [img_h // 2, img_h // 2], dim=0)
    rgbs_top = rgbs_top.contiguous().view(-1)

    sample["rgbs_top"] = rgbs_top
    samples[idx] = sample
    results[idx] = {k: v.cpu() for k, v in (f(rays, ts)).items()}
    break

Calculating results for id: 500...


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


OutOfMemoryError: CUDA out of memory. Tried to allocate 3.00 GiB. GPU 0 has a total capacty of 15.89 GiB of which 1.95 GiB is free. Process 2156156 has 11.47 GiB memory in use. Including non-PyTorch memory, this process has 2.47 GiB memory in use. Of the allocated memory 2.15 GiB is allocated by PyTorch, and 31.64 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
mask = dataset.mask.permute(1,2,0)
img_w, img_h = mask.shape[0], mask.shape[1]

mask_top, mask_bottom = torch.split(mask, [img_h//2, img_h//2], dim=0)

plt.subplots(figsize=(15, 8))
plt.axis('off')
plt.tight_layout()
plt.subplot(131)
plt.title('mask')
plt.imshow(mask, cmap='gray')
plt.subplot(132)
plt.title('mask top')
plt.imshow(mask_top, cmap='gray')
plt.subplot(133)
plt.title('mask bottom')
plt.imshow(mask_bottom, cmap='gray')
plt.show()


In [None]:

def masked_img(mask, img):
    img = img if torch.is_tensor(img) else torch.from_numpy(img)
    return img * mask

def plot_print_results(sample, result,mask, encode_transient):
    img_w, img_h = tuple(sample['img_wh'].numpy())


    # get whole GT and top half
    img_gt = sample['rgbs'].reshape(img_h, img_w, 3)
    img_gt_top = sample['rgbs_top'].reshape(img_h // 2, img_w, 3)

    # get whole pred and split in half
    img_pred = result['rgb_fine'].reshape(img_h, img_w, 3)
    img_pred_top, img_pred_bottom = torch.split(img_pred, [img_h//2, img_h//2], dim=0)

    depth_pred = result['depth_fine'].reshape(img_h, img_w)

    # mask GTs and preds
    img_gt = masked_img(mask,img_gt)
    img_gt_top = masked_img(mask_top,img_gt_top)

    masked_pred = masked_img(mask,img_pred)
    masked_pred_top = masked_img(mask_top,img_pred_top)

    # save results
    img_dir = os.path.join(root_dir, 'results', 'images')
    depth_dir = os.path.join(root_dir, 'results', 'depth')

    # create dirs if not exist
    os.makedirs(img_dir, exist_ok=True)
    os.makedirs(depth_dir, exist_ok=True)

    img_name = sample['dataset_path']
    depth_name = f'{img_name}_depth.png'

    save_image(img_pred, os.path.join(img_dir, img_name))
    save_image(depth_pred, os.path.join(dept_dir, depth_name))


    # plot results
    plt.subplots(figsize=(15, 8))
    plt.axis("off")
    plt.tight_layout()

    plt.subplot(241)
    plt.title("GT masked")
    plt.imshow(img_gt)

    plt.subplot(242)
    plt.title("pred")
    plt.imshow(img_pred)

    plt.subplot(243)
    plt.title("pred with mask")
    plt.imshow(masked_pred)

    plt.subplot(244)
    plt.title("depth")
    plt.imshow(visualize_depth(depth_pred).permute(1, 2, 0))

    plt.subplot(245)
    plt.title("GT top")
    plt.imshow(img_gt_top)

    plt.subplot(246)
    plt.title("pred top")
    plt.imshow(img_pred_top)

    plt.subplot(247)
    plt.title("pred top masked")
    plt.imshow(masked_pred_top)

    plt.show()

    # valid mask
    valid_mask = (mask_top != 0).repeat(1,1,3)

    # calculate scores on top half of image
    psnr_ = metrics.psnr(img_gt_top, img_pred_top).item()
    ssim_ = metrics.ssim(img_gt_top, img_pred_top).item()
    lpips_ = metrics.lpips_score(img_gt_top, img_pred_top).item()

    psnr_mask_ = metrics.psnr(img_gt_top, masked_pred_top).item()
    psnr_unmask_ = metrics.psnr(img_gt_top, masked_pred_top, valid_mask=valid_mask).item()
    ssim_mask_ = metrics.ssim(img_gt_top, masked_pred_top).item()
    lpips_mask_ = metrics.lpips_score(img_gt_top, masked_pred_top).item()

    print(f'{"#"*15} Scores for real predicition {"#"*15}')
    print('PSNR between GT and pred:', psnr_)
    print('SSIM between GT and pred:', ssim_)
    print('LPIPS between GT and pred:',lpips_, '\n') 

    print(f'{"#"*15} Scores for masked predicition {"#"*15}')
    print('PSNR between GT and pred:', psnr_mask_)
    print('UNMASKED PSNR between GT and pred:', psnr_unmask_)
    print('SSIM between GT and pred:', ssim_mask_)
    print('LPIPS between GT and pred:',lpips_mask_) 



    if encode_transient:
        print('Decomposition--------------------------------------------' + 
            '---------------------------------------------------------' +
            '---------------------------------------------------------' + 
            '---------------------------------------------------------')
        valid_mask_static = sample['valid_mask'].reshape(img_wh[1], img_wh[0])
        beta = result['beta'].reshape(img_wh[1], img_wh[0])
        img_pred_static = result['rgb_fine_static'].reshape(img_wh[1], img_wh[0], 3)
        img_pred_transient = result['_rgb_fine_transient'].reshape(img_wh[1], img_wh[0], 3)
        depth_pred_static = result['depth_fine_static'].reshape(img_wh[1], img_wh[0])
        depth_pred_transient = result['depth_fine_transient'].reshape(img_wh[1], img_wh[0])
        plt.subplots(figsize=(15, 8))
        plt.tight_layout()
        plt.subplot(231)
        plt.title('static')
        plt.imshow(img_pred_static)
        plt.subplot(232)
        plt.title('transient')
        plt.imshow(img_pred_transient)
        plt.subplot(233)
        plt.title('uncertainty (beta)')
        plt.imshow(beta-beta_min, cmap='gray')
        plt.subplot(234)
        plt.title('static depth')
        plt.imshow(visualize_depth(depth_pred_static).permute(1,2,0))
        plt.subplot(235)
        plt.title('valid mask')
        plt.imshow(valid_mask_static, cmap='gray')
        plt.show()

    return psnr_,psnr_mask_,psnr_unmask_, ssim_, ssim_mask_, lpips_, lpips_mask_

In [None]:
psnrs = []
ssims = []
lpips_scores = []

psnrs_mask = []
psnrs_unmask = []
ssims_mask = []
lpips_scores_mask = []

# don't evaluate on images with transient objects
# skipped_images = [0,1,2,3]
skipped_images = []
for i,idx in enumerate(img_ids_test):
     if i in skipped_images:
          continue
     sample = samples[idx]
     result = results[idx]
     
     psnr,psnr_mask,psnr_unmask, ssim,ssim_mask, lpips, lpips_mask = plot_print_results(sample=sample, result=result,mask=mask, encode_transient=False)
     
     psnrs.append(psnr)
     psnrs_unmask.append(psnr_unmask)
     ssims.append(ssim)
     lpips_scores.append(lpips)

     psnrs_mask.append(psnr_mask)
     ssims_mask.append(ssim_mask)
     lpips_scores_mask.append(lpips_mask)


# Masked

# PSNR
mean_psnr = round(np.mean(psnrs_mask),2)
median_psnr = round(np.median(psnrs_mask),2)
max_psnr = round(np.max(psnrs_mask),2)
min_psnr = round(np.min(psnrs_mask),2)

mean_psnr_unmask = round(np.mean(psnrs_unmask),2)
median_psnr_unmask = round(np.median(psnrs_unmask),2)
max_psnr_unmask = round(np.max(psnrs_unmask),2)
min_psnr_unmask = round(np.min(psnrs_unmask),2)

# SSIM
mean_ssim = round(np.mean(ssims_mask),3)
median_ssim = round(np.median(ssims_mask),3)
max_ssim = round(np.max(ssims_mask),3)
min_ssim = round(np.min(ssims_mask),3)

# LPIPS
mean_lpips = round(np.mean(lpips_mask),4)
median_lpips = round(np.median(lpips_mask),4)
max_lpips = round(np.max(lpips_mask),4)
min_lpips = round(np.min(lpips_mask),4)


print(f'{"#"*15} Mean scores for masked prediction {"#"*15}')
print("Mean PSNR: ", mean_psnr)
print("Mean UNMASKED PSNR: ", mean_psnr_unmask)
print("Mean SSIM: ", mean_ssim)
print("Mean LPIPS: ", mean_lpips)

data = {
     'mean_psnr': mean_psnr,
     'median_psnr': median_psnr,
     'max_psnr': max_psnr,
     'min_psnr': min_psnr,
     'mean_psnr_unmasked': mean_psnr_unmask,
     'median_psnr_unmasked': median_psnr_unmask,
     'max_psnr_unmasked': max_psnr_unmask,
     'min_psnr_unmasked': min_psnr_unmask,
     'mean_ssim': mean_ssim,
     'median_ssim': median_ssim,
     'max_ssim': max_ssim,
     'min_ssim': min_ssim,
     'mean_lpips': mean_lpips,
     'median_lpips': median_lpips,
     'max_lpips': max_lpips,
     'min_lpips': min_lpips
}

json_obj = json.dumps(data)
results_file = f'results_{dataset.exp_name}.json' if not skipped_images else f'results_{dataset.exp_name}_skipped_images.json'
with open(os.path.join(root_dir,'results', results_file),'w') as file:
     file.write(json_obj)





In [None]:
# Clear memory
#del samples
#del results
#torch.cuda.empty_cache()

# Interpolate embedding for appearance change (Fig 8 in the paper)
### left image: train number 53; right image: train number 111
### The pose is fixed to that of the right image, and the appearance embedding is interpolated

In [None]:
# left_sample = dataset[53]
# right_sample = dataset[111]

# right_rays = right_sample['rays'].cuda()
# right_ts = right_sample['ts'].cuda()
# left_a_embedded = embedding_a(left_sample['ts'][0].cuda())
# right_a_embedded = embedding_a(right_sample['ts'].cuda())

# results_list = [left_sample]

# for i in range(5):
#     kwargs = {'a_embedded': right_a_embedded*i/4+left_a_embedded*(1-i/4)}
#     results_list += [f(right_rays, right_ts, **kwargs)]

# results_list += [right_sample]

In [None]:
# plt.subplots(figsize=(20, 10))
# for i, results in enumerate(results_list):
#     if i == 0:
#         img_wh = tuple(results['img_wh'].numpy())
#         left_GT = results['rgbs'].view(img_wh[1], img_wh[0], 3).cpu().numpy()
#         plt.subplot(241)
#         plt.axis('off')
#         plt.title('left GT')
#         plt.imshow(left_GT)
#     elif i == 6:
#         img_wh = tuple(results['img_wh'].numpy())
#         right_GT = results['rgbs'].view(img_wh[1], img_wh[0], 3).cpu().numpy()
#         plt.subplot(247)
#         plt.axis('off')
#         plt.title('right GT')
#         plt.imshow(right_GT)
#     else:
#         img_wh = tuple(right_sample['img_wh'].numpy())
#         img_pred = results['rgb_fine'].view(img_wh[1], img_wh[0], 3).cpu().numpy()
#         plt.subplot(2, 4, i+1)
#         plt.axis('off')
#         plt.imshow(img_pred)
# plt.tight_layout()
# plt.show()

In [None]:
mask = mask.permute(1,2,0)

In [None]:
img = Image.open(
    os.path.join(root_dir, "images", '00000.png')
).convert("RGB")

img_w, img_h = img.size
img_w = img_w // 3
img_h = img_h // 3
img = img.resize((img_w, img_h), Image.LANCZOS)

img_tensor = T.ToTensor()(img).permute(1,2,0)
print(img_tensor.shape)
plt.tight_layout()
plt.subplots(figsize=(15, 8))



plt.subplot(131)
plt.imshow(img)
plt.subplot(132)
plt.imshow(mask)
plt.subplot(133)
plt.imshow(img_tensor*mask)
plt.show()

