In [None]:
import copy
import torch
import numpy as np
import torch_geometric.transforms as T
import matplotlib.pyplot as plt
import importlib
import os
import time
from scipy.spatial.transform import Rotation
from scipy.spatial.distance import cdist
import poisson_modelnet_40
importlib.reload(poisson_modelnet_40)

from poisson_modelnet_40 import (
    get_transform,
    get_rotation_transform,
    get_single_modelnet40_sample
)

import pose_estimation
importlib.reload(pose_estimation)
from pose_estimation import ICP, PointCloudMetropolisHastings, nearest_neighbor_src_dst#, metrics_per_step

import tbp.monty.frameworks.environment_utils.transforms
importlib.reload(tbp.monty.frameworks.environment_utils.transforms)
import tbp.monty.frameworks.environments.modelnet
importlib.reload(tbp.monty.frameworks.environments.modelnet)

from tbp.monty.frameworks.environment_utils.transforms import RandomRotate
from tbp.monty.frameworks.utils.metrics import TransformedPointCloudDistance, AngleDisparity, InverseMatrixDeviation
from tbp.monty.frameworks.environments.modelnet import ModelNet40OnlineOptimizationExactCopy



# Step 1: just see if we can implement ICP or similar.

- Load a single modelnet40 object.
- Sample k points from the mesh.
- Load the same object and apply a known rotation.
- Call pose_estimator(src, tgt)
- Decode the output of pose_estimator so we can compare to known transform
- Examine the fit of the learned transform
- Measure how much time passed
- Extend by replacing error_fn with poisson surface reconstruction

In [None]:
N_SAMPLES = 1024
dst_transform = get_transform(N_SAMPLES)
rot_transform = RandomRotate(axes=["y"], fix_rotation=True)
rotation_matrix = rot_transform.rotation_matrix
src_transform = rot_transform
dataset = ModelNet40OnlineOptimizationExactCopy(
    root=os.path.expanduser("~/tbp/datasets/ModelNet40/raw"),
    transform=None,  # raw torch geometric object
    train=True,
    num_samples_train=2,
    dst_transform=dst_transform,
    src_transform=rot_transform
    )
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=1)

In [None]:
icp = ICP(n_steps=100)

In [None]:
for pcs in dataloader:
    src, dst, label = pcs
    icp_pc = icp(src, dst)
    break


## Check that we can use the inverse transform to get src back to dst exactly

In [None]:
r_t = Rotation.from_matrix(rotation_matrix)
print(r_t.as_euler("xyz", degrees=True))
print(r_t.inv().as_euler("xyz", degrees=True))
print(AngleDisparity.disparity(r_t.inv(), r_t))


In [None]:
AngleDisparity()(**dict(inverse_rotation=r_t.inv(), params=r_t))

In [None]:
tsfm_inv = r_t.inv()
snp = src.squeeze(0).numpy()
dnp = dst.squeeze(0).numpy()
src_inv = tsfm_inv.apply(snp)

In [None]:
src_inv.shape

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
ax.set_title("Original and transformed point clouds")

ax.scatter(dnp[:, 1], dnp[:, 0], dnp[:, 2], c="b", alpha=0.1, label="Original")
# ax.scatter(snp[:, 1], snp[:, 0], snp[:, 2], c="r", alpha=0.3, label="Transformed")
ax.scatter(src_inv[:, 1], src_inv[:, 0], src_inv[:, 2], c="g", alpha=0.5, s=2., label="Inverse transform")
ax.legend()
plt.show()

# Looks good. Now manually examine results of ICP

In [None]:
icp_min_idx = np.argmin(icp.error_history)
icp_min_error = icp.error_history[icp_min_idx]
print(icp_min_idx)

In [None]:
fig, ax = plt.subplots()
ax.plot(icp.error_history)
ax.set_xlabel("Time step")
ax.set_ylabel("Error (sum over nearest points)")

In [None]:
icpt = icp.best_params
# Compare ground truth transform to estimate
print(f"Parameter estimate: \n{icpt.as_matrix()}")
print(f"true parameters: \n{torch.inverse(rotation_matrix)}")

In [None]:
print(icpt.as_euler("xyz", degrees=True))
print(r_t.as_euler("xyz", degrees=True))

In [None]:
print(AngleDisparity.disparity(r_t, icpt))

In [None]:
snp = src.squeeze(0).numpy()
dnp = dst.squeeze(0).numpy()

fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
ax.set_title("Original and transformed point clouds")

ax.scatter(dnp[:, 1], dnp[:, 0], dnp[:, 2], c="b", alpha=0.3, label="dst")
ax.scatter(snp[:, 1], snp[:, 0], snp[:, 2], c="r", s=2., alpha=0.3, label="src")
ax.legend()
plt.show()

In [None]:
icp_pc_np = icp_pc.squeeze(0).numpy()

fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
ax.set_title("Original and icp output pointclouds")

ax.scatter(dnp[:, 1], dnp[:, 0], dnp[:, 2], c="b", alpha=0.3, label="Original")
ax.scatter(icp_pc_np[:, 1], icp_pc_np[:, 0], icp_pc_np[:, 2], c="r", alpha=0.3, label="Estimated")
ax.legend()
plt.show()

