In [1]:
%load_ext autoreload

import sys
sys.path.insert(0, '..')

import torch
from torch import nn
from external.QuaterNet.common import quaternion
import numpy as np
import math
import matplotlib.pyplot as plt

from models import super_shape
from models import super_shape_sampler
from models import periodic_shape_sampler_xyz
from layers import super_shape_functions
import utils
from visualize import plot
import plotly.graph_objects as go

In [6]:
batch = 1
m = 3
n = 1
n1 = 1
n2 = 10
n3 = 3
a = 1
b = 1
P = 50
dim = 3
sample_num = 30
points_num = 7

rotations = [[0., 0., 0.]]
transitions = [[0., 0., 0.]]
linear_scales = [[1., 1., 1.]]

preset_params = utils.generate_multiple_primitive_params(
    m,
    n1,
    n2,
    n3,
    a,
    b,
    rotations_angle=rotations,
    transitions=transitions,
    linear_scales=linear_scales,
    nn=n)

thetas = utils.sample_spherical_angles(
    sample_num=sample_num, batch=batch, dim=dim)
sampler = periodic_shape_sampler_xyz.PeriodicShapeSamplerXYZ(points_num, m, n, dim=dim, last_scale=0.5)
#sampler.get_periodic_net_r = (lambda *args, **kwargs: torch.zeros([batch, n, sample_num**2, 2]))
sampler.eval()
points = torch.ones([batch, points_num, dim]).float()
# B, N, P
radius = sampler.transform_circumference_angle_to_super_shape_radius(
    thetas, preset_params, points=points)
# B, P, dim
coord = sampler.transform_circumference_angle_to_super_shape_world_cartesian_coord(
    thetas, radius, preset_params, points=points).view(batch, -1, dim)

sgn = sampler.transform_world_cartesian_coord_to_tsd(
    coord, preset_params, points=points)

points_numpy = plot.tensor2numpy(coord)
points_numpy = plot.check_and_reduce_batch(points_numpy, 3, [3, 4])

plots = []
marker_opt = dict(size=1)
n = points_numpy.shape[0]
x = points_numpy[0, :, 0]
y = points_numpy[0, :, 1]
z = points_numpy[0, :, 2]
plots.append(
        go.Scatter3d(x=x, y=y, z=z, mode='markers', marker=marker_opt))
sgn = sgn.view(-1)
error_idx =  ~((-1e-5 < sgn) & (sgn < 1e-5))
print(sgn[error_idx])
print(error_idx)
plots.append(
        go.Scatter3d(x=x[error_idx], y=y[error_idx], z=z[error_idx], mode='markers', marker=marker_opt))
fig = go.Figure(data=plots)
fig.update_layout(scene_aspectmode='data')
fig.show()

error_thetas0 = thetas[..., 0].view(-1)[error_idx]/math.pi * 180.
error_thetas1 = thetas[..., 1].view(-1)[error_idx]/math.pi * 180.

for (theta, phi) in zip(error_thetas0, error_thetas1):
    print(theta, phi)


        1.1596, 1.1596, 1.1596, 1.1596, 1.1596, 1.1596, 1.1596, 1.1596,
          1.1596, 1.1596, 1.1596, 1.1596, 1.0034, 1.0034, 1.0034, 1.0034,
          1.0034, 1.0034, 1.0034, 1.0034, 1.0034, 1.0034, 1.0034, 1.0034,
          1.0034, 1.0034, 1.0034, 1.0034, 1.0034, 1.0034, 1.0034, 1.0034,
          1.0034, 1.0034, 1.0034, 1.0034, 1.0034, 1.0034, 1.0034, 1.0034,
          1.0034, 1.0034, 1.0034, 1.0034, 1.0034, 1.0034, 1.0034, 1.0034,
          1.0034, 1.0034, 1.0034, 1.0034, 1.0034, 1.0034, 1.0034, 1.0034,
          1.0034, 1.0034, 1.0034, 1.0034, 1.0034, 1.0034, 1.0034, 1.0034,
          1.0034, 1.0034, 1.0034, 1.0034, 1.0034, 1.0034, 1.0034, 1.0034,
          1.1596, 1.1596, 1.1596, 1.1596, 1.1596, 1.1596, 1.1596, 1.1596,
          1.1596, 1.1596, 1.1596, 1.1596, 1.1596, 1.1596, 1.1596, 1.1596,
          1.1596, 1.1596, 1.1596, 1.1596, 1.1596, 1.1596, 1.1596, 1.1596,
          1.1596, 1.1596, 1.1596, 1.1596, 1.1596, 1.1596, 1.4103, 1.4103,
          1.4103, 1.4103, 1.4103, 1.4103

