In [1]:
import argparse
import os
import time

import imageio
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use("Agg")

#matplotlib.use("TkAgg")

import numpy as np
import torch
import torchvision
import yaml
from tqdm import tqdm
#from nerf-pytorch import


from nerf import (
    CfgNode,
    get_ray_bundle,
    load_flame_data,
    load_llff_data,
    models,
    get_embedding_function,
    run_one_iter_of_nerf,
    meshgrid_xy
)

### read config file

In [2]:
config = 'nerface_dataset/person_2/person_2_config.yml'
cfg = None
with open(config, "r") as f:
    cfg_dict = yaml.load(f, Loader=yaml.FullLoader)
    cfg = CfgNode(cfg_dict)
cfg.dataset.type.lower()

'blender'

In [3]:
images, poses, render_poses, hwf = None, None, None, None
i_train, i_val, i_test = None, None, None
if cfg.dataset.type.lower() == "blender":
    # Load blender dataset
    images, poses, render_poses, hwf, i_split, expressions, _, _ = load_flame_data(
        cfg.dataset.basedir,
        half_res=cfg.dataset.half_res,
        testskip=cfg.dataset.testskip,
        test=True
    )
    #i_train, i_val, i_test = i_split
    i_test = i_split
    H, W, focal = hwf
    H, W = int(H), int(W)
elif cfg.dataset.type.lower() == "llff":
    # Load LLFF dataset
    images, poses, bds, render_poses, i_test = load_llff_data(
        cfg.dataset.basedir, factor=cfg.dataset.downsample_factor,
    )
    hwf = poses[0, :3, -1]
    H, W, focal = hwf
    hwf = [int(H), int(W), focal]
    render_poses = torch.from_numpy(render_poses)


starting data loading
Done with data loading


In [4]:
# Device on which to run.
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
print(device)
encode_position_fn = get_embedding_function(
    num_encoding_functions=cfg.models.coarse.num_encoding_fn_xyz,
    include_input=cfg.models.coarse.include_input_xyz,
    log_sampling=cfg.models.coarse.log_sampling_xyz,
)

encode_direction_fn = None
if cfg.models.coarse.use_viewdirs:
    encode_direction_fn = get_embedding_function(
        num_encoding_functions=cfg.models.coarse.num_encoding_fn_dir,
        include_input=cfg.models.coarse.include_input_dir,
        log_sampling=cfg.models.coarse.log_sampling_dir,
    )

# Initialize a coarse resolution model.
model_coarse = getattr(models, cfg.models.coarse.type)(
    num_encoding_fn_xyz=cfg.models.coarse.num_encoding_fn_xyz,
    num_encoding_fn_dir=cfg.models.coarse.num_encoding_fn_dir,
    include_input_xyz=cfg.models.coarse.include_input_xyz,
    include_input_dir=cfg.models.coarse.include_input_dir,
    use_viewdirs=cfg.models.coarse.use_viewdirs,
    num_layers=cfg.models.coarse.num_layers,
    hidden_size=cfg.models.coarse.hidden_size,
    include_expression=True
)
model_coarse.to(device)

# If a fine-resolution model is specified, initialize it.
model_fine = None
if hasattr(cfg.models, "fine"):
    model_fine = getattr(models, cfg.models.fine.type)(
        num_encoding_fn_xyz=cfg.models.fine.num_encoding_fn_xyz,
        num_encoding_fn_dir=cfg.models.fine.num_encoding_fn_dir,
        include_input_xyz=cfg.models.fine.include_input_xyz,
        include_input_dir=cfg.models.fine.include_input_dir,
        use_viewdirs=cfg.models.fine.use_viewdirs,
        num_layers=cfg.models.coarse.num_layers,
        hidden_size=cfg.models.coarse.hidden_size,
        include_expression=True
    )
    model_fine.to(device)

cuda


### read ckpt

In [68]:
ckpt = 'logs/person_3/checkpoint400000.ckpt'
checkpoint = torch.load(ckpt)
checkpoint.keys()