In [None]:
fig = plt.figure(figsize=(6, 18))
icp_pc_np = icp_pc.squeeze(0).numpy()

ax_dst = fig.add_subplot(131, projection="3d")
ax_dst.scatter(dnp[:, 1], dnp[:, 0], dnp[:, 2], c="b", s=2., alpha=0.3)
ax_dst.set_title("Original")

ax_src = fig.add_subplot(132, projection="3d")
ax_src.scatter(snp[:, 1], snp[:, 0], snp[:, 2], c="r", s=2., alpha=0.3)
ax_src.set_title("Transformed")

ax_est = fig.add_subplot(133, projection="3d")
ax_est.scatter(icp_pc_np[:, 1], icp_pc_np[:, 0], icp_pc_np[:, 2], c="g", s=2., alpha=0.3)
ax_est.set_title("Estimated")
plt.show()


In [None]:
transforms = [icp.get_params(i) for i in range(len(icp.param_history))]
errors, pct_errors, angle_disparities, identity_deviations = metrics_per_step(snp, dnp, transforms, Rotation.from_matrix(rotation_matrix))

In [None]:
ad = np.array(angle_disparities).squeeze(1)
ad.shape

In [None]:
fig, ax = plt.subplots()
angles = ["x", "y", "z"]
for i in range(3):
    ax.plot(ad[:, i], label=angles[i])

plt.legend()
plt.show()

In [None]:
plt.plot(identity_deviations)

In [None]:
plt.plot(pct_errors)

In [None]:
plt.plot(errors)

In [None]:
m1 = rotation_matrix
m2 = icpt.as_matrix()
identity_est = np.dot(m2, m1)
identity_deviations = np.linalg.norm(np.eye(3) - identity_est)

In [None]:
identity_deviations

In [None]:
np.eye(3) - identity_est

In [None]:
# multiply src by 3 random rotations to give a baseline
r1 = Rotation.from_euler("xyz", np.random.uniform(0, 2*np.pi, 3))
r2 = Rotation.from_euler("xyz", np.random.uniform(0, 2*np.pi, 3))
r3 = Rotation.from_euler("xyz", np.random.uniform(0, 2*np.pi, 3))
snp = src.squeeze(0).numpy()
dnp = dst.squeeze(0).numpy()
pcr1, pcr2, pcr3 = r1.apply(snp), r2.apply(snp), r3.apply(snp)
e1, e2, e3 = np.linalg.norm(pcr1 - dnp, axis=1), np.linalg.norm(pcr2 - dnp, axis=1), np.linalg.norm(pcr3 - dnp, axis=1)
d = torch.norm(icp_pc - dst, dim=2).numpy()
dss = [e1, e2, e3, d.squeeze(0)]


In [None]:
fig, ax = plt.subplots()
ax.plot(errors)
ax.set_title("Mean pointwise error at every icp time step")
ax.set_xlabel("Time step")
ax.set_ylabel("Mean pointwise error (ground truth)")
for i, e in enumerate(dss[:-1]):
    ax.axhline(e.mean(), color="r", linestyle="dashed", label=f"src + trandom transform {i}")

plt.legend()

In [None]:
plt.plot(pct_errors)

In [None]:
plt.plot([rot_diffs[i][0] for i in range(len(rot_diffs))])

In [None]:
rot_diffs

In [None]:
fig, ax = plt.subplots(4, 1, sharex=True)
for i in range(4):
    ax[i].hist(dss[i], 25)
    ax[i].axvline(dss[i].mean(), color="r", linestyle="dashed")


In [None]:
# TransformedPointCloudDistance()(icp_pc.T, dst.T)

In [None]:
dst_dist = torch.cdist(dst, dst)

In [None]:
dst_dist.size()

In [None]:
dst_dist.max()

In [None]:
# TODO: nothing works because I changed the code so it expects output of dataloader

n_samples = 1024
obj = get_single_modelnet40_sample(idx=6)
obj_target = copy.deepcopy(obj)
obj_input = copy.deepcopy(obj)

raw_transform = get_transform(n_samples)
target_point_cloud = raw_transform(obj_target)
rot_transform, rotation_matrix = get_rotation_transform(n_samples, axes=["x"])
input_point_cloud = rot_transform(obj_input)

In [None]:
icp = ICP(n_steps=20)

In [None]:
icp_t0 = time.time()
icp_pointcloud = icp(
    src=input_point_cloud.numpy(),
    dst=target_point_cloud.numpy()
)
icp_t1 = time.time()
icp_time = icp_t1 - icp_t0

In [None]:
icp_min_idx = np.argmin(icp.distances)
icp_min_error = icp.distances[icp_min_idx]

In [None]:
fig, ax = plt.subplots()
ax.plot(icp.distances)
ax.set_xlabel("Time step")
ax.set_ylabel("Error (sum over nearest points")

In [None]:
T = icp.extract_final_transform()

## Compare ground truth transform to estimate