In [3]:
def sphere_polar2cartesian(radius, angles):
    assert len(radius.shape) == len(angles.shape)
    assert radius.shape[-1] == 1
    theta = angles[..., 0]
    phi = torch.zeros(
        [1], device=theta.device) if angles.shape[-1] == 1 else angles[..., 1]

    x = radius.squeeze(-1) * theta.cos() * phi.cos()
    y = radius.squeeze(-1) * theta.sin() * phi.cos()
    coord = [x, y]
    if angles.shape[-1] == 2:
        z = phi.sin() * radius.squeeze(-1)
        coord.append(z)
    return torch.stack(coord, axis=-1)



In [4]:
batch = 1
m = 3
n = 1
n1 = 1
n2 = 10
n3 = 2
a = 1
b = 1
P = 50
dim = 3
sample_num = 9
points_num = 7

rotations = [[0., 0., 0.]]
transitions = [[0., 0., 0.]]
linear_scales = [[1., 1., 1.]]

preset_params = utils.generate_multiple_primitive_params(
    m,
    n1,
    n2,
    n3,
    a,
    b,
    rotations_angle=rotations,
    transitions=transitions,
    linear_scales=linear_scales,
    nn=n)

thetas = torch.tensor([math.pi/4, math.pi/4]).view(1, 1, 2)
sampler = super_shape_sampler.SuperShapeSampler(m, n, dim=dim)
sampler.eval()
# B, N, P
radius = sampler.transform_circumference_angle_to_super_shape_radius(
    thetas, preset_params)

print('input angle', thetas/math.pi*180)
print('primitve radius', radius)

coord = sampler.transform_circumference_angle_to_super_shape_world_cartesian_coord(
    thetas, radius, preset_params).view(batch, -1, dim)
print('coord', coord)


r1ext = 0.
r2ext = 2.
radius[..., 0] += r1ext
radius[..., 1] += r2ext

print('extended radius', radius)

# B, P, dim
coord = sampler.transform_circumference_angle_to_super_shape_world_cartesian_coord(
    thetas, radius, preset_params).view(batch, -1, dim)
print('coord extended', coord)

r1, r2, theta, phi = sampler.cartesian2polar(coord.view(1, 1, 1, 3), preset_params)
print('theta', theta/math.pi*180, 'phi', phi/math.pi*180)
print('r1', r1, 'r2', r2)

r1 += r1ext
r2 += r2ext


op = ((coord**2).sum(-1)).sqrt()
print('new theta', op)
oi = (((r1**2) * (r2**2) * (phi.cos()**2) + (r2**2) *
                    (phi.sin()**2))).sqrt()
sgn = 1 - op / oi

#sgn = sampler.transform_world_cartesian_coord_to_tsd(coord, preset_params)

print('sgn', sgn)


primitive_r = ((coord**2).sum(-1)).sqrt()
print(primitive_r)
print(primitive_r.shape, thetas.shape)
#angles = thetas
denominator1 = (((r1**2) * (r2**2) * (phi.cos()**2) + (r2**2) *
                    (phi.sin()**2))).sqrt()
print('denom1', denominator1)

angles = torch.stack([theta, phi], axis=-1)
print('angles', angles/math.pi * 180)
radius2= torch.stack([r1, r2], axis=-1)
coord3 = super_shape_functions.polar2cartesian(radius2, angles)
print('coord3', coord3)

print(coord.shape, angles.shape)
coord2 = sphere_polar2cartesian(primitive_r.view(1, 1, 1, 1), angles)
print(coord2)

input angle tensor([[[45., 45.]]])
primitve radius tensor([[[[1.3986, 1.3986]]]])
coord tensor([[[0.9780, 0.9780, 0.9890]]])
extended radius tensor([[[[1.3986, 3.3986]]]])
coord extended tensor([[[2.3766, 2.3766, 2.4032]]])
theta tensor([[[45.]]]) phi tensor([[[45.]]])
r1 tensor([[[1.3986]]]) r2 tensor([[[1.3986]]])
new theta tensor([[4.1318]])
sgn tensor([[[-1.1921e-07]]])
tensor([[4.1318]])
torch.Size([1, 1]) torch.Size([1, 1, 2])
denom1 tensor([[[4.1318]]])
angles tensor([[[[45., 45.]]]])
coord3 tensor([[[[2.3766, 2.3766, 2.4032]]]])
torch.Size([1, 1, 3]) torch.Size([1, 1, 1, 2])
tensor([[[[2.0659, 2.0659, 2.9217]]]])