dict_keys(['iter', 'model_coarse_state_dict', 'model_fine_state_dict', 'optimizer_state_dict', 'loss', 'psnr', 'background', 'latent_codes'])

In [69]:
model_coarse.load_state_dict(checkpoint["model_coarse_state_dict"])
if checkpoint["model_fine_state_dict"]:
    try:
        model_fine.load_state_dict(checkpoint["model_fine_state_dict"])
    except:
        print(
            "The checkpoint has a fine-level model, but it could "
            "not be loaded (possibly due to a mismatched config file."
        )
if "height" in checkpoint.keys():
    hwf[0] = checkpoint["height"]
if "width" in checkpoint.keys():
    hwf[1] = checkpoint["width"]
if "focal_length" in checkpoint.keys():
    hwf[2] = checkpoint["focal_length"]
if "background" in checkpoint.keys():
    background = checkpoint["background"]
    if background is not None:
        print("loaded background with shape ", background.shape)
        background.to(device)
if "latent_codes" in checkpoint.keys():
    latent_codes = checkpoint["latent_codes"]
    use_latent_code = False
    if latent_codes is not None:
        use_latent_code = True
        latent_codes.to(device)
        print("loading index map for latent codes...")
        idx_map = np.load(cfg.dataset.basedir + "/index_map.npy").astype(int)
        print("loaded latent codes from checkpoint, with shape ", latent_codes.shape)
        print("idx_map shape", idx_map.shape)
model_coarse.eval()
if model_fine:
    model_fine.eval()

loaded background with shape  torch.Size([512, 512, 3])
loading index map for latent codes...
loaded latent codes from checkpoint, with shape  torch.Size([5435, 32])
idx_map shape (5318, 2)


In [70]:
idx_map

array([[   0, 3141],
       [   1,  571],
       [   2, 1928],
       ...,
       [5315, 2805],
       [5316, 3215],
       [5317, 3059]])

In [71]:
replace_background = True
if replace_background:
    from PIL import Image
    #background = Image.open('./view.png')
    background = Image.open(cfg.dataset.basedir + '/bg/00050.png')
    #background = Image.open("./real_data/andrei_dvp/" + '/bg/00050.png')
    background.thumbnail((H,W))
    background = torch.from_numpy(np.array(background).astype(float)).to(device)
    background = background/255
    print('loaded custom background of shape', background.shape)

    #background = torch.ones_like(background)
    #background.permute(2,0,1)

render_poses = render_poses.float().to(device)

loaded custom background of shape torch.Size([512, 512, 3])


### Create directory to save images to.

In [72]:
savedir = 'renders/person_2_rendered_frames_jupyter'
os.makedirs(savedir, exist_ok=True)
# if configargs.save_disparity_image:
#     os.makedirs(os.path.join(configargs.savedir, "disparity"), exist_ok=True)
# if configargs.save_error_image:
#     os.makedirs(os.path.join(configargs.savedir, "error"), exist_ok=True)
os.makedirs(os.path.join(savedir, "normals"), exist_ok=True)
# Evaluation loop
times_per_image = []

#render_poses = render_poses.float().to(device)
render_poses = poses[i_test].float().to(device)
#expressions = torch.arange(-6,6,0.5).float().to(device)
render_expressions = expressions[i_test].float().to(device)
#avg_img = torch.mean(images[i_train],axis=0)
#avg_img = torch.ones_like(avg_img)

#pose = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
#for i, pose in enumerate(tqdm(render_poses)):
index_of_image_after_train_shuffle = 0
# render_expressions = render_expressions[[300]] ### TODO render specific expression

#######################
no_background = False
no_expressions = False
no_lcode = False
nerf = False
frontalize = False
interpolate_mouth = False

#######################
if nerf:
    no_background = True
    no_expressions = True
    no_lcode = True
if no_background: background=None
if no_expressions: render_expressions = torch.zeros_like(render_expressions, device=render_expressions.device)
if no_lcode:
    use_latent_code = True
    latent_codes = torch.zeros(5000,32,device=device)

In [73]:
#print

