In [None]:
import numpy as np
from lib.data.metainfo import MetaInfo
from lib.data.transforms import SiameseTransform
import hydra
from lib.utils import load_config
import matplotlib.pyplot as plt
import torch
from torch.nn.functional import cosine_similarity
from tqdm import tqdm
from lib.visualize.image import image_grid
from lightning import Trainer

def transform(normal):
    _transform = SiameseTransform()
    return _transform(normal).to("cuda")

def plot_images(images, size: int = 4):
    if isinstance(images, list):
        fig, axes = plt.subplots(1, len(images), figsize=(size, size))
        for ax, image in zip(axes, images):
            ax.imshow(image)
            ax.axis("off")  # Turn off axis
        plt.show()
    else:
        plt.figure(figsize=(size, size))
        plt.imshow(images)
        plt.show()

def siamese_loss(emb_1, emb_2):
    return 1 - cosine_similarity(emb_1, emb_2)


cfg = load_config("traverse_latent", ["+experiment/traverse_latent=siamese_train_train_1"])
model = hydra.utils.instantiate(cfg.model).to("cuda")
metainfo = MetaInfo("/home/borth/sketch2shape/data/shapenet_chair_4096")

gt_image_1 = np.asarray(metainfo.load_image(3, 11, 0))
gt_image_2 = np.asarray(metainfo.load_image(3, 11, 1))
wrong_image_1 = np.asarray(metainfo.load_image(5, 11, 1))

model.latent = model.latent_end
normal = model.capture_camera_frame().detach().cpu().numpy()

print("sketch", "gt_normal", "deepsdf_normal", "wrong_image_1")
plot_images([gt_image_1, gt_image_2, normal, wrong_image_1])

In [None]:
latent = model.deepsdf.lat_vecs.weight[5].clone()
std1 = torch.stack(lats).std(0)
mean1 = torch.stack(lats).mean(0)
x1 = ((latent - mean1) / std1)

std2 = model.deepsdf.lat_vecs.weight.std(0)
mean2 = model.deepsdf.lat_vecs.weight.mean(0)
x2 = ((latent - mean2) / std2).pow(2)
# sorted(x)[::-1][:40]
# sorted(x)
plt.plot(x1.detach().cpu().numpy())
plt.plot(x2.detach().cpu().numpy())

In [None]:
latent = model.deepsdf.lat_vecs.weight[5].clone()
mean1 = torch.stack(lats).mean(0)
grad_direction = (grad / torch.abs(grad).max()).detach().cpu().numpy()
# latent_direction = ((mean1-latent)/torch.abs((mean1-latent)).max()).detach().cpu().numpy()
plt.plot(np.sign(x1.detach().cpu().numpy()) * np.sign(grad_direction))
# plt.plot(np.sign(grad_direction))
# plt.plot(latent_direction)
# plt.plot(grad_direction - latent_direction)
pp = np.sign(x1.detach().cpu().numpy()) * np.sign(grad_direction)
high_diff_latents = reversed(torch.argsort(x1))
for i in range(10):
    print(pp[high_diff_latents[i]])

In [None]:
# model.latent = model.deepsdf.lat_vecs.weight[5].clone()
model.latent = torch.stack(lats).mean(0).clone()
normal = model.capture_camera_frame().detach().cpu().numpy()
plot_images(normal)

std1 = torch.stack(lats).std(0)
mean1 = torch.stack(lats).mean(0)
x1 = ((latent - mean1) / std1).pow(2)

images = []
high_diff_latents = reversed(torch.argsort(x1))
# high_diff_latents = torch.argsort(x1)
for idx, latent_idx in enumerate(high_diff_latents):
    model.latent[latent_idx] = mean1[latent_idx]
    normal = model.capture_camera_frame().detach().cpu().numpy()
    images.append(normal)
    if idx > 16:
        break
plot_images(images, size=16)

In [None]:
sketch = metainfo.load_image(3, 11, 0)
sketch_emb = model.siamese(transform(sketch)[None, ...])

