In [None]:
import sys
import time
import numpy as np
import torch
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader

from datasets.dataset import SMPLRandomDataset
from models.SkinningNet import JointNet, WeightBindingNet, SkinningNet
from utils.LossFunction import SkinningLoss
from utils.OneRingIdx import getSkeletonOneRingIdx

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("device:%s"%device)

In [None]:
batchSize = 8
trainDataset = SMPLRandomDataset(complexity = "skinning", 
                                 gender = "mixed", 
                                 dataSize = 5000, 
                                 vertexOrderRandom = True, 
                                 noise = 0, 
                                 rotate = True, 
                                 scale = True,
                                 translate = True,
                                 centre = False)
trainIter = DataLoader(trainDataset, num_workers=0, batch_size=batchSize, shuffle=True, drop_last=True)
allTrainNUM = len(trainDataset)

In [None]:
jointNet = JointNet()
jointNet.load_state_dict(torch.load("./results/JointNet/latest.pkl"))

weightNet = WeightBindingNet(jointNum = 24)
weightNet.load_state_dict(torch.load("./results/WeightBindingNet/latest.pkl"))

net = SkinningNet(jointNet, weightNet)
net = net.to(device)

In [None]:
lr = 0.00005
epochs = 100
opt = torch.optim.Adam(net.parameters(), lr=lr)
scheduler = CosineAnnealingLR(opt, epochs, eta_min=0.1*lr)
loss = SkinningLoss()

In [None]:
beginTime = time.time()
net = net.train()

for epoch in range(epochs):
    epochBegin = time.time()
    trainLoss = 0.0
    completedNum = 0
    batchNum = 0
    trainAcc = 0.0
    for V, facesOneRingIdx, rigW, joints in trainIter:
        V = V.to(device).float()
        joints = joints.to(device).float()
        facesOneRingIdx = facesOneRingIdx.to(device).long()
        rigW = rigW.to(device).float()
        
        opt.zero_grad()
        
        preJ, preA = net(V, facesOneRingIdx)
        jl, sl, acc = loss(preJ, preA, joints, rigW.permute(0, 2, 1))
        l = jl+sl
        l.backward()
        opt.step()
        
        batchLoss = float(jl)
        batchAcc = float(acc)
        trainAcc += batchAcc
        trainLoss += batchLoss
        batchNum += 1
        completedNum += batchSize
        compltePer = completedNum/allTrainNUM
        leftTime = ((time.time() - epochBegin)/compltePer - (time.time() - epochBegin))/60
        sys.stdout.write('\r Trainning %i / %i, joint loss: %f, acc: %0.2f%%, percentage: %0.2f%%, remain %d minuetes'%(completedNum, allTrainNUM, batchLoss, batchAcc*100, compltePer * 100, leftTime))
        sys.stdout.flush()
    
    save_path = './results/SkinningNet/'
    if epoch%10 == 0:
        torch.save(net.state_dict(),save_path + "%d.pkl"%epoch)
    torch.save(net.state_dict(), save_path + "latest.pkl")
    print("\n")
    print("Epoch%d, joint loss: %f, acc: %0.2f%%, time cost: %0.2f minuetes"%(epoch+1, trainLoss/batchNum, trainAcc/batchNum*100, (time.time() - beginTime) / 60 ))
    print("----------------------------------")
    scheduler.step()