# print("expression shape: ", expressions.shape)
print("render_expressions shape: ", render_expressions.shape)
# print("poses shape: ", poses.shape)
print("render_poses shape: ", render_poses.shape)

render_expressions shape:  torch.Size([1000, 76])
render_poses shape:  torch.Size([1000, 4, 4])


# class

In [74]:
def torch_normal_map(depthmap,focal,weights=None,clean=True, central_difference=False):
    W,H = depthmap.shape
    #normals = torch.zeros((H,W,3), device=depthmap.device)
    cx = focal[2]*W
    cy = focal[3]*H
    fx = focal[0]
    fy = focal[1]
    ii, jj = meshgrid_xy(torch.arange(W, device=depthmap.device),
                         torch.arange(H, device=depthmap.device))
    points = torch.stack(
        [
            ((ii - cx) * depthmap) / fx,
            -((jj - cy) * depthmap) / fy,
            depthmap,
        ],
        dim=-1)
    difference = 2 if central_difference else 1
    dx = (points[difference:,:,:] - points[:-difference,:,:])
    dy = (points[:,difference:,:] - points[:,:-difference,:])
    normals = torch.cross(dy[:-difference,:,:],dx[:,:-difference,:],2)
    normalize_factor = torch.sqrt(torch.sum(normals*normals,2))
    normals[:,:,0]  /= normalize_factor
    normals[:,:,1]  /= normalize_factor
    normals[:,:,2]  /= normalize_factor
    normals = normals * 0.5 +0.5

    if clean and weights is not None: # Use volumetric rendering weights to clean up the normal map
        mask = weights.repeat(3,1,1).permute(1,2,0)
        mask = mask[:-difference,:-difference]
        where = torch.where(mask > 0.22)
        normals[where] = 1.0
        normals = (1-mask)*normals + (mask)*torch.ones_like(normals)
    normals *= 255
    #plt.imshow(normals.cpu().numpy().astype('uint8'))
    #plt.show()
    return normals

def save_plt_image(im1, outname):
    fig = plt.figure()
    fig.set_size_inches((6.4,6.4))
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    #plt.set_cmap('jet')
    ax.imshow(im1, aspect='equal')
    plt.savefig(outname, dpi=80)
    plt.close(fig)
    
def cast_to_image(tensor, dataset_type):
    # Input tensor is (H, W, 3). Convert to (3, H, W).
    tensor = tensor.permute(2, 0, 1)
    tensor = tensor.clamp(0.0,1.0)
    # Convert to PIL Image and then np.array (output shape: (H, W, 3))
    img = np.array(torchvision.transforms.ToPILImage()(tensor.detach().cpu()))
    return img
    # # Map back to shape (3, H, W), as tensorboard needs channels first.
    # return np.moveaxis(img, [-1], [0])

# test 

In [75]:
sample = render_expressions[760:761,:]
sample.shape

torch.Size([1, 76])

### 1 eopch

