In [None]:
import argparse
import os
import sys
from os import mkdir

import numpy as np
import torch
import torch.nn.functional as F

sys.path.append('..')
from config import cfg
from data import make_data_loader
from engine.trainer import do_train
from modeling import build_model
from solver import make_optimizer, WarmupMultiStepLR
from layers import make_loss

from utils.logger import setup_logger
from utils.feats_pca import feats_map_pca_projection,feats_pca_projection
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
import cv2
from imageio_ffmpeg import write_frames
%load_ext autoreload
%autoreload 2
torch.cuda.set_device(2)

In [None]:
model_path = 'sport_1_training'
epoch = 908
para_file = 'nr_model_%d.pth' % epoch

In [None]:
cfg.merge_from_file(os.path.join(model_path,'configs.yml'))
#cfg.INPUT.SIZE_TRAIN = [1000,750]
#cfg.INPUT.SIZE_TEST = [1000,750]
cfg.SOLVER.IMS_PER_BATCH = 1
cfg.freeze()

In [None]:
writer = SummaryWriter(log_dir=os.path.join(model_path,'tensorboard_test'))

In [None]:
test_loader, dataset = make_data_loader(cfg, is_train=False)

In [None]:
model = build_model(cfg, isTrain = False)
model.load_state_dict(torch.load(os.path.join(model_path,para_file),map_location='cpu'))
model.eval()
model = model.cuda()

In [None]:
feature_maps = []
tars = []


i = 0
for batch in test_loader:
    in_points = batch[1].cuda()
    K = batch[2].cuda()
    T = batch[3].cuda()
    near_far_max_splatting_size = batch[5]
    num_points = batch[4]
    point_indexes = batch[0]
    target = batch[7].cuda()
    inds = batch[6].cuda()
    rgbs = batch[8].cuda()
    
    res,depth,features,dir_in_world,rgb,m_point_features = model(in_points, K, T,
                        near_far_max_splatting_size, num_points, rgbs, inds)
    
    i = i+1
    
    if (i>0):
        break
    
    

In [None]:
plt.imshow(res.detach().cpu()[0].permute(1,2,0)[:,:,0:3])
print(target.size())

In [None]:
_,in_points, _,_,_,_,rgbs =  dataset.__getitem__(1*56)
center = torch.mean(in_points,dim=0).cpu()
up = -torch.mean(dataset.datasets[0].Ts[:,0:3,1],dim =0)
up = up / torch.norm(up)
radius = torch.norm(dataset.datasets[0].Ts[0,0:3,3] - center) * 1.3

center = center 

v = torch.tensor([0,0,-1], dtype=torch.float32)
v = v - up.dot(v)*up
v = v / torch.norm(v)

K[:,0,2] = 400
K[:,1,2] = 300

s_pos = center - v * radius + up*radius*0.1

center = center.numpy()
up = up.numpy()
radius = radius.item()
s_pos = s_pos.numpy()

lookat = center - s_pos
lookat = lookat/np.linalg.norm(lookat)

xaxis = np.cross(lookat, up)
xaxis = xaxis / np.linalg.norm(xaxis)

In [None]:

import math
def rotate(angle):
    res = np.array([ [math.cos(angle), 0, math.sin(-angle)],[0,1,0],[ math.sin(-angle),0, math.cos(angle)]])
    return res

def rodrigues_rotation_matrix(axis, theta):
    axis = np.asarray(axis)
    theta = np.asarray(theta)
    axis = axis/math.sqrt(np.dot(axis, axis))
    a = math.cos(theta/2.0)
    b, c, d = -axis*math.sin(theta/2.0)
    aa, bb, cc, dd = a*a, b*b, c*c, d*d
    bc, ad, ac, ab, bd, cd = b*c, a*d, a*c, a*b, b*d, c*d
    return np.array([[aa+bb-cc-dd, 2*(bc+ad), 2*(bd-ac)],
                     [2*(bc-ad), aa+cc-bb-dd, 2*(cd+ab)],
                     [2*(bd+ac), 2*(cd-ab), aa+dd-bb-cc]])

index = 0
dx = 1
if not os.path.exists(os.path.join(model_path,'vis_%d'%epoch)):
    os.mkdir(os.path.join(model_path,'vis_%d'%epoch))
if not os.path.exists(os.path.join(model_path,'vis_mask_%d'%epoch)):
    os.mkdir(os.path.join(model_path,'vis_mask_%d'%epoch))
