In [1]:
from common.opt import opts
import os
import glob
import torch
import random
import logging
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import torch.utils.data
import torch.optim as optim
from common.opt import opts
from common.utils import *
from common.camera import get_uvd2xyz
from common.load_data_hm36_tds import Fusion
from common.load_data_3dhp_mae import Fusion as Fusion2
from common.h36m_dataset import Human36mDataset
from model_temproal.ConvLift import ConvLift_mask,ConvLift
from model_transformer.MPMLift import MPMmask,MPM

In [2]:
import sys
from data.HybirdTools import HybirdSet
from torchstat import stat
from model_transformer.MPMLift import MPMmask
from torchsummary import summary
from thop import profile, clever_format
sys.argv = ['--MAE']
if __name__=='__main__':
    
    opt = opts().parse()
    opt.manualSeed = 1
    
    # seed 必须恒定
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)
    np.random.seed(opt.manualSeed)
    torch.cuda.manual_seed_all(opt.manualSeed)

    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    opt.MAE = True
    opt.layers = 4
    
    # model = MPMmask(opt).cuda()
    # summary(model, (243, 16, 3))
    
    model = MPMmask(opt)
    input = torch.randn(1, 243, 16, 2)
    flops, params = profile(model, inputs=(input), verbose=True)

    flops, params = clever_format([flops, params], "%.3f")
    print('MACs:' + flops)
    print('Params:' + params)



Dataset: ['h36m']
Task Name:   
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv1d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm1d'>.
[INFO] Register count_relu() for <class 'torch.nn.modules.activation.LeakyReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
MACs:778.969M
Params:3.214M


In [3]:
train_data2 = Fusion2(opt=opt, train=False, root_path=opt.root_path, MAE=opt.MAE)
dataset_path = opt.root_path + 'data_3d_' + 'h36m' + '.npz'
dataset = Human36mDataset(dataset_path, opt)
train_data = Fusion(opt=opt, train=False, dataset=dataset, root_path=opt.root_path, MAE=opt.MAE, tds=opt.t_downsample)




3dhp length:  2875
INFO: Testing on 2880 frames
h36m length:  543344
INFO: Testing on 543360 frames


In [160]:
# 数据集测试
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import random
%matplotlib inline
def D3vis(keypoints,name, angle_h = 10, angle_w = 70, mask = None):

    keypoints[0] = (keypoints[1] + keypoints[4] )/2
    leftI = [0,1,2,8,13,14,0,7,8]
    leftJ = [1,2,3,13,14,15,7,8,9]

    rightI = [8,10,11,0,4,5]
    rightJ = [10,11,12,4,5,6]
    
    left_joints = [0,1,2,13,14,15,7,8,9] 
    
    X=keypoints
    x = [k[0] for k in X]
    y = [k[1] for k in X]
    z = [k[2] for k in X]
    
    fig = plt.figure()
    if mask==None:
        mask = np.zeros(16).astype(np.int64)
    
    plt.title(name)
    ax = fig.add_subplot(111,projection='3d')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([])
    ax.view_init(angle_h, angle_w)
    ax.grid(False)
    # ax.set_xticklabels([])
    # ax.set_yticklabels([])
    ax.set_zticklabels([])
    ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
    ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
    # ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
    ax.w_xaxis.set_visible(False)
    ax.w_yaxis.set_visible(False)
    ax.w_zaxis.set_visible(False)
    # ax.set_xlabel('x')
    # ax.set_ylabel('y')
    # ax.set_zlabel('z')
    ax.w_xaxis.line.set_color("none")
    ax.w_yaxis.line.set_color("none")
    ax.w_zaxis.line.set_color("none")
    ax.axis('on')
    plt.xticks([])  # 隐藏x轴的刻度
    plt.yticks([])  # 隐藏y轴的刻度

    ax.set_zlim(0, 4)
    
    
    for k in range(0,len(x)):
        if mask[k]==0:
            ax.scatter(x[k],y[k],z[k],marker='o',s=75,color="#000000")
        else:
            ax.scatter(x[k],y[k],z[k],marker='o',s=80,color="#ffffff",zorder=100,  edgecolors='#000000')
        # ax.text(x[k],y[k],z[k],str(k))
    
    for k in range(len(leftI)): 
        plt.plot([x[leftI[k]],x[leftJ[k]]],[y[leftI[k]],y[leftJ[k]]],[z[leftI[k]],z[leftJ[k]]], linewidth=4, color='#ec4b3c',alpha=0.6) 
    for k in range(len(rightI)): 
        plt.plot([x[rightI[k]],x[rightJ[k]]],[y[rightI[k]],y[rightJ[k]]],[z[rightI[k]],z[rightJ[k]]], linewidth=4, color='#2f9ae4',alpha=0.6)    

    # plt.axis('off') 
    plt.show()

