In [1]:
# Project files.
from tcsr.data.data_loader import DatasetClasses, DataLoaderDevicePairs
from tcsr.train import helpers as tr_helpers
from tcsr.visualize.renderer_pytorch3d import RendererPatchesUV
import tcsr.visualize.helpers as vis_helpers
from externals.jblib import helpers as jbh
from externals.jblib import file_sys as jbfs
from externals.jblib.deep_learning import torch_helpers as jbth

# 3rd party.
import torch
import imageio
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from IPython.display import display
from IPython.display import Image as iImage

# Python std.
import os
import math

%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import torch; print(torch.__version__)

1.7.0


In [3]:
### Settings.
# Paths
rel_paths_textures = {
    'ortho': 'textures/ortho.png',
    'diag': 'textures/diag.png'}
file_rend_config = 'conf_patches.yaml'

abs_path_out = ''  # <-- Set the abs. path for the output data.

# Training runs paths.
abs_paths_trruns = {  # <-- Fill in the abs. paths for the individual training runs.
    'an': '',
    'dsr': '',
    'our': '/cvlabdata2/home/jayakuma/testf_withsort/cat_walk/only_curv_neigh_9_alpha_1e-2__p10_bs4_DS_mode-neighbors5type-clean_centF_alignrotF'
}

# Models.
models_selected = ['our']  # <-- Choose the models to render the results for.

# Data.
subject = None  # <-- Only set as a string for DFAUST, CAPE, INRIA. E.g. '50002' for DFAUST.

# Rendering settings - from `file_rend_config`.
render_ds_specific = 'default'
render_seq_specific = 'default_uv_camplane'
texture_sel = 'diag'  # 'diag', 'orig'

# Whether to render and save GIFs and/or PNGs.
render_gif = True
render_pngs = False

# Meshes.
mesh_edge_verts = 11
num_patches = 10

# Technical
revert_cycle_gif = True
img_size = 400
bs = 16
dev = torch.device('cuda')
verbose = False

# Process params.
abs_path_base = os.path.abspath(os.path.curdir)
abs_path_texture = jbfs.jn(abs_path_base, rel_paths_textures[texture_sel])


SyntaxError: EOL while scanning string literal (<ipython-input-3-ad0e453c0546>, line 14)

In [None]:
# Load data.
conf_ds = jbh.load_conf(jbth.get_path_conf(abs_paths_trruns['our']))
dataset = conf_ds['ds']
sequence = conf_ds['sequences'][0]
subject = None if (
            (isinstance(subject, (list, tuple))
             and len(subject) == 0) or subject is None) else [subject]
ds = DatasetClasses[dataset](
    num_pts=conf_ds['N'], subjects=subject, sequences=[sequence],
    mode='within_seq', center=conf_ds['center'],
    align_rot=conf_ds['align_rotation'], resample_pts=True,
    with_reg=False, synth_rot=conf_ds['synth_rot'],
    synth_rot_ang_per_frame=conf_ds['synth_rot_ang_per_frame'],
    synth_rot_up=conf_ds['synth_rot_up'], noise=conf_ds['noise'],
    pairing_mode=conf_ds.get('ds_pairing_mode', 'standard'),
    pairing_mode_kwargs=conf_ds.get('ds_pairing_mode_params', None),
    ds_type=conf_ds.get('ds_type', 'clean'))
dl = DataLoaderDevicePairs(DataLoader(
        ds, batch_size=bs, shuffle=False, num_workers=1,
        drop_last=False), gpu=True)
    
# Create renderer.
conf_rend_all = jbh.load_conf(jbfs.jn(abs_path_base, file_rend_config))
confr = vis_helpers.get_rend_config(
    conf_rend_all, dataset, ds_spec=render_ds_specific,
    seq=sequence, seq_spec=render_seq_specific)
renderer = RendererPatchesUV(
    mesh_edge_verts, num_patches, abs_path_texture,
    img_size=img_size, camera=confr['camera'], 
    light_loc=np.array(confr['light']['location']), 
    light_colors=confr['light']['colors'], 
    camera_animation=confr['camera_anim'], 
    uv_style=confr['texture']['style'], 
    uv_style_kwargs=None, gpu=True)

# Process all models:
images_all = {}
process_models = list(abs_paths_trruns.keys()) \
    if models_selected is None else models_selected
for mi, m in enumerate(process_models):
    print(f"Processing method {m}")

    # Load model.
    path_trrun = abs_paths_trruns[m]
    path_conf, path_trstate = jbth.get_path_conf_tr_state(path_trrun)
    conf = jbh.load_conf(path_conf)
    model = tr_helpers.create_model_train(conf)
    model.load_state_dict(torch.load(path_trstate)['weights'])
    _ = model.eval()
    
    # Prepare UVs for camera plane.
    if confr['texture']['style'] == 'camera_plane':
        pts = ds[confr['texture']['style_args']
                 ['camera_plane']['ref_idx']]['pts'][:1]. \
            to(model.device)
        vp = model.predict_mesh(
            pts, mesh_edge_verts=mesh_edge_verts)[0]. \
            reshape((-1, 3)).detach().cpu().numpy()
        renderer._prepare_uvs_cam_plane(
            vp, cam_azi=confr['camera']['azi'], 
            cam_ele=confr['camera']['ele'])
    
    # Predict all samples.
    images_all[m] = vis_helpers.render_uv_patches(
        model, renderer, dl, confr, mesh_edge_verts, 
        model_type=conf['model'])
    
# Save visuals.
num_imgs = len(ds)
name_base = vis_helpers.name_from_config(
    render_ds_specific, render_seq_specific)
    
# Create joined images.
imgs_joined = []
for i in range(num_imgs):
    print(f"\rProcessing img {i + 1}/{num_imgs}", end='')
    imgs_joined.append(np.concatenate(
        [images_all[m][i] for m in process_models], axis=1))
imagesf_joined = np.stack(imgs_joined, axis=0)

# Save joint pngs.
if render_pngs:
    print('Generating pngs.')
    abs_path_pngs = jbfs.jn(abs_path_out, 'pngs', dataset, sequence,
        name_base + f"_size{img_size}_tex-{texture_sel}")
    jbfs.make_dir(abs_path_pngs)
    for imi, im in enumerate(imagesf_joined):
        print(f"\rSaving png {imi + 1}/{num_imgs}", end='')
        plt.imsave(jbfs.jn(abs_path_pngs, f"fr_{imi:04d}.png"), im)
            
# Render gif.
if render_gif:
    print('Generating gif.')
    imagesi_joined = (imagesf_joined * 255.).astype(np.uint8)
    fps = confr['gif']['fps']
    subsample = confr['gif']['subsample']

    abs_path_gif = jbfs.jn(abs_path_out, 'gif', dataset, sequence)
    jbfs.make_dir(abs_path_gif)
    nm_gif_out = f"{name_base}_fps{fps}_ss{subsample}" \
                 f"_size{img_size}_tex-{texture_sel}.gif"
    pth_out = jbfs.jn(abs_path_gif, nm_gif_out)
    
    imgsgif = imagesi_joined[::subsample]
    if revert_cycle_gif:
        imgsgif = np.concatenate([imgsgif, imgsgif[-2:0:-1]], axis=0)
    imageio.mimwrite(pth_out, imgsgif, fps=math.ceil(fps / subsample),
                     palettesize=32, subrectangles=True)

    # Display.
    img_disp = iImage(filename=pth_out)
    display(img_disp)