In [None]:
import os
import sys
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as Function
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from model.dataset_transfomation import SMPLRandomDataset

from model.SkeletonTransformationNet import TransformationNet, DeformationtionNet
from model.utils import SkinningLoss

from meshplot import plot
import igl
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("device:%s"%device)

In [None]:
batchSize = 8
trainDataset = SMPLRandomDataset(complexity = "simple", gender = "male")
trainIter = DataLoader(trainDataset, num_workers=0, batch_size=batchSize, shuffle=True, drop_last=True)
allTrainNUM = len(trainDataset)

In [None]:
net = TransformationNet()
net = net.to(device)
deformNet = DeformationtionNet(trainDataset.smplLayer.kintree_table)
deformNet.to(device)

In [None]:
lr = 0.0001
epochs = 100
opt = torch.optim.Adam(net.parameters(), lr=lr)
scheduler = CosineAnnealingLR(opt, epochs, eta_min=lr)
loss = nn.MSELoss()

In [None]:
beginTime = time.time()
net = net.train()

for epoch in range(epochs):
    epochBegin = time.time()
    trainLoss = 0.0
    completedNum = 0
    batchNum = 0
    for sV, sFacesOneRingIdx, sW, sRigW, sJ, rV, rFacesOneRingIdx, rW, rRigW, rJ, tV, tJ in trainIter:
        sV = sV.to(device).float()
        sW = sW.to(device).float()
        sFacesOneRingIdx = sFacesOneRingIdx.to(device).long()
        sJ = sJ.to(device).float()
        sRigW = sRigW.to(device).float()
        rV = rV.to(device).float()
        rFacesOneRingIdx = rFacesOneRingIdx.to(device).long()
        rW = rW.to(device).float()
        rJ = rJ.to(device).float()
        tV = tV.to(device).float()
        rRigW = rRigW.to(device).float()
        
        opt.zero_grad()

        prePose = net(sV, sFacesOneRingIdx, sRigW, sJ, rV, rFacesOneRingIdx, rRigW, rJ)
        preV = deformNet(sV[:, :, 0:3], sJ, prePose, sW)
        
        l= loss(preV, tV)
        l.backward()
        opt.step()
        
        batchLoss = float(l)
        trainLoss += batchLoss
        batchNum += 1
        completedNum += batchSize
        compltePer = completedNum/allTrainNUM
        leftTime = ((time.time() - epochBegin)/compltePer - (time.time() - epochBegin))/60
        sys.stdout.write('\r 训练%i / %i, 关节点损失: %f, 当前进度: %0.2f%%, 预计剩余%d分钟'%(completedNum, allTrainNUM, batchLoss, compltePer * 100, leftTime))
        sys.stdout.flush()
    torch.save(net.state_dict(), ".\stateDict\skinningNet_epochs.pkl")
    print("\n")
    print("epoch%d, 关节点损失: %f, 用时：%0.2f minuetes"%(epoch+1, trainLoss/batchNum, (time.time() - beginTime) / 60 ))
    print("----------------------------------")
    scheduler.step()

In [None]:
a = torch.randn(2, 24, 4, 4)
b = torch.randn(2, 6890, 24)

In [None]:
5760 / 2 / 24