In [76]:
for i, expression in enumerate(tqdm(sample)):
        start = time.time()
        rgb = None, None
        disp = None, None
        with torch.no_grad():
            pose = render_poses[i]
            
            #---?---
            ablate = 'None'

            if ablate == 'expression':
                pose = render_poses[100]
            elif ablate == 'latent_code':
                pose = render_poses[100]
                expression = render_expressions[100]
                if idx_map[100+i,1] >= 0:
                    #print("found latent code for this image")
                    index_of_image_after_train_shuffle = idx_map[100+i,1]
            elif ablate == 'view_dir':
                pose = render_poses[100]
                expression = render_expressions[100]
                _, ray_directions_ablation = get_ray_bundle(hwf[0], hwf[1], hwf[2], render_poses[240+i][:3, :4])

            # pose = pose[:3, :4]

            ## --------------Latent code: ---------------
            if use_latent_code:
                if idx_map[i,1] >= 0:
                    #print("found latent code for this image")
                    index_of_image_after_train_shuffle = idx_map[i,1]
            #index_of_image_after_train_shuffle = 10 ## TODO Fixes latent code
            #index_of_image_after_train_shuffle = idx_map[84,1] ## TODO Fixes latent code v2 for andrei
            index_of_image_after_train_shuffle = idx_map[10,1] ## TODO Fixes latent code - USE THIS if not ablating!

                # shape [32] 
            latent_code = latent_codes[index_of_image_after_train_shuffle].to(device) if use_latent_code else None
            ## --------------Latent code ends ---------------
            
            ray_origins, ray_directions = get_ray_bundle(hwf[0], hwf[1], hwf[2], pose)
            rgb_coarse, disp_coarse, _, rgb_fine, disp_fine, _, weights = run_one_iter_of_nerf(
                hwf[0],
                hwf[1],
                hwf[2],
                model_coarse,
                model_fine,
                ray_origins,
                ray_directions,
                cfg,
                mode="validation",
                encode_position_fn=encode_position_fn,
                encode_direction_fn=encode_direction_fn,
                expressions = expression,
                background_prior = background.view(-1,3) if (background is not None) else None,
                #background_prior = torch.ones_like(background).view(-1,3),  # White background
                latent_code = latent_code,
                ray_directions_ablation = None
            )
            
            
            
            ## ----------------calculate time 
            times_per_image.append(time.time() - start)
            tqdm.write(f"Avg time per image: {sum(times_per_image) / (i + 1)}")

100%|██████████| 1/1 [00:13<00:00, 13.80s/it]

Avg time per image: 13.79657769203186





In [77]:
##------------------generate 'normals' image------------
rgb = rgb_fine if rgb_fine is not None else rgb_coarse
normals = torch_normal_map(disp_fine, focal, weights, clean=True)
#normals = normal_map_from_depth_map_backproject(disp_fine.cpu().numpy())
save_plt_image(normals.cpu().numpy().astype('uint8'), os.path.join(savedir, 'normals', f"{i:04d}.png"))
#if configargs.save_disparity_image:
if False:
    disp = disp_fine if disp_fine is not None else disp_coarse
    #normals = normal_map_from_depth_map_backproject(disp.cpu().numpy())
    normals = normal_map_from_depth_map_backproject(disp_fine.cpu().numpy())
    save_plt_image(normals.astype('uint8'), os.path.join(configargs.savedir,'normals', f"{i:04d}.png"))
##------------------generate 'normals' image ends ------------
##------------------save image------------------------------------
if savedir:
    savefile = os.path.join(savedir, f"{i:04d}.png")
    imageio.imwrite(
        savefile, cast_to_image(rgb[..., :3], cfg.dataset.type.lower())
    )
    # if configargs.save_disparity_image:
    #     savefile = os.path.join(configargs.savedir, "disparity", f"{i:04d}.png")
    #     imageio.imwrite(savefile, cast_to_disparity_image(disp_fine))
    # if configargs.save_error_image:
    #     savefile = os.path.join(configargs.savedir, "error", f"{i:04d}.png")
    #     GT = images[i_test][i]
    #     fig = error_image(GT, rgb.cpu().numpy())
    #     #imageio.imwrite(savefile, cast_to_disparity_image(disp))
    #     plt.savefig(savefile,pad_inches=0,bbox_inches='tight',dpi=54)
##------------------save image ends------------------------------------        


In [56]:
pose = render_poses[i]
print(pose)
pose = pose[:3, :4]
print(pose)

tensor([[ 0.9960, -0.0099, -0.0890, -0.0479],
        [ 0.0194,  0.9941,  0.1071,  0.1003],
        [ 0.0874, -0.1084,  0.9903,  0.5183],
        [-0.0000,  0.0000,  0.0000,  1.0000]], device='cuda:0')
tensor([[ 0.9960, -0.0099, -0.0890, -0.0479],
        [ 0.0194,  0.9941,  0.1071,  0.1003],
        [ 0.0874, -0.1084,  0.9903,  0.5183]], device='cuda:0')


In [42]:
ro = ray_origins.view((-1, 3))
rd = ray_directions.view((-1, 3))
near = cfg.dataset.near * torch.ones_like(rd[..., :1])
far = cfg.dataset.far * torch.ones_like(rd[..., :1])
rays = torch.cat((ro, rd, near, far), dim=-1)

