In [51]:
import sys
sys.path.insert(0, '..')
import kaolin as kal
from kaolin.datasets import shapenet
from kaolin import rep
from kaolin import conversions
import torch
from torch import nn
import matplotlib.pyplot as plt
import math
from periodic_shapes.models import periodic_shape_sampler
from periodic_shapes.models import super_shape_sampler
from periodic_shapes.models import super_shape
from periodic_shapes.models import model_utils
from periodic_shapes import utils
from periodic_shapes.losses import custom_chamfer_loss
import numpy as np
import random
import tqdm
from collections import defaultdict
from torch.autograd import Variable
import torch.optim as optim
import pickle
import os
import dotenv
import plotly.graph_objects as go
from periodic_shapes.visualize import plot


In [14]:
seed = 0
random.seed(seed)  
np.random.seed(seed)  
# PyTorch のRNGを初期化  
torch.manual_seed(seed)

<torch._C.Generator at 0x7f93dd9f58f0>

In [None]:
dotenv.load_dotenv(verbose=True)

category = 'plane'
cache_root = os.getenv('SHAPENET_KAOLIN_CACHE_ROOT')
shapenet_root = os.getenv('SHAPENET_ROOT')
cache_dir = os.path.join(cache_root, category)

categories = [category]


In [17]:
sdf_set = shapenet.ShapeNet_SDF_Points(root=shapenet_root, categories=categories, cache_dir=cache_dir, train=True, split=1.)
point_set = shapenet.ShapeNet_Points(root=shapenet_root, categories=categories, cache_dir=cache_dir, train=True, split=1.)
surface_set = shapenet.ShapeNet_Surface_Meshes(root=shapenet_root, categories=categories, cache_dir=cache_dir, train=True, split=1.)

converting to voxels: 100%|██████████| 4045/4045 [00:00<00:00, 36512.95it/s]
converting to surface meshes: 100%|██████████| 4045/4045 [00:00<00:00, 35735.11it/s]
converting to sdf points: 100%|██████████| 4045/4045 [00:00<00:00, 35017.24it/s]
converting to voxels: 100%|██████████| 4045/4045 [00:00<00:00, 35894.64it/s]
converting to surface meshes: 100%|██████████| 4045/4045 [00:00<00:00, 35952.98it/s]
converting to points: 100%|██████████| 4045/4045 [00:00<00:00, 35285.62it/s]
converting to voxels: 100%|██████████| 4045/4045 [00:00<00:00, 36633.81it/s]
converting to surface meshes: 100%|██████████| 4045/4045 [00:00<00:00, 35191.93it/s]


In [53]:
EPS = 1e-7
m = 4
n = 6
batch = 10
learning_rate = .01
iters = 500
dim = 3
sample_idx = 0
train_theta_sample_num = 30
points_sample_num = 3000
train_grid_sample_num = 5000

device_type = 'cuda:7'
#device_type = 'cpu'
train_periodic_after_abstraction = True

periodicnet = train_periodic_after_abstraction
device = torch.device(device_type)

if periodicnet:
    ocoef = 1.
    ccoef = 10.
else:
    ocoef = 1.
    ccoef = 1.

overlap_reg_coef = 1.
self_overlap_reg_coef = 1.

# points_num, dim
points = point_set[sample_idx]['data']['points'].to(device) * 10
all_points_sample_num = points.shape[0]

# grid_points_num, dim
xyz = sdf_set[sample_idx]['data']['sdf_points'].to(device) * 10
x = xyz[:, 0]
y = xyz[:, 1]
z = xyz[:, 2]
mesh = surface_set[sample_idx]['data']
meshkal = rep.TriangleMesh.from_tensors(mesh['vertices']*10,
                                    mesh['faces'])
meshkal.to('cuda:0')
sdf_func = kal.conversions.trianglemesh_to_sdf(meshkal, x.shape[0])
sgn = (sdf_func(xyz.to('cuda:0')).to(device) <= 0.001).float()
all_grid_sample_num = sgn.shape[0]

def get_target_sample():
    index = []
    for _ in range(batch):
        index_single = random.sample(range(all_grid_sample_num), train_grid_sample_num)
        index.extend(index_single)
    train_x = x[index]
    train_y = y[index]
    train_z = z[index]

    target_coord = torch.stack([train_x, train_y, train_z], axis=1).view(batch, -1, dim)

    target_sgn = sgn[index].float().view(batch, -1)

    points_index = []
    for _ in range(batch):
        index_single = random.sample(range(all_points_sample_num), points_sample_num)
        points_index.extend(index_single)
    target_points = points[points_index, :].view(batch, -1, dim)

    return target_points, target_coord, target_sgn


if train_periodic_after_abstraction:
    primitive.eval()
else:
    primitive = super_shape.SuperShapes(m, n, quadrics=True, train_logits=False, train_ab=False, dim=dim, transition_range=3, latent_dim=1)
    primitive.to(device)


if periodicnet or train_periodic_after_abstraction:
    print('Train periodic net')
    sampler = periodic_shape_sampler.PeriodicShapeSampler(points_sample_num, m, n, factor=2, dim=dim)
    optimizer = optim.Adam([*sampler.parameters(), *primitive.parameters()], lr=learning_rate)
else:
    print('Train super shape')
    sampler = super_shape_sampler.SuperShapeSampler(m, n, dim=dim)
    optimizer = optim.Adam(primitive.parameters(), lr=learning_rate)

sampler.to(device)
torch.autograd.set_detect_anomaly(True)

