In [None]:
%load_ext autoreload
%autoreload 2

import torch
import numpy as np
from utils import load_ckpt, visualize_depth
from collections import defaultdict
import matplotlib.pyplot as plt

from models.rendering import *
from models.nerf import *

import metrics
from PIL import Image
import os
from torchvision import transforms as T

from datasets import dataset_dict
import json

torch.backends.cudnn.benchmark = True

In [None]:
# Change to your settings...
############################
N_vocab = 800
encode_appearance = True
N_a = 48
encode_transient = True
N_tau = 16
beta_min = 0.1 # doesn't have effect in testing
ckpt_path = 'ckpts/IC_patient000_segment001_scale3_nerfw_mask/last.ckpt'

N_emb_xyz = 10
N_emb_dir = 4
N_samples = 256
N_importance = 256
use_disp = False
chunk = 1024*32
#############################

embedding_xyz = PosEmbedding(N_emb_xyz-1, N_emb_xyz)
embedding_dir = PosEmbedding(N_emb_dir-1, N_emb_dir)
embeddings = {'xyz': embedding_xyz, 'dir': embedding_dir}

if encode_appearance:
    embedding_a = torch.nn.Embedding(N_vocab, N_a).cuda()
    load_ckpt(embedding_a, ckpt_path, model_name='embedding_a')
    embeddings['a'] = embedding_a

if encode_transient:
    embedding_t = torch.nn.Embedding(N_vocab, N_tau).cuda()
    load_ckpt(embedding_t, ckpt_path, model_name='embedding_t')
    embeddings['t'] = embedding_t
    

nerf_coarse = NeRF('coarse',
                   in_channels_xyz=6*N_emb_xyz+3,
                   in_channels_dir=6*N_emb_dir+3).cuda()

nerf_fine = NeRF('fine',
                 in_channels_xyz=6*N_emb_xyz+3,
                 in_channels_dir=6*N_emb_dir+3,
                 encode_appearance=encode_appearance,
                 in_channels_a=N_a,
                 encode_transient=encode_transient,
                 in_channels_t=N_tau,
                 beta_min=beta_min).cuda()

load_ckpt(nerf_coarse, ckpt_path, model_name='nerf_coarse')
load_ckpt(nerf_fine, ckpt_path, model_name='nerf_fine')

models = {'coarse': nerf_coarse, 'fine': nerf_fine}

In [None]:
@torch.no_grad()
def f(rays, ts, **kwargs):
    """Do batched inference on rays using chunk."""
    B = rays.shape[0]
    results = defaultdict(list)
    for i in range(0, B, chunk):
        kwargs_ = {}
        if 'a_embedded' in kwargs:
            kwargs_['a_embedded'] = kwargs['a_embedded'][i:i+chunk]
        rendered_ray_chunks = \
            render_rays(models,
                        embeddings,
                        rays[i:i+chunk],
                        ts[i:i+chunk],
                        N_samples,
                        use_disp,
                        0,
                        0,
                        N_importance,
                        chunk,
                        dataset.white_back,
                        test_time=True,
                        **kwargs_)

        for k, v in rendered_ray_chunks.items():
            results[k] += [v]

    for k, v in results.items():
        results[k] = torch.cat(v, 0)
    del rendered_ray_chunks
    torch.cuda.empty_cache()

    return results

In [None]:
root_dir = '/home/jupyter/data/IC_patient_000/segment_001/datasets/test/'

dataset = dataset_dict['phototourism'] \
          (root_dir,
           split='test_train',
           img_downscale=3, use_cache=False, exp_name='IC_patient000_segment001_scale3_nerfw_mask')



In [None]:
img_ids_test = dataset.img_ids_train[:5]
len(img_ids_test)


In [None]:
mask_dir = '/home/jupyter/data/IC_patient_000/segment_001/datasets/train/mask.png'

mask = Image.open(mask_dir).convert('L')

mask_w, mask_h = mask.size

mask_w = mask_w//3
mask_h = mask_h//3


mask = mask.resize((mask_w, mask_h), Image.Resampling.LANCZOS)
mask = T.ToTensor()(mask)



In [None]:

def masked_img(mask, img):
    img = img if torch.is_tensor(img) else torch.from_numpy(img)

    mask = mask.permute(1,2,0)

    return img * mask



def plot_print_results(sample, result,mask, encode_transient):
    img_wh = tuple(sample['img_wh'].numpy())
    img_gt = sample['rgbs'].view(img_wh[1], img_wh[0], 3)
    img_pred = result['rgb_fine'].view(img_wh[1], img_wh[0], 3).cpu().numpy()
    depth_pred = result['depth_fine'].view(img_wh[1], img_wh[0])

    
    masked_pred = masked_img(mask,img_pred).numpy()
    
    plt.subplots(figsize=(15, 8))
    plt.tight_layout()
    plt.subplot(241)
    plt.title('GT')
    plt.imshow(img_gt)
    plt.subplot(242)
    plt.title('pred')
    plt.imshow(img_pred)
    plt.subplot(243)
    plt.title('pred with mask')
    plt.imshow(masked_pred)
    plt.subplot(244)
    plt.title('depth')
    plt.imshow(visualize_depth(depth_pred).permute(1,2,0))
    plt.show()



    psnr = metrics.psnr(img_gt, img_pred).item()
    ssim = metrics.ssim(img_gt, img_pred).item()
    lpips = metrics.lpips_score(img_gt, img_pred).item()

    psnr_mask = metrics.psnr(img_gt, masked_pred).item()
    ssim_mask = metrics.ssim(img_gt, masked_pred).item()
    lpips_mask = metrics.lpips_score(img_gt, masked_pred).item()

    print(f'{"#"*15} Scores for real predicition {"#"*15}')
    print('PSNR between GT and pred:', psnr)
    print('SSIM between GT and pred:', ssim)
    print('LPIPS between GT and pred:',lpips, '\n') 

    print(f'{"#"*15} Scores for masked predicition {"#"*15}')
    print('PSNR between GT and pred:', psnr_mask)
    print('SSIM between GT and pred:', ssim_mask)
    print('LPIPS between GT and pred:',lpips_mask) 



    if encode_transient:
        print('Decomposition--------------------------------------------' + 
            '---------------------------------------------------------' +
            '---------------------------------------------------------' + 
            '---------------------------------------------------------')
        valid_mask_static = sample['valid_mask'].view(img_wh[1], img_wh[0])
        beta = result['beta'].view(img_wh[1], img_wh[0]).cpu().numpy()
        img_pred_static = result['rgb_fine_static'].view(img_wh[1], img_wh[0], 3).cpu().numpy()
        img_pred_transient = result['_rgb_fine_transient'].view(img_wh[1], img_wh[0], 3).cpu().numpy()
        depth_pred_static = result['depth_fine_static'].view(img_wh[1], img_wh[0])
        depth_pred_transient = result['depth_fine_transient'].view(img_wh[1], img_wh[0])
        plt.subplots(figsize=(15, 8))
        plt.tight_layout()
        plt.subplot(231)
        plt.title('static')
        plt.imshow(img_pred_static)
        plt.subplot(232)
        plt.title('transient')
        plt.imshow(img_pred_transient)
        plt.subplot(233)
        plt.title('uncertainty (beta)')
        plt.imshow(beta-beta_min, cmap='gray')
        plt.subplot(234)
        plt.title('static depth')
        plt.imshow(visualize_depth(depth_pred_static).permute(1,2,0))
        plt.subplot(235)
        plt.title('valid mask')
        plt.imshow(valid_mask_static, cmap='gray')
        plt.show()

    return psnr,psnr_mask, ssim, ssim_mask, lpips, lpips_mask


In [None]:
# CHANGE TO CORRECT IMAGE

results = {}
samples = {}
for id in img_ids_test:
    print(f'Calculating results for id: {id}...')
    sample = dataset[id]
    samples[id] = sample
    
    rays = sample['rays'].cuda()
    ts = sample['ts'].cuda()

    results[id] = (f(rays, ts))



In [None]:
psnrs = []
ssims = []
lpips_scores = []

psnrs_mask = []
ssims_mask = []
lpips_scores_mask = []

