In [None]:
import os
import sys
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as Function
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader

from model.dataset import SMPLRandomDataset
from model.SkinningNet import JointNet
from model.utils import JointLoss, getSkeletonOneRingIdx

from meshplot import plot
import igl
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("device:%s"%device)

In [None]:
batchSize = 16
trainDataset = SMPLRandomDataset(complexity = "all", 
                                 gender = "mixed", 
                                 dataSize = 5000, 
                                 vertexOrderRandom = True, 
                                 noise = 1, 
                                 rotate = False, 
                                 scale = True,
                                 translate = True,
                                 centre = False)
trainIter = DataLoader(trainDataset, num_workers=0, batch_size=batchSize, shuffle=True, drop_last=True)
allTrainNUM = len(trainDataset)

In [None]:
net = JointNet()
net = net.to(device)
net.load_state_dict(torch.load(".\stateDict\jointNet_pretrain.pkl"))

In [None]:
lr = 0.00005
epochs = 100
opt = torch.optim.Adam(net.parameters(), lr=lr)
scheduler = CosineAnnealingLR(opt, epochs, eta_min=0.1*lr)
loss = JointLoss()

In [None]:
beginTime = time.time()
net = net.train()

skeletonOneRingIdx = getSkeletonOneRingIdx("human")
skeletonOneRingIdx = torch.tensor(skeletonOneRingIdx).unsqueeze(0).repeat(batchSize, 1, 1).to(device).long()
for epoch in range(epochs):
    epochBegin = time.time()
    trainLoss = 0.0
    completedNum = 0
    batchNum = 0
    trainAcc = 0.0
    for V, facesOneRingIdx, rigW, joints in trainIter:
        V = V.to(device).float()
        facesOneRingIdx = facesOneRingIdx.to(device).long()
        rigW = rigW.to(device).float()
        joints = joints.to(device).float()
        opt.zero_grad()
        
        preJ = net(V, rigW, facesOneRingIdx, skeletonOneRingIdx)
        l = loss(preJ, joints)
        l.backward()
        opt.step()
        
        batchLoss = float(l)
        trainLoss += batchLoss
        batchNum += 1
        completedNum += batchSize
        compltePer = completedNum/allTrainNUM
        leftTime = ((time.time() - epochBegin)/compltePer - (time.time() - epochBegin))/60
        sys.stdout.write('\r 训练%i / %i, 关节点损失: %f, 当前进度: %0.2f%%, 预计剩余%0.2f分钟'%(completedNum, allTrainNUM, batchLoss, compltePer * 100, leftTime))
        sys.stdout.flush()

    torch.save(net.state_dict(), ".\stateDict\epochs\jointNet%d.pkl"%epoch)
    print("\n")
    print("epoch%d,  关节点损失: %f, 用时：%0.2f minuetes"%(epoch+1, trainLoss/batchNum, (time.time() - beginTime) / 60 ))
    print("----------------------------------")
    scheduler.step()

In [None]:
torch.save(net.state_dict(), ".\stateDict\jointNet_pretrain_noise.pkl")

In [None]:
batchSize = 8
testDataset = SMPLRandomDataset(complexity = "all", 
                                 gender = "mixed", 
                                 dataSize = 5000, 
                                 vertexOrderRandom = True, 
                                 noise = 0, 
                                 rotate = False, 
                                 scale = False,
                                 translate = False,
                                 centre = False)
testIter = DataLoader(testDataset, num_workers=0, batch_size=batchSize, shuffle=False, drop_last=False)
allTestNUM = len(testDataset)

In [None]:
net = JointNet()
net = net.to(device)
loss = JointLoss()
net.load_state_dict(torch.load(".\stateDict\jointNet_pretrain.pkl"))

In [None]:
beginTime = time.time()
net = net.eval()

testLoss = 0.0
completedNum = 0
batchNum = 0

skeletonOneRingIdx = getSkeletonOneRingIdx("human")
skeletonOneRingIdx = torch.tensor(skeletonOneRingIdx).unsqueeze(0).repeat(batchSize, 1, 1).to(device).long()

for V, facesOneRingIdx, rigW, joints in testIter:
    V = V.to(device).float()
    facesOneRingIdx = facesOneRingIdx.to(device).long()
    rigW = rigW.to(device).float()
    joints = joints.to(device).float()
        
    with torch.no_grad():
        preJ = net(V, rigW, facesOneRingIdx, skeletonOneRingIdx)
        l = loss(preJ, joints)
        
    batchLoss = float(l)
    testLoss += batchLoss
    batchNum += 1
    completedNum += batchSize
    compltePer = completedNum / allTestNUM
    leftTime = ((time.time() - beginTime)/compltePer - (time.time() - beginTime))/60
    sys.stdout.write('\r 测试%i / %i, 关节点损失: %f, 当前进度: %0.2f%%, 预计剩余%d分钟'%(completedNum, allTestNUM, batchLoss, compltePer * 100, leftTime))
    sys.stdout.flush()


In [None]:
print('测试完成')
print('测试样本%i个, 关节点损失: %f, 用时%d分钟'%(allTestNUM, testLoss/batchNum, leftTime))

In [None]:
for V, facesOneRingIdx, rigW, joints in trainIter:
    V = V.to(device).float()
    facesOneRingIdx = facesOneRingIdx.to(device).long()
    rigW = rigW.to(device).float()
    joints = joints.to(device).float()
    break

In [None]:
idx = 0
v = V[idx].detach().cpu().numpy()
j = joints[idx].detach().cpu().numpy()
prejoint = preJ[idx].detach().cpu().numpy()

In [None]:
p = plot(v, shading={"point_size": "0.04"})
#p.add_points(j, shading={"point_size": "0.2", "point_color":"blue"})
p.add_points(prejoint, shading={"point_size": "0.2", "point_color":"green"})

In [None]:
np.load(".\data\%s_pose.npy"%"test_simple").shape

In [None]:
j