model.latent = model.deepsdf.lat_vecs.weight[5].clone()
model.latent.requires_grad = True
points, surface_mask = model.deepsdf.sphere_tracing(
    latent=model.latent,
    points=model.deepsdf.camera_points,
    rays=model.deepsdf.camera_rays,
    mask=model.deepsdf.camera_mask,
)
rendered_normal = model.deepsdf.render_normals(
    latent=model.latent,
    points=points,
    mask=surface_mask,
)  # (H, W, 3)
normal = model.deepsdf.normal_to_siamese(rendered_normal)  # (1, 3, H, W)
normal_emb = model.siamese(normal)  # (1, D)

snn_loss = siamese_loss(sketch_emb, normal_emb)

(grad,) = torch.autograd.grad(
    outputs=snn_loss,
    inputs=model.latent,
    grad_outputs=torch.ones_like(snn_loss)
)
plt.plot(grad.detach().cpu().numpy())

In [None]:
latent = model.deepsdf.lat_vecs.weight[5]
# std = torch.stack(lats).std(0)
# mean = torch.stack(lats).mean(0)
std = model.deepsdf.lat_vecs.weight.std(0)
mean = model.deepsdf.lat_vecs.weight.mean(0)
x = ((model.deepsdf.lat_vecs.weight[3] - mean) / std).pow(2)
# # sorted(x)[::-1][:40]
x.mean()
# # sorted(x)
# plt.plot(x.detach().cpu().numpy())

# SNN L2-Dist Sketch - Normal

In [None]:
model.latent = model.latent_end
normal = model.capture_camera_frame().detach().cpu().numpy()
emb_1 = model.siamese(transform(gt_image_1)[None, ...])
emb_2 = model.siamese(transform(gt_image_2)[None, ...])
print(siamese_loss(emb_1, emb_2))
plot_images([gt_image_1, gt_image_2])

# SNN L2-Dist Sketch - DeepSDF Normal

In [None]:
model.latent = model.latent_end
normal = model.capture_camera_frame()
normal_input = model.deepsdf.normal_to_siamese(normal)

emb_1 = model.siamese(transform(gt_image_1)[None, ...])
emb_2 = model.siamese(normal_input)
print(siamese_loss(emb_1, emb_2))
plot_images([gt_image_1, normal.detach().cpu().numpy()])

# SNN L2-Dist Normal - DeepSDF Normal

In [None]:
model.latent = model.latent_end
normal = model.capture_camera_frame().detach().cpu().numpy()
emb_1 = model.siamese(transform(gt_image_2)[None, ...])
emb_2 = model.siamese(transform(normal)[None, ...])
print(siamese_loss(emb_1, emb_2))
plot_images([gt_image_2, normal])

# SNN L2-Dist Wrong - DeepSDF Normal

In [None]:
model.latent = model.latent_end
normal = model.capture_camera_frame().detach().cpu().numpy()
emb_1 = model.siamese(transform(wrong_image_1)[None, ...])
emb_2 = model.siamese(transform(normal)[None, ...])
print(siamese_loss(emb_1, emb_2))
plot_images([wrong_image_1, normal])

# SNN L2-Dist GT Normal All Views (90) - DeepSDF Normal

In [None]:
from tqdm import tqdm

loss = []
emb_normal = model.siamese(transform(normal)[None, ...])

for i in tqdm(range(90), total=90):
    gt_image = np.asarray(metainfo.load_image(3, i, 1))
    emb_image = model.siamese(transform(gt_image)[None, ...])
    _loss = siamese_loss(emb_normal, emb_image) 
    loss.append(_loss)
loss = torch.concatenate(loss)
print(f"{loss.mean()=}")
print(f"{loss.min()=}")
print(f"{loss.max()=}")
plot_images([gt_image, normal])

# 15-NN Latent Code Shape(3) Prior