for id in img_ids_test:
     sample = samples[id]
     result = results[id]
     
     psnr,psnr_mask, ssim,ssim_mask, lpips, lpips_mask = plot_print_results(sample=sample, result=result,mask=mask, encode_transient=False)


     psnrs.append(psnr)
     ssims.append(ssim)
     lpips_scores.append(lpips)

     psnrs_mask.append(psnr_mask)
     ssims_mask.append(ssim_mask)
     lpips_scores_mask.append(lpips_mask)


# Masked

# PSNR
mean_psnr = round(np.mean(psnrs_mask),2)
median_psnr = round(np.median(psnrs_mask),2)
max_psnr = round(np.max(psnrs_mask),2)
min_psnr = round(np.min(psnrs_mask),2)

# SSIM
mean_ssim = round(np.mean(ssims_mask),3)
median_ssim = round(np.median(ssims_mask),3)
max_ssim = round(np.max(ssims_mask),3)
min_ssim = round(np.min(ssims_mask),3)

# LPIPS
mean_lpips = round(np.mean(lpips_mask),3)
median_lpips = round(np.median(lpips_mask),3)
max_lpips = round(np.max(lpips_mask),3)
min_lpips = round(np.min(lpips_mask),3)




# print('\n',f'{"#"*15} Mean scores for true prediction {"#"*15}')
# print("Mean PSNR: ", np.mean(psnrs))
# print("Mean SSIM: ", np.mean(ssims))
# print("Mean LPIPS: ", np.mean(lpips_scores), '\n')

print(f'{"#"*15} Mean scores for masked prediction {"#"*15}')
print("Mean PSNR: ", mean_psnr)
print("Mean SSIM: ", mean_ssim)
print("Mean LPIPS: ", mean_lpips)

data = {
     'mean_psnr': mean_psnr,
     'median_psnr': median_psnr,
     'max_psnr': max_psnr,
     'min_psnr': min_psnr,
     'mean_ssim': mean_ssim,
     'median_ssim': median_psnr,
     'max_ssim': max_ssim,
     'min_ssim': min_ssim,
     'mean_lpips': mean_lpips,
     'median_lpips': median_lpips,
     'max_lpips': max_lpips,
     'min_lpips': min_lpips
}

json_obj = json.dumps(data)
with open(os.path.join(root_dir, 'results_nerfw.json'),'w') as file:
     file.write(json_obj)





In [None]:
# Clear memory
del results
torch.cuda.empty_cache()

# Interpolate embedding for appearance change (Fig 8 in the paper)
### left image: train number 53; right image: train number 111
### The pose is fixed to that of the right image, and the appearance embedding is interpolated

In [None]:
# left_sample = dataset[53]
# right_sample = dataset[111]

# right_rays = right_sample['rays'].cuda()
# right_ts = right_sample['ts'].cuda()
# left_a_embedded = embedding_a(left_sample['ts'][0].cuda())
# right_a_embedded = embedding_a(right_sample['ts'].cuda())

# results_list = [left_sample]

# for i in range(5):
#     kwargs = {'a_embedded': right_a_embedded*i/4+left_a_embedded*(1-i/4)}
#     results_list += [f(right_rays, right_ts, **kwargs)]

# results_list += [right_sample]

In [None]:
# plt.subplots(figsize=(20, 10))
# for i, results in enumerate(results_list):
#     if i == 0:
#         img_wh = tuple(results['img_wh'].numpy())
#         left_GT = results['rgbs'].view(img_wh[1], img_wh[0], 3).cpu().numpy()
#         plt.subplot(241)
#         plt.axis('off')
#         plt.title('left GT')
#         plt.imshow(left_GT)
#     elif i == 6:
#         img_wh = tuple(results['img_wh'].numpy())
#         right_GT = results['rgbs'].view(img_wh[1], img_wh[0], 3).cpu().numpy()
#         plt.subplot(247)
#         plt.axis('off')
#         plt.title('right GT')
#         plt.imshow(right_GT)
#     else:
#         img_wh = tuple(right_sample['img_wh'].numpy())
#         img_pred = results['rgb_fine'].view(img_wh[1], img_wh[0], 3).cpu().numpy()
#         plt.subplot(2, 4, i+1)
#         plt.axis('off')
#         plt.imshow(img_pred)
# plt.tight_layout()
# plt.show()