if not os.path.exists(os.path.join(model_path,'vis_depth_%d'%epoch)):
    os.mkdir(os.path.join(model_path,'vis_depth_%d'%epoch))
    
if not os.path.exists(os.path.join(model_path,'vis_compos_%d'%epoch)):
    os.mkdir(os.path.join(model_path,'vis_compos_%d'%epoch))
sKs = []
sTs = []
frames_id = []
    
for i in range(360):

    
    
    _,in_points, _,_,_,_,rgbs =  dataset.__getitem__(index*dataset.datasets[0].Ts.size(0))
    
    num_points = torch.Tensor([in_points.size(0)])
    rgbs = rgbs.cuda()
    print(num_points)
    
    in_points = in_points.cuda()
    
    if index >=dataset.datasets[0].frame_num-1:
        dx = -1
    
    angle = 3.1415926*2*i/360.0
    
    ii = i % 100
    angle = 3.1415926*(ii-50)/360.0
    #angle= 0 
    
    pos = s_pos - center
    pos = rodrigues_rotation_matrix(up,-angle).dot(pos) 
    
    pos = pos + center
    
    print('pos:',pos)
    
    
    lookat = center - pos
    lookat = lookat/np.linalg.norm(lookat)
    
    xaxis = np.cross(lookat, up)
    xaxis = xaxis / np.linalg.norm(xaxis)
    
    yaxis = -np.cross(xaxis,lookat)
    yaxis = yaxis/np.linalg.norm(yaxis)
    
    nR = np.array([xaxis,yaxis,lookat, pos]).T
    nR = np.concatenate([nR,np.array([[0,0,0,1]])])
    
    sTs.append(nR)
    sKs.append(K[0].cpu().numpy())
    frames_id.append(index)
    
    T[0,:,:] = torch.Tensor(nR).cuda()
    with torch.no_grad():
        res,depth,features,dir_in_world,rgb,m_point_features = model(in_points, K, T,
                            near_far_max_splatting_size, num_points, rgbs, inds)
    
    img_t = res.detach().cpu()[0]
    
    mask_t = img_t[3:4,:,:]
    
    img_t[0:3,:,:] = img_t[0:3,:,:]
    img = cv2.cvtColor(img_t.permute(1,2,0).numpy()*255.0,cv2.COLOR_BGR2RGB)
    mask = mask_t.permute(1,2,0).numpy()*255.0
    img_depth = depth.detach().cpu()[0][0].numpy()
    img_depth = img_depth *255/ np.max(img_depth)
    img_depth = cv2.cvtColor(img_depth,cv2.COLOR_GRAY2RGB)
   
    cv2.imwrite(os.path.join(model_path,'vis_%d/img_%04d.jpg'%(epoch,i)),img)
    
    img_t[0:3,:,:] = img_t[0:3,:,:]*mask_t.repeat(3,1,1)
    img_t[0:3,:,:][mask_t.repeat(3,1,1)<0.95] = 1.0
    img = cv2.cvtColor(img_t.permute(1,2,0).numpy()*255.0,cv2.COLOR_BGR2RGB)
    
    cv2.imwrite(os.path.join(model_path,'vis_compos_%d/img_%04d.jpg'%(epoch,i)),img)
    cv2.imwrite(os.path.join(model_path,'vis_mask_%d/img_%04d.jpg'%(epoch,i)  ),mask)
    cv2.imwrite(os.path.join(model_path,'vis_depth_%d/img_%04d.jpg'%(epoch,i)  ),img_depth)
    

    
    del res
    del depth
    del features
    del dir_in_world
    del rgb
    del img
    torch.cuda.empty_cache()
    print(i,'/360')
    
    

with open(os.path.join(model_path,'Intrinsic_%d.inf'%epoch), 'w') as f:
    for i,camk in enumerate(sKs):
        f.write('%d\n'%i)
        f.write('%f %f %f\n %f %f %f\n %f %f %f\n' % tuple(camk.reshape(9).tolist()))
        f.write('\n')
        
        
with open(os.path.join(model_path,'CamPose_%d.inf' %epoch), 'w') as f:
    for i,camT in enumerate(sTs):
        A = camT[0:3,:]
        tmp = np.concatenate( [A[0:3,2].T, A[0:3,0].T,A[0:3,1].T,A[0:3,3].T])
        f.write('%f %f %f %f %f %f %f %f %f %f %f %f\n' % tuple(tmp.tolist()))
    
with open(os.path.join(model_path,'frames_%d.inf' %epoch), 'w') as f:
    for i,ids in enumerate(frames_id):
        f.write('%d\n' % int(ids))