In [None]:
# search for the top 15 latent codes and then print the SNN loss
model.latent_end = model.deepsdf.lat_vecs.weight[0] 
model.latent = model.deepsdf.lat_vecs.weight[0]
normal = model.capture_camera_frame().detach().cpu().numpy()
emb_normal = model.siamese(transform(normal)[None, ...])
lats = []
for _i in range(10):
    latent = model.latent_end
    dist = torch.abs(model.deepsdf.lat_vecs.weight - latent[None, ...]).mean(-1)
    idx = torch.argsort(dist, descending=False)
    model.latent = model.deepsdf.lat_vecs.weight[idx[_i]]
    lats.append(model.latent)
    metainfo.label_to_obj_id(idx[_i].item()), idx[_i], dist[idx[_i]]

    loss = []
    for i in range(90):
        gt_image = np.asarray(metainfo.load_image(idx[_i].item(), i, 0))
        emb_image = model.siamese(transform(gt_image)[None, ...])
        _loss = siamese_loss(emb_normal, emb_image) 
        loss.append(_loss)
    loss = torch.concatenate(loss)

    print(idx[_i], loss.mean())
    _n = model.capture_camera_frame().detach().cpu().numpy()
    plot_images(_n)

In [None]:
model.latent.shape

In [None]:
ids

In [None]:
idx = torch.argsort(dist, descending=False)
ids = list(idx[1:1+128].cpu().numpy())
model.latent = model.deepsdf.lat_vecs.weight[ids].mean(0)
normal = model.capture_camera_frame().detach().cpu().numpy()
plot_images(normal)

In [None]:
model.latent = torch.stack(lats).mean(0)
normal = model.capture_camera_frame().detach().cpu().numpy()
plot_images([normal, gt_image_1])

In [None]:
# settings
# start_latent_ids = [3, 3385, 2801, 1962, 1058, 782, 1328]  # couch nn
# start_latent_ids = [3, 0, 1, 2, 4, 5, 6] # couch far
# start_latent_ids = [3, 0, 1, 2, 4, 5, 6] # couch far
start_latent_ids = [0] # couch far
# start_latent_ids = list(range(128))
end_latent_id = 0
sketch_view = 11
traversal_steps = 20
image_skip = 4
azims = [40]
elevs = [-30]

# fetch and encode the sketch
sketch = metainfo.load_image(end_latent_id, sketch_view, 0)
sketch_emb = model.siamese(transform(sketch)[None, ...])

image_trajectories = []
loss_trajectories = []
for idx, start_latent_id in enumerate(start_latent_ids):
    image_trajectory = [sketch]
    loss_trajectory = []
    desc = f"{idx+1}/{len(start_latent_ids)}"
    for t in tqdm(np.linspace(1, 0, traversal_steps), desc=desc):
        start_latent = model.deepsdf.lat_vecs.weight[start_latent_id]
        start_latent = model.deepsdf.lat_vecs.weight[ids].mean(0)
        end_latent = model.deepsdf.lat_vecs.weight[end_latent_id]
        model.latent = t * start_latent + (1 - t) * end_latent

        # calculate the mean loss from all the views 
        loss = []
        for azim in azims:
            for elev in elevs:
                model.deepsdf.create_camera(azim=azim, elev=elev)
                rendered_normal = model.capture_camera_frame().detach().cpu().numpy()
                rendered_normal_emb = model.siamese(transform(rendered_normal)[None, ...])
                snn_loss = siamese_loss(sketch_emb, rendered_normal_emb)
                loss.append(snn_loss)
                image_trajectory.append(rendered_normal)
        loss = torch.stack(loss).mean()

        # add the loss to the trajectory
        loss_trajectory.append(loss.detach().cpu().numpy())
    loss_trajectories.append(loss_trajectory)
    image_trajectories.append(image_trajectory)

# # plot the images
# for image_trajectory in image_trajectories:
#     trajectory = [] 
#     for idx, img in enumerate(image_trajectory):
#         if idx == 0 or (idx-1) % image_skip == 0:
#             trajectory.append(img)
#     plot_images(trajectory, size=16)