In [None]:
print(f"Parameter estimate: \n{T}")
print(f"true parameters: \n{rotation_matrix}")

In [None]:
print(R.from_matrix(T[:3, :3]).as_euler("xyz", degrees=True))
print(R.from_matrix(rotation_matrix).as_euler("xyz", degrees=True))

In [None]:
type(icp_pointcloud)

In [None]:
# TransformedPointCloudDistance()(icp_pointcloud, target_point_cloud.T)

icp_pointcloud - target_point_cloud.T

In [None]:
print(type(target_point_cloud))
print(type(icp.src))

## Spot check we can extract the full transform correctly

In [None]:
T = np.eye(4)
for tsfm in icp.transforms:
    T = tsfm.dot(T)

print(f"Estimated T: {T}")
print(f"output T: {icp.extract_final_transform()}")

In [None]:
T = np.eye(4)
for tsfm in icp.transforms:
    T = tsfm.dot(T)

In [None]:
m = input_point_cloud.shape[1]
src_final = icp.src
new_src = np.ones((m + 1, input_point_cloud.shape[0]))
new_src[:m, :] = copy.deepcopy(input_point_cloud.T)
est_final = T.dot(new_src)

In [None]:
np.isclose(est_final, src_final).sum() == est_final.shape[0] * est_final.shape[1]

## Metropolis Hastings Search

In [None]:
mcmc = PointCloudMetropolisHastings(
    n_steps=1_500,
    kappa=8,
    temp=0.1,
    # threshold=0.5,
)

In [None]:
mcmc_t0 = time.time()
mcmc_pointcloud = mcmc(
    src=input_point_cloud.numpy(),
    dst=target_point_cloud.numpy(),
)
mcmc_t1 = time.time()
mcmc_time = mcmc_t1 - mcmc_t0

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10,5))

ax[0].plot(mcmc.step_history, mcmc.error_history)
ax[0].set_ylabel("Total cdist error")

ax[1].plot(mcmc.ratio_history)
ax[1].plot(mcmc.step_history, mcmc.ratio_history[np.array(mcmc.step_history)], "x")

ax[0].set_title(f"Argmin error: {np.min(mcmc.error_history)}")

In [None]:
min_idx = np.argmin(mcmc.error_history)
rot = mcmc.param_history[min_idx].as_matrix()
print(f"Parameter estimate: \n{rot}")
print(f"true parameters: \n{rotation_matrix}")
print(f"min error: {mcmc.error_history[min_idx]}")

# Compare mcmc and icp on a single example

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 5))

ax[0].plot(mcmc.step_history, mcmc.error_history)
ax[0].set_ylabel("Error")
ax[0].set_title(f"Argmin error: {np.min(mcmc.error_history)}\nTotal time: {mcmc_time}")

ax[1].plot(icp.distances)
ax[1].set_ylabel("Error")
ax[1].set_title(f"Argmin error: {icp_min_error}\nTotal time: {icp_time}")

### Check implementations of scipy cdist, torch cdist, etc.

In [None]:
import numpy as np
import scipy.spatial as ss
from scipy.spatial.transform import Rotation as R


In [None]:
x1 = np.random.normal(0, 1, (3, 100))
x2 = np.random.normal(0, 1, (3, 100))

In [None]:
o = ss.distance.cdist(x1.T, x2.T)
assert o[0,0] == np.linalg.norm(x1[:,0] - x2[:,0])

In [None]:
x1 = torch.randn(100, 3)
x2 = torch.randn(100, 3)
o = torch.cdist(x1, x2, p=2)
assert o[0,0] == torch.norm(x1[0,:] - x2[0,:])

In [None]:
rot = R.from_euler("x", 45, degrees=True)
o = rot.apply(obj_target.pos)
print(o.dtype)
print(type(o))
print(o.size)

In [None]:
x = torch.randn(1, 10, 3)
xs = x.squeeze(dim=0)
xss = xs.unsqueeze(dim=0)

In [None]:
xs.size()
print(xss.size())

In [None]:
torch.tensor(np.random.normal(0, 1, 10))

In [None]:
x = torch.randn(10)
n1 = torch.norm(x, p=2)
n2 = np.linalg.norm(x.numpy())

In [None]:
n1

In [None]:
n2

In [None]:
x = torch.randn(100, 3)
n1 = torch.norm(x, p=2, dim=1)
n2 = np.linalg.norm(x, axis=1)

In [None]:
n1 == n2

In [None]:
n1

In [None]:
n2

In [None]:
x = torch.randn(3, 20)
new_point = torch.randn(3, 1)

In [None]:
sub = (x - new_point).numpy()

In [None]:
np.linalg.norm(sub, axis=0) > 

In [None]:
R

In [None]:
x = R.from_euler("x", 45, degrees=True)

In [None]:
isinstance(x, R)

In [None]:
x = np.random.normal(0, 1, (3, 3))
xr = R.from_matrix(x)

In [None]:
xr.as_euler("xyz")

In [None]:
355 % 360

In [None]:
361 % 360

In [None]:
(360 - 355) % 360