In [None]:
import sys,os

import torch
import numpy as np

sys.path.append('..')

from modeling import SRNsModel
import util
from sklearn import mixture

_RENDERER = 'FC'
_ORTHO = False


_MODEL_PATH = '../log/080320new_lstm/checkpoints//epoch_0184_iter_090000.pth'
_OPT_CAM = False
_ORTHO = True
_TOT_NUM_INSTANCES = 122

_IMG_SIZE = 128
_OUT_SIZE = 128


model = SRNsModel(num_instances=_TOT_NUM_INSTANCES,
                  latent_dim=256,
                  renderer=_RENDERER,
                  tracing_steps=10,
                  freeze_networks=True,
                  out_channels=20,
                  img_sidelength=_IMG_SIZE,
                  output_sidelength=_OUT_SIZE,
                  opt_cam=_OPT_CAM,
                  orthogonal=_ORTHO,
                 )

util.custom_load(model, path=_MODEL_PATH, discriminator=None,
                 overwrite_embeddings=True, overwrite_cam=True)

model.eval()
model.cuda()

In [None]:
from dataset.face_dataset import FaceRandomPoseDataset
from torch.utils.data import DataLoader
# from torch.utils.tensorboard import SummaryWriter

import cv2
import imageio

_OUTPUT_DIR = './logs/latent_interpolation/050201face_seg_syns'
_MODE = 'sphere'
_R = 1.2

_NUM_OBSERVATIONS=25

output_dir = os.path.join(_OUTPUT_DIR, _MODE)

dataset = FaceRandomPoseDataset(
    num_instances=1, num_observations=_NUM_OBSERVATIONS, sample_radius=_R, mode=_MODE)

dataloader = DataLoader(dataset,
                     collate_fn=dataset.collate_fn,
                     batch_size=1,
                     shuffle=False,
                     drop_last=False)

In [None]:
# interp embedding

src_idx, trgt_idx = torch.randint(0, _NUM_INSTANCES, (2,)).squeeze().cuda()
num_interps = 8

print(src_idx, trgt_idx)

z_src = model.get_embedding({'instance_idx': src_idx}).unsqueeze(0).repeat(num_interps, 1)
z_trgt = model.get_embedding({'instance_idx': trgt_idx}).unsqueeze(0).repeat(num_interps, 1)

print(torch.max(z_src), torch.min(z_src), z_src.shape)

interp = torch.Tensor(np.linspace(0.0, 1.0, num_interps)).squeeze().cuda().unsqueeze(1)

z_interp = z_src * (1.0 - interp) + z_trgt * interp

embedding_mean = torch.zeros(256).cuda()

for idx in range(_NUM_INSTANCES):
    embedding_mean += model.get_embedding({'instance_idx': torch.Tensor([idx]).squeeze().long().cuda()})
    
embedding_mean /= _NUM_INSTANCES
print(torch.max(embedding_mean), torch.min(embedding_mean))

In [None]:
from torchvision.utils import make_grid
import matplotlib
import matplotlib.pyplot as plt
import os

matplotlib.rcParams['figure.figsize'] = [30, 5]

cam2worlds = []

with torch.no_grad():
    instance_idx = 0
        
    for idx, model_input in enumerate(dataloader):
        model_input, ground_truth = model_input
        pose = model_input['pose'].repeat(num_interps, 1, 1)
        intrinsics = model_input['intrinsics'].repeat(num_interps, 1, 1)
        uv = model_input['uv'].repeat(num_interps, 1, 1)
        
        print(idx, pose.shape, intrinsics.shape, uv.shape)
        cam2worlds.append(pose[0])
        
        predictions, depth_map = model(pose, z_interp, intrinsics, uv)
        
        B, _, C = predictions.shape
                
        pred = torch.argmax(predictions, dim=2, keepdim=True)
        output_img = util.lin2img(pred, color_map=dataset.color_map)
        output_img = make_grid(output_img, nrow=8, padding=1).permute((1, 2, 0)).cpu().numpy()
        
#         print(idx, output_img.shape, type(output_img), np.max(output_img))
        
        plt.imshow(output_img)
        plt.show()
        

In [None]:
cam2world = np.expand_dims(cam2worlds[7].cpu().numpy(), axis=0)
# print(cam2world)

np.save('/data/anpei/facial-data/seg_video/0000/cam2world.npy', cam2world)

import shutil
from glob import glob

cams = np.load('/data/anpei/facial-data/seg_video/0000/cams.npy', allow_pickle=True)

for cam in cams:
    

for dst_dir in glob(os.path.join('/data/anpei/facial-data/seg_video/fake', '[0-9]*')):
    print(dst_dir)
    np.save(os.path.join(dst_dir, 'cam2world.npy'), cam2world)

In [None]:
from seg_sampler import FaceSegSampler

sampler = FaceSegSampler()
print(sampler.uv.shape, sampler.intrinsics.shape)

smp_ins = sampler.sample_ins(100)
print(smp_ins.shape)


In [None]:
import sys,os
import configargparse

import torch
import numpy as np

from torchvision.utils import make_grid
import matplotlib
import matplotlib.pyplot as plt

matplotlib.rcParams['figure.figsize'] = [100, 5]

import util

_COLOR_MAP = np.asarray([[0, 0, 0], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], [255, 204, 204], [102, 51, 0], [
                        255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]])

_COLOR_MAP = torch.tensor(_COLOR_MAP, dtype=torch.float32) / 255.0

img = torch.from_numpy(smp_ins).squeeze(1).view(100,-1).unsqueeze(2)
output_img = util.lin2img(img, color_map=_COLOR_MAP)
output_img = make_grid(output_img, nrow=8, padding=1).permute((1, 2, 0)).cpu().numpy()
        
#         print(idx, output_img.shape, type(output_img), np.max(output_img))
        
plt.imshow(output_img)
plt.show()