# # plot the loss curves
# for obj_id, loss_trajectory in zip(start_latent_ids, loss_trajectories):
#     plt.plot(np.linspace(0, 1, traversal_steps), np.stack(loss_trajectory), label=obj_id)
#     plt.legend(loc="upper right")
# plt.ylabel("siamese_loss")
# plt.xlabel("traversal_steps")
# plt.show()

x = np.linspace(0, 1, 20)
mean = np.stack(loss_trajectories).mean(0)
std = np.stack(loss_trajectories).std(0)
plt.plot(x, mean)
plt.fill_between(x, (mean-std), (mean+std), color='b', alpha=0.1)
plt.ylabel("siamese_loss")
plt.xlabel("traversal_steps")

In [None]:

# settings
# normal_obj_id = 1962
normal_obj_id = 3
# normal_obj_id = 1328
sketch_obj_id = 3
sketch_view = 11
# azims = cfg.data.preprocess_siamese.azims 
# elevs = cfg.data.preprocess_siamese.elevs
azims = [0, 40, 80, 120, 160, 200, 240, 280, 320]
elevs = [-50, -30, -10, 10]

# fetch and encode the sketch
sketch = metainfo.load_image(sketch_obj_id, sketch_view, 0)
sketch_emb = model.siamese(transform(sketch)[None, ...])

image_trajectory = []
model.latent =  model.deepsdf.lat_vecs.weight[normal_obj_id]
# calculate the mean loss from all the views 
loss = []
for elev in tqdm(elevs):
    for azim in azims:
        model.deepsdf.create_camera(azim=azim, elev=elev)
        rendered_normal = model.capture_camera_frame().detach().cpu().numpy()
        rendered_normal_emb = model.siamese(transform(rendered_normal)[None, ...])
        snn_loss = siamese_loss(sketch_emb, rendered_normal_emb)
        loss.append(snn_loss)
        image_trajectory.append(rendered_normal)
loss = torch.stack(loss)

# visualize the grid
image_grid(image_trajectory, cols=len(azims), rows=len(elevs))
plt.show()

# visusalize the loss
loss = loss.reshape(len(elevs), len(azims))
plt.imshow(loss.detach().cpu().numpy(), vmin=0, vmax=1.0)
plt.colorbar()
plt.show()

In [None]:
# visusalize the loss
loss = loss.reshape(len(elevs), len(azims))
plt.imshow(loss.detach().cpu().numpy(), vmin=0, vmax=0.2)
plt.colorbar()
plt.show()

# Sample Latent Code from 15-NN Latent Code Shape(3) Prior -> Mean of Latent Codes

In [None]:
for s in np.linspace(2, 0, 10):
    noise = torch.rand_like(model.latent_end) * torch.stack(lats).std(0) * s
    model.latent = torch.stack(lats).mean(0) + noise
    normal = model.capture_camera_frame().detach().cpu().numpy()
    emb_1 = model.siamese(transform(gt_image_1)[None, ...])
    emb_2 = model.siamese(transform(normal)[None, ...])
    print(torch.norm(emb_1 - emb_2, dim=-1))
    plot_images(normal)

# SNN Retrieval Top 10 Latent Code Based on Sketch

In [None]:
from tqdm import tqdm

# encode the images
embs = []
for i in tqdm(range(4096)):
    img = metainfo.load_image(i, 11, 0) # load the normals from a good view
    emb = model.siamese(transform(img)[None, ...])
    embs.append(emb)
embs = torch.stack(embs).squeeze(1)

img = metainfo.load_image(3, 11, 0) # load the normals from a good view
query_emb= model.siamese(transform(img)[None, ...])

# get the top index in SNN space
k = 10
loss =[]
for emb in tqdm(embs):
    loss.append(siamese_loss(emb, query_emb))
loss = torch.concatenate(loss)
idx = torch.argsort(loss)[:k]  # (Q, k)
top_loss, top_idx = loss.take(idx).detach().cpu().numpy(), idx.detach().cpu().numpy()

