In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
import shutil
import subprocess
import pydoc
from tqdm.notebook import tqdm

from omegaconf import OmegaConf
import hydra.experimental

import cv2
import numpy as np

import torch
from torch import nn

%matplotlib inline
from matplotlib import pylab as plt

# sys.path.append("..")
import face_expression
from face_expression import utils

In [None]:
USE_DEFAULT_CONFIG = False

In [None]:
if USE_DEFAULT_CONFIG:
    with hydra.experimental.initialize(config_path="../face_expression/config"):
        config = hydra.experimental.compose(config_name="config")
        OmegaConf.set_struct(config, False)  # allow adding new fields
        
    checkpoint_path = None
else:
    project_dir = "/Vol1/dbstore/datasets/k.iskakov/projects/face_expression"

    # run_name = "run-20200910_120005-2icqhqix" # siamese_mediapipe_2d
    # run_name = "run-20200910_142958-27s0s7pi" # siamese_dropout
    # run_name = "run-20200911_003858-1zs8eycn" # siamese_jaw_pose.weight-5.0
    # run_name = "run-20200911_005939-vwlb8xjz" # siamese_small

    # run_name = "run-20200915_200519-nitmitct" # siamese_bbox_filter

    # run_name = "run-20200916_182119-3khcgm0a"  # siamese_normalize_area-False_jaw_pose_weight-10.0
    # run_name = "run-20200916_212159-2kfaqcv4" # siamese_keypoint_l2_normalize_area-False
    # run_name = "run-20200917_122110-1qk66uzu" # siamese+keypoint_l2_loss+normalize-image_shape
    
#     run_name = "run-20200917_181208-g15oyjuo" # siamese+mediapipe_normalization
#     run_name = "run-20200917_181214-2p2mcq7o" # siamese+mediapipe_normalization+expression_weight-10
#     run_name = "run-20200917_181220-23sm1rck"  # siamese+mediapipe_normalization+use_beta-false
#     run_name = "run-20200917_182953-kr90pwk6"  # siamese+mediapipe_normalization+no_keypoint_l2_loss

#     run_name = "run-20200921_140154-2u0labuw"  # siamese+keypoints_3d

#     run_name = "run-20200923_190202-3lf0gggu"  # siamese+keypoints_3d
#     run_name = "run-20200923_185641-256g37gk"  # siamese+mouth
#     run_name = "run-20200923_180309-2vciol9p"  # siamese+keypoints_3d_loss+expression_loss
#     run_name = "run-20200923_180225-3kupdul7"  # siamese+keypoints_3d_loss

    run_name = "run-20200924_184732-2pugva9j" # siamese+mouth+keypoints_3d_loss+expression_loss
    
    experiment_dir = os.path.join(project_dir, "wandb", run_name)

    # checkpoint
    checkpoint_path = utils.common.get_lastest_checkpoint(os.path.join(experiment_dir, "checkpoints"))
    checkpoint_name = os.path.basename(checkpoint_path)
    print(f"Checkpoint: {os.path.basename(checkpoint_path)}")

    # load config
    config_path = os.path.join(experiment_dir, "config.yaml")
    with open(config_path) as f:
        config = OmegaConf.load(f)

In [None]:
print(f"checkpoint_path = {checkpoint_path}")
print(f"config_path = {config_path}")

In [None]:
# runner
runner_cls = pydoc.locate(config.runner.cls)
runner = runner_cls(config)
runner = runner.to(config.device)

if checkpoint_path is not None:
    state_dict = torch.load(checkpoint_path)
    runner.load_state_dict(state_dict)

runner.eval();

In [None]:
from face_expression.utils.misc import get_dataloaders

# load azure_people_test dataset
azure_people_test_data_config_path = "../face_expression/config/data/azure_people_test.yaml"
with open(azure_people_test_data_config_path) as f:
    azure_people_test_data_config = OmegaConf.load(f)
    
config.data.test = azure_people_test_data_config.data.test