loss_log = []
for idx in tqdm.tqdm(range(iters)):
    optimizer.zero_grad()

    # Ensure polar coordinate samples are closed at 0 and 2 pi.
    thetas = utils.sample_spherical_angles(batch=batch, sample_num=train_theta_sample_num, sampling='uniform', device=device, dim=dim)

    target_points, target_coord, target_sgn = get_target_sample()

    kwargs = {
        'thetas': thetas,
        'coord': target_coord
    }
    if periodicnet:
        kwargs['points'] = target_points

    param = primitive(torch.zeros([batch, 1], device=device).float())
    print(param['m_vector'].device)

    #print(param['m_vector'].device)
    prd_points, prd_mask, prd_tsd, bnnp_tsd = sampler(param, **kwargs)

    print('prd tsd', prd_tsd.min().item(), prd_tsd.max().item())
    prd_sgn = model_utils.convert_tsd_range_to_zero_to_one(prd_tsd).sum(1)
    print('prd sgn', prd_sgn.min().item(), prd_sgn.max().item())

    overlap_reg = (nn.functional.relu(prd_sgn - 1.2).abs()).mean()

    bnnp_tsd_reshaped = bnnp_tsd.view(batch, n, n, -1)[:, range(n), range(n), :]
    self_overlap_reg = nn.functional.relu(bnnp_tsd_reshaped - 1e-1).mean()

    print('target sgn', target_sgn.min().item(), target_sgn.max().item())
    oloss = nn.functional.binary_cross_entropy_with_logits(prd_sgn.clamp(min=1e-7), target_sgn)

    closs = custom_chamfer_loss.custom_chamfer_loss(prd_points, target_points, surface_mask=prd_mask, prob=None, pykeops=False)

    reg = overlap_reg * overlap_reg_coef + self_overlap_reg * self_overlap_reg_coef
    loss = closs * ccoef + oloss * ocoef + reg

    print(closs.item(), reg.item(), loss.item(), oloss.item(), self_overlap_reg.item())
    loss_log.append(closs.detach().cpu().numpy())
    loss.backward()
    optimizer.step()

primitive.eval()
sampler.eval()



96
prd sgn 0.0 2.0
target sgn 0.0 1.0
0.05955249443650246 0.015071332454681396 1.2881584167480469 0.6775621175765991 0.004308873787522316
 83%|████████▎ | 416/500 [05:26<01:05,  1.28it/s]cuda:7
prd tsd -11584.29296875 0.78464275598526
prd sgn 0.0 2.0
target sgn 0.0 1.0
0.05960880219936371 0.014572544023394585 1.2882498502731323 0.6775893568992615 0.004249352961778641
 83%|████████▎ | 417/500 [05:27<01:05,  1.27it/s]cuda:7
prd tsd -11587.77734375 0.774344801902771
prd sgn 0.0 2.0
target sgn 0.0 1.0
0.05997646600008011 0.013551989570260048 1.2908709049224854 0.6775543093681335 0.0040023974142968655
 84%|████████▎ | 418/500 [05:28<01:04,  1.27it/s]cuda:7
prd tsd -11540.3740234375 0.7565293908119202
prd sgn 0.0 2.0
target sgn 0.0 1.0
0.060159459710121155 0.01314038597047329 1.292189121246338 0.6774540543556213 0.003768208669498563
 84%|████████▍ | 419/500 [05:29<01:03,  1.27it/s]cuda:7
prd tsd -11457.265625 0.730535089969635
prd sgn 0.0 2.036381959915161
target sgn 0.0 1.0
0.05989946424961

PeriodicShapeSampler(
  (encoder): PointNet(
    (conv1): Conv1d(3, 32, kernel_size=(1,), stride=(1,))
    (conv2): Conv1d(32, 64, kernel_size=(1,), stride=(1,))
    (conv3): Conv1d(64, 64, kernel_size=(1,), stride=(1,))
    (lin): Linear(in_features=64, out_features=64, bias=True)
    (bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn3): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): LeakyReLU(negative_slope=0.01, inplace=True)
  )
  (decoder): PrimitiveWiseGroupConvDecoder(
    (decoder): Sequential(
      (0): PrimitiveWiseLinear(
        (main): Sequential(
          (0): Conv1d(414, 192, kernel_size=(1,), stride=(1,), groups=6)
          (1): LeakyReLU(negative_slope=0.01, inplace=True)
        )
      )
      (1): 

In [54]:
points_list = [prd_points, target_points.unsqueeze(1)]
fig = plt.figure()
for idx, points in enumerate(points_list):
    plot.plot_primitive_point_cloud_3d(points)
"""
surface_points1 = sampler.extract_super_shapes_surface_point(prd_points[0, ...].unsqueeze(0), primitive(), points=target_points[0, ...].unsqueeze(0))
surafce_points_list = [surface_points1]
fig = plt.figure()
for idx, surface_points in enumerate(surafce_points_list):
    plot.plot_primitive_point_cloud_3d(surface_points)

tsd_list = [prd_sgn]
fig = plt.figure()
for idx, tsd in enumerate(tsd_list):
    plot.draw_primitive_inside_3d(tsd, target_coord, th=0.1)
"""


'\nsurface_points1 = sampler.extract_super_shapes_surface_point(prd_points[0, ...].unsqueeze(0), primitive(), points=target_points[0, ...].unsqueeze(0))\nsurafce_points_list = [surface_points1]\nfig = plt.figure()\nfor idx, surface_points in enumerate(surafce_points_list):\n    plot.plot_primitive_point_cloud_3d(surface_points)\n\ntsd_list = [prd_sgn]\nfig = plt.figure()\nfor idx, tsd in enumerate(tsd_list):\n    plot.draw_primitive_inside_3d(tsd, target_coord, th=0.1)\n'

<Figure size 432x288 with 0 Axes>

In [55]:
#torch.save(primitive.state_dict(), 'primitive_self_overlap_reg.pth')
#torch.save(sampler.state_dict(), 'sampler_self_overlap_reg.pth')