snn_lats = []
for loss, i in zip(top_loss,top_idx):
    model.latent = model.deepsdf.lat_vecs.weight[i]
    _n = model.capture_camera_frame().detach().cpu().numpy()
    snn_lats.append(model.latent)
    print(loss, i)
    plot_images(_n, size=2) 

In [None]:
top_loss.shape

In [None]:
for s in np.linspace(2, 0, 10):
    noise = torch.rand_like(model.latent_end) * torch.stack(snn_lats).std(0) * s
    model.latent = torch.stack(snn_lats).mean(0) + noise
    normal = model.capture_camera_frame().detach().cpu().numpy()
    emb_1 = model.siamese(transform(gt_image_1)[None, ...])
    emb_2 = model.siamese(transform(normal)[None, ...])
    print(torch.norm(emb_1 - emb_2, dim=-1))
    plot_images(normal)

# Retrieval Generation

In [None]:
# val_id = 4192 # show 
val_id = 4193
val_sketch = metainfo.load_image(val_id, 11, 0) # load the normals from a good view
val_normal = metainfo.load_image(val_id, 11, 1) # load the normals from a good view
emb_1 = model.siamese(transform(val_sketch)[None, ...])
emb_2 = model.siamese(transform(val_normal)[None, ...])
print(torch.norm(emb_1 - emb_2, dim=-1))
plot_images([val_sketch, val_normal])

In [None]:
query_emb= model.siamese(transform(val_sketch)[None, ...])

# get the top index in SNN space
k = 10
sx = torch.sum(query_emb**2, dim=-1, keepdim=True)
sy = torch.sum(embs**2, dim=-1, keepdim=True)
dist = torch.sqrt(-2 * (query_emb @ embs.T) + sx + sy.T)  # (Q, I)
dist = dist.nan_to_num(0)
idx = torch.argsort(dist)[:, :k]  # (Q, k)
top_dist, top_idx = dist.take(idx).detach().cpu().numpy(), idx.detach().cpu().numpy()

val_lats = []
val_weights = []
for i in top_idx.flatten():
    model.latent = model.deepsdf.lat_vecs.weight[i]
    _n = model.capture_camera_frame().detach().cpu().numpy()
    val_lats.append(model.latent)
    emb_1 = model.siamese(transform(val_sketch)[None, ...])
    emb_2 = model.siamese(transform(_n)[None, ...])
    weight = torch.norm(emb_1 - emb_2, dim=-1)
    val_weights.append(weight)
    print(weight)
    plot_images(_n) 

In [None]:
for s in np.linspace(2, 0, 10):
    noise = torch.rand_like(model.latent_end) * torch.stack(val_lats).std(0) * s
    model.latent = torch.stack(val_lats).mean(0) + noise
    normal = model.capture_camera_frame().detach().cpu().numpy()
    emb_1 = model.siamese(transform(gt_image_1)[None, ...])
    emb_2 = model.siamese(transform(normal)[None, ...])
    print(torch.norm(emb_1 - emb_2, dim=-1))
    plot_images(normal)

# Weighted Latent Code

In [None]:
for i in range(0, 20, 2):
    w = torch.exp(-i * torch.concatenate(val_weights))
    w = w / w.sum()
    print(w)
    model.latent = (torch.stack(val_lats) * w[..., None]).sum(0)
    normal = model.capture_camera_frame().detach().cpu().numpy()
    emb_1 = model.siamese(transform(val_normal)[None, ...])
    emb_2 = model.siamese(transform(normal)[None, ...])
    print(torch.norm(emb_1 - emb_2, dim=-1))
    plot_images([val_sketch, normal, val_normal])

In [None]:
from lib.data.scheduler import Coarse2FineScheduler
scheduler = Coarse2FineScheduler(resolution=256, milestones=[1, 2])
scheduler.current_epoch = 0
image = scheduler.downsample(transform(gt_image_2), reducer="avg")
image = model.normal_to_image(image, resolution=64)

In [None]:
plt.imshow(image.detach().cpu().numpy())