In [1]:
import sys,os
import configargparse

import torch
import numpy as np

from modeling import SRNsModel
import util

_MODEL_PATH = './logs/050201face_seg_syns/checkpoints/epoch_0010_iter_120000.pth'

_NUM_INSTANCES = 100

model = SRNsModel(num_instances=_NUM_INSTANCES,
                  latent_dim=256,
                  tracing_steps=10,
                  use_encoder=True,
                  feat_dim=228,
                  freeze_networks=True,
                  out_channels=20,
                  img_sidelength=128)

util.custom_load(model, path=_MODEL_PATH, discriminator=None,
                 overwrite_embeddings=False)

model.eval()
model.cuda()

[INIT embedding] encoder. 228 256


SRNsModel(
  (mapping_fn): Sequential(
    (0): Linear(in_features=228, out_features=256, bias=True)
    (1): Tanh()
  )
  (hyper_phi): HyperFC(
    (layers): ModuleList(
      (0): NewCls(
        (hyper_linear): HyperLinear(
          (hypo_params): FCBlock(
            (net): Sequential(
              (0): FCLayer(
                (net): Sequential(
                  (0): Linear(in_features=256, out_features=256, bias=True)
                  (1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
                  (2): ReLU(inplace=True)
                )
              )
              (1): FCLayer(
                (net): Sequential(
                  (0): Linear(in_features=256, out_features=256, bias=True)
                  (1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
                  (2): ReLU(inplace=True)
                )
              )
              (2): Linear(in_features=256, out_features=1024, bias=True)
            )
          )
        )
        (norm_n

In [2]:
from dataset.face_dataset import FaceRandomPoseDataset
from torch.utils.data import DataLoader
# from torch.utils.tensorboard import SummaryWriter

import cv2
import imageio

_OUTPUT_DIR = './logs/latent_interpolation/050201face_seg_syns'
_MODE = 'load'
_R = 10.0

_ROOT_DIR = '/data/anpei/facial-data/seg_face_syn'
_CAM_INT = os.path.join(_ROOT_DIR, 'intrinsics.txt')

_NUM_OBSERVATIONS=20

output_dir = os.path.join(_OUTPUT_DIR, _MODE)

dataset = FaceRandomPoseDataset(
    intrinsics=_CAM_INT, num_instances=1, num_observations=_NUM_OBSERVATIONS, sample_radius=_R, mode=_MODE)

dataloader = DataLoader(dataset,
                     collate_fn=dataset.collate_fn,
                     batch_size=1,
                     shuffle=False,
                     drop_last=False)

In [3]:
from glob import glob
import os

# load parameters

face_params = []

cam_fps = glob(os.path.join(_ROOT_DIR, '*', 'cameras.npy'))
for idx, cam_fp in enumerate(cam_fps):
#     print('| ', cam_fp)
    instance_dir = os.path.dirname(cam_fp)
    cam_params = np.load(cam_fp, allow_pickle=True).item()    
    world2cam = cam_params['extrinsics']
    
    face_params.append(np.squeeze(cam_params['instance']))
    
print(len(face_params), _CAM_INT)

5000 /data/anpei/facial-data/seg_face_syn/intrinsics.txt


In [None]:
from torchvision.utils import make_grid
import matplotlib
import matplotlib.pyplot as plt
import os

import random

matplotlib.rcParams['figure.figsize'] = [30, 5]

cam2worlds = []

num_interps = 16

_NUM_ATTRIBUTES = 480

# if observation_idx == _NUM_OBSERVATIONS - 1:
#     imageio.mimsave(os.path.join(instance_dir, 'output.gif'), images[instance_idx], fps=5.0)
#     print('=== [DONE] saving output.gif.')
        
os.makedirs(_OUTPUT_DIR, exist_ok=True)
    
with torch.no_grad():
    
    batch_params = np.asarray(random.sample(face_params, k=_NUM_ATTRIBUTES))
    batch_params[:, :164] = batch_params[0, :164]
        
    for idx, model_input in enumerate(dataloader):
        
        model_input, ground_truth = model_input
        pose = model_input['pose'].repeat(num_interps, 1, 1)
        intrinsics = model_input['intrinsics'].repeat(num_interps, 1, 1)
        uv = model_input['uv'].repeat(num_interps, 1, 1)
    
        images = []
        cur_idx = 0
    
        print('Process: ', idx)
        while cur_idx < _NUM_ATTRIBUTES:
            model_input['params'] = torch.from_numpy(batch_params[cur_idx:cur_idx+pose.shape[0]]).float()

            z_interp = model.get_embedding(model_input)

            predictions, depth_map = model(pose, z_interp, intrinsics, uv)

            B, _, C = predictions.shape

            pred = torch.argmax(predictions, dim=2, keepdim=True)
            output_img = util.lin2img(pred, color_map=dataset.color_map)
            output_img = make_grid(output_img, nrow=4, padding=0).permute((1, 2, 0)).cpu().numpy()            
            cur_idx += pose.shape[0]
    
            output_img = (output_img*255.0).round().clip(0, 255).astype(np.uint8)
            print(output_img.shape, output_img.dtype, np.max(output_img))
        
            images.append(output_img.copy())
    
#             plt.imshow(output_img)
#             plt.show()

        output_fp = os.path.join(_OUTPUT_DIR, 'output_%d.gif'%idx)
        imageio.mimsave(output_fp, images, fps=5.0)
        print('[DONE] Save output: %s, num_frames = %d'%(output_fp, len(images)))  

Process:  0
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
[DONE] Save output: ./logs/latent_interpolation/050201face_seg_syns/output_0.gif, num_frames = 30
Process:  1
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) 

[DONE] Save output: ./logs/latent_interpolation/050201face_seg_syns/output_9.gif, num_frames = 30
Process:  10
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
[DONE] Save output: ./logs/latent_interpolation/050201face_seg_syns/output_10.gif, num_frames = 30
Process:  11
(512, 512, 3) uint8 255
(512, 512, 3) uint8 255
(512, 512