import random
def D2vis(keypoints,name, mask=):
    # 取第一个batch的f帧
    #fcb6af
    #98c4e2


    leftI = [0,1,2,8,13,14,0,7,8]
    leftJ = [1,2,3,13,14,15,7,8,9]

    rightI = [8,10,11,0,4,5]
    rightJ = [10,11,12,4,5,6]
    
    if mask==None:
        mask = np.zeros(16).astype(np.int64)


    # before
    plt.figure()
    plt.ylim(-1.0,1.0)
    plt.xlim(-1.0,1.0)

    X=keypoints
    x = [k[0] for k in X]
    y = [k[1] for k in X]

    plt.subplot(1,2,1)
    plt.title(name)
    plt.gca().invert_yaxis()

    for k in range(len(leftI)): 
        plt.plot([x[leftI[k]],x[leftJ[k]]],[y[leftI[k]],y[leftJ[k]]], linewidth=4, color='#ec4b3c')
    for k in range(len(rightI)): 
        plt.plot([x[rightI[k]],x[rightJ[k]]],[y[rightI[k]],y[rightJ[k]]], linewidth=4, color='#2f9ae4')
    for i in range(0,len(x)):
        if(1):
            plt.scatter(x[i],y[i],marker='o',s=75,color="#000000")
        else:
            plt.scatter(x[i],y[i],marker='o',s=80,color="#ffffff",zorder=100,  edgecolors='#000000')
        # plt.text(x[i],y[i],str(i))

    plt.xticks([])  # 隐藏x轴的刻度
    plt.yticks([])  # 隐藏y轴的刻度

    # plt.axis('off')
    plt.show()
    return

In [161]:
def normalize_data(data):
    """
    data: [B, j, 3]
    Return: [B, j, 3]
    """
    res_w, res_h = 1000, 1000
    assert data.ndim >= 3
    data = data.copy()
    data[..., :2] = data[..., :2] / res_w * 2 - [1, res_h / res_w]
    data[..., 2:] = data[..., 2:] / res_w * 2
    return data

a,_,_ = train_data[0]
_,b,_ = train_data[123]
print(a.shape,b.shape)

mask = np.zeros(16).astype(np.int64)
ran = random.sample(range(1,15),3)
mask[ran] = 1


k = 100
scale = 3
b[k,:, 2] = b[k,: , 2] * scale
print(b[k])
D2vis(a[k],"", mask)
D3vis(b[k],"")

(243, 16, 2) (243, 16, 3)
[[-7.36545995e-02 -7.61312991e-02  2.95198798e+00]
 [-1.93038419e-01 -1.06525317e-01  2.99630451e+00]
 [-2.04041988e-01 -6.87220767e-02  1.58465803e+00]
 [-2.23911792e-01  3.37322503e-02  2.14091390e-01]
 [ 4.57295328e-02 -4.57372069e-02  2.90767121e+00]
 [ 8.54947194e-02 -8.61959159e-02  1.50134921e+00]
 [ 1.28510401e-01  7.98514520e-04  1.25954717e-01]
 [-8.55847821e-02 -8.64226818e-02  3.71939182e+00]
 [-8.81381184e-02 -1.21283114e-01  4.46310616e+00]
 [-6.24963827e-02 -1.88656762e-01  5.07344818e+00]
 [ 6.93148822e-02 -4.27564383e-02  4.30901861e+00]
 [ 2.70345032e-01  2.93781571e-02  3.69281745e+00]
 [ 4.26190346e-01 -1.09057412e-01  3.28415155e+00]
 [-2.58935481e-01 -1.40386313e-01  4.27196407e+00]
 [-4.76661295e-01 -1.49022982e-01  3.66951895e+00]
 [-6.54739380e-01 -2.36126885e-01  3.21742845e+00]]


ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

In [30]:
import torch
import numpy as np
a = torch.randn(10,10)
s = a.clone()
mask_num = int(4)
mask2D = np.hstack([
    np.zeros(10 - mask_num),
    np.ones(mask_num),
]).flatten()

np.random.seed()
np.random.shuffle(mask2D)
mask = torch.from_numpy(mask2D).to(torch.bool)

mask_idx = torch.nonzero(mask).squeeze()
unmask_idx = torch.nonzero(~mask).squeeze()
rearrange_idx = torch.cat((unmask_idx, mask_idx),dim=0)
print(rearrange_idx )

a = a[:, ~mask]

b = torch.zeros(10, 4)
a = torch.concat((a, b), dim=1)
a[:, rearrange_idx] = a.clone()
print(s[: ,~mask]==a[:, ~mask])

print()


tensor([3, 4, 6, 7, 8, 9, 0, 1, 2, 5])
tensor([[True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True]])