modes = ('train', 'val')
# modes = ('train', 'val', 'test')
# modes = ('test',)
for mode in modes:
    config.data[mode].dataloader.args.batch_size = 128
    config.data[mode].dataloader.args.num_workers = 2
    config.data[mode].dataloader.args.shuffle = False
    
    config.data[mode].dataset.args.sample_range = [1500, float('+inf'), 1]

dataloaders = get_dataloaders(config, splits=modes)
# dataloader = dataloaders['train']
dataloader = dataloaders['val']
# dataloader = dataloaders['test']

In [None]:
if not USE_DEFAULT_CONFIG:
    # setup dirs for result
    root_result_dir = os.path.join(config.log.project_dir, "artifacts", "triple_video")

    result_dir = os.path.join(root_result_dir, config.log.experiment_name)
    frame_dir = os.path.join(result_dir, f"frames#{config.log.experiment_name}#{checkpoint_name}")
    output_video_path = os.path.join(result_dir, f"video#{config.log.experiment_name}#{checkpoint_name}.mp4")

    shutil.rmtree(frame_dir, ignore_errors=True)
    shutil.rmtree(output_video_path, ignore_errors=True)
    os.makedirs(result_dir, exist_ok=True)
    os.makedirs(frame_dir, exist_ok=True)

In [None]:
n_batches = 10
# n_batches = float('+inf')
count = 0
for i, input_dict in tqdm(enumerate(dataloader), total=min(n_batches, len(dataloader))):
    with torch.no_grad():
        input_dict = utils.common.dict2device(input_dict, config.device, dtype=torch.float32)
        if i >= n_batches:
            break
        
        output_dict = runner.forward(input_dict)
        
#         keypoints_3d_pred, _, _ = utils.misc.infer_smplx(
#             runner.smplx_model, output_dict['expression_pred'], output_dict['pose_pred'], input_dict['beta']
#         )
#         keypoints_2d_pred = utils.misc.project_keypoints_3d(keypoints_3d_pred, input_dict['projection_matrix'])

        
#         keypoints_3d_target, _, _ = utils.misc.infer_smplx(
#             runner.smplx_model, input_dict['expression'], input_dict['pose'], input_dict['beta']
#         )
#         keypoints_2d_target = utils.misc.project_keypoints_3d(keypoints_3d_target, input_dict['projection_matrix'])
        
#         keypoints_2d_target = keypoints_2d_target.detach().cpu().numpy()
        
#         loss = runner.keypoint_3d_l2_criterion(
#             keypoints_3d_pred[:, SMPLX_MOUTH_INDICES], keypoints_3d_target[:, SMPLX_MOUTH_INDICES]
#         ).item()


        expression_pred_norm = np.abs(output_dict['expression_pred'].cpu().numpy()).mean()
        expression_norm = np.abs(input_dict['expression'].cpu().numpy()).mean()
        print(f"expression norm: {expression_pred_norm}, {expression_norm}")
        
        triple_images = utils.vis.vis_triple_with_smplx(runner.smplx_model, runner.renderer, input_dict, output_dict, float('+inf'), alpha=0.5)

        for batch_index, triple_image in enumerate(triple_images):            
            image_path = os.path.join(frame_dir, f"{count:06d}.jpg")
            cv2.imwrite(image_path, cv2.cvtColor(triple_image, cv2.COLOR_BGR2RGB))

            if count % 100 == 0:
                plt.imshow(triple_image)
                plt.show()


            count += 1

In [None]:
cmd = [
    "ffmpeg",
    "-y",
    "-framerate", "25",
    "-i", os.path.join(frame_dir, "%06d.jpg"),
    "-c:v", "libx264",
    "-vf", "fps=25",
    "-pix_fmt", "yuv420p",
    output_video_path
]

result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if result.returncode:
    raise ValueError(result.stderr.decode("utf-8"))

In [None]:
print(f"Output video path:\n{output_video_path}")
print()
print(f"scp cluster:{output_video_path} ~/face_expression")

---