In [1]:
import os
import torch as tc
import time
from src.star.star import STAR
from tqdm.notebook import tqdm
from torch.autograd import Variable
from src.mesh_manipulation import load_mesh
from pytorch3d.loss import point_mesh_face_distance
from pytorch3d.structures import Meshes, Pointclouds
device = tc.device("cuda" if tc.cuda.is_available() else "cpu")


In [2]:
def zero_poses(poses):
    poses = poses.clone().detach()
    poses[:,12:17] = 0
    poses[:,54:65] = 0
    poses = Variable(poses, requires_grad=True)
    return poses

In [3]:
all_vertices = []
for gender in ['male', 'female']:
    files_path = f"./data/MOVE4D/{gender}/"
    files = os.listdir(files_path)
    files.sort()
    files_iterator = tqdm(files[:10], desc=f"loading {gender} bodies", position=0)
    for file in files_iterator:
        body_vertices = load_mesh(files_path + file, device)
        body_vertices *= 0.001
        body_vertices -= body_vertices.mean(axis=0)
        all_vertices.append(body_vertices)
bodies = Pointclouds(points=all_vertices)

loading male bodies:   0%|          | 0/10 [00:00<?, ?it/s]

loading female bodies:   0%|          | 0/10 [00:00<?, ?it/s]

In [4]:
for nbetas in range(10,301):

    for bi, body in enumerate(bodies):


        poses = tc.FloatTensor(tc.zeros((1, 72))).to(device)
        poses = Variable(poses, requires_grad=True)
        betas = tc.FloatTensor(tc.zeros((1, nbetas))).to(device)
        betas = Variable(betas, requires_grad=True)
        trans = tc.FloatTensor(tc.zeros((1, 3))).to(device)
        trans = Variable(trans, requires_grad=True)
        gender = 'male' if bi < 5 else 'female'
        star = STAR(gender=gender, num_betas=nbetas)
        star_faces = star.faces[None,...].to(device)

        initial_time = time.time()

        for it in range(5000):
            optimizer = tc.optim.Adam([trans, betas, poses], lr=0.2)
            d = star(betas=betas, pose=poses, trans=trans)
            optimizer.zero_grad()
            star_meshes = Meshes(verts=d, faces=star_faces)
            loss = point_mesh_face_distance(star_meshes, body)
            loss.backward(retain_graph=True)
            optimizer.step()

        print(time.time() - initial_time)

0.33370065689086914
0.33507466316223145
0.35886311531066895
0.35014843940734863
0.3678295612335205
0.36156177520751953
0.34595179557800293


KeyboardInterrupt: 