In [66]:
render_expressions.shape

torch.Size([1000, 76])

In [28]:
    restore_shapes = [
        ray_directions.shape,
        ray_directions.shape[:-1],
        ray_directions.shape[:-1],
    ]
    restore_shapes

[torch.Size([512, 512, 3]), torch.Size([512, 512]), torch.Size([512, 512])]

In [34]:
print(latent_code)
print(latent_code.shape)

tensor([-0.0008, -0.0012,  0.0016,  0.0048,  0.0019,  0.0028, -0.0037, -0.0016,
        -0.0117, -0.0039,  0.0035,  0.0012,  0.0096, -0.0092, -0.0006, -0.0067,
        -0.0002, -0.0017,  0.0042,  0.0006,  0.0018, -0.0069,  0.0007, -0.0007,
         0.0014,  0.0061, -0.0032, -0.0011, -0.0011, -0.0017, -0.0043,  0.0026],
       device='cuda:0')
torch.Size([32])


In [30]:
print(expression)
print(expression.shape)

tensor([-4.0813e-05,  7.0033e-02,  2.5274e-01, -9.2142e-02, -1.3292e-01,
         2.1672e-02,  7.1768e-02, -7.6453e-02, -2.0912e-01,  1.3151e-03,
         1.3103e-01,  2.9060e-01,  2.8297e-02, -5.2820e-02,  1.9694e-01,
         7.0626e-02, -1.1306e-02,  5.7903e-02, -2.1446e-01,  9.0488e-02,
         2.9039e-02,  2.9313e-02,  3.5732e-02, -4.8023e-02,  3.5983e-02,
         1.9504e-01,  3.3352e-02,  1.2454e-01,  4.5865e-02,  9.8833e-02,
         2.9979e-02,  6.5397e-02,  4.0572e-01, -4.6820e-02, -3.7146e-01,
        -2.0494e-02,  8.1898e-02,  4.1465e-02,  3.5226e-02,  3.0267e-02,
         1.9230e-01, -1.7592e-03,  1.8759e-02, -1.9538e-01, -1.4866e-01,
        -2.7522e-01,  1.6156e-01, -2.0175e-01,  1.4392e-01, -8.9506e-02,
        -1.2474e-01,  1.3857e-02, -8.5681e-02,  4.3055e-02,  1.3732e-01,
         9.9381e-03, -1.1707e-01, -4.1152e-02, -7.0993e-02,  8.9980e-02,
        -1.4620e-01,  5.9456e-02,  1.2137e-01,  3.7929e-02,  1.0319e-01,
        -3.1344e-01,  1.0509e-01,  2.0067e-01, -2.9

In [36]:
hwf[2].shape

(4,)

In [37]:
use_latent_code

True

In [41]:
print(latent_codes)
latent_codes.shape

tensor([[ 0.0173,  0.0127,  0.0038,  ...,  0.0172,  0.0074, -0.0171],
        [ 0.0362, -0.0402, -0.0246,  ...,  0.0041,  0.0423,  0.0309],
        [-0.0221, -0.0022, -0.0253,  ..., -0.0002, -0.0018,  0.0240],
        ...,
        [ 0.0081, -0.0251, -0.0101,  ...,  0.0024,  0.0366,  0.0027],
        [ 0.0070, -0.0065, -0.0115,  ..., -0.0205,  0.0206,  0.0041],
        [-0.0239, -0.0189, -0.0146,  ..., -0.0106, -0.0117,  0.0127]],
       device='cuda:0')


torch.Size([5507, 32])

In [44]:
idx_map

array([[   0,  615],
       [   1, 3788],
       [   2, 4448],
       ...,
       [5510, 3785],
       [5511, 3432],
       [5512, 2237]])

In [20]:
hwf

[512,
 512,
 array([-2.22321152e+03,  2.42276352e+03,  5.02588000e-01,  4.88307000e-01])]

In [22]:
hwf[2].shape

(4,)

In [67]:
512*512

262144