In [None]:
import numpy as np
from lib.data.metainfo import MetaInfo
from lib.data.transforms import SketchTransform
import hydra
from lib.utils import load_config
import matplotlib.pyplot as plt
from torchvision.transforms import v2
import torch
from torch.nn.functional import l1_loss


def stats(normal):
    mean = normal.reshape(-1, 3).mean(0)
    print(f"R mean: {mean[0]}")
    print(f"G mean: {mean[1]}")
    print(f"B mean: {mean[2]}")


# check why latent is not the right, maybe similar problem
def transform(normal):
    _transform = v2.Compose(
        [
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
            v2.Normalize(mean=[0.5], std=[0.5]),
        ]
    )
    return _transform(normal).to("cuda")


def plot_images(images):
    if isinstance(images, list):
        fig, axes = plt.subplots(1, len(images), figsize=(4, 4))
        for ax, image in zip(axes, images):
            ax.imshow(image)
            ax.axis("off")  # Turn off axis
        plt.show()
    else:
        plt.figure(figsize=(2, 2))
        plt.imshow(images)
        plt.show()


cfg = load_config("traverse_latent", ["+experiment/traverse_latent=debug_snn"])
model = hydra.utils.instantiate(cfg.model).to("cuda")
metainfo = MetaInfo("/home/borth/sketch2shape/data/shapenet_chair_4096")

gt_image_1 = np.asarray(metainfo.load_image(3, 11, 0))
gt_image_2 = np.asarray(metainfo.load_image(3, 11, 1))
wrong_image_1 = np.asarray(metainfo.load_image(2, 11, 1))

model.latent = model.latent_end
normal = model.capture_video_frame().detach().cpu().numpy()

print("sketch", "gt_normal", "deepsdf_normal", "wrong_image_1")
plot_images([gt_image_1, gt_image_2, normal, wrong_image_1])

# SNN L2-Dist Sketch - Normal

In [None]:
model.latent = model.latent_end
normal = model.capture_video_frame().detach().cpu().numpy()
emb_1 = model.siamese(transform(gt_image_1)[None, ...])
emb_2 = model.siamese(transform(gt_image_2)[None, ...])
print(torch.linalg.vector_norm(emb_1 - emb_2, dim=-1))
plot_images([gt_image_1, gt_image_2])

# SNN L2-Dist Sketch - DeepSDF Normal

In [None]:
model.latent = model.latent_end
normal = model.capture_video_frame().detach().cpu().numpy()
emb_1 = model.siamese(transform(gt_image_1)[None, ...])
emb_2 = model.siamese(transform(normal)[None, ...])
print(torch.linalg.vector_norm(emb_1 - emb_2, dim=-1))
plot_images([gt_image_1, normal])

# SNN L2-Dist Normal - DeepSDF Normal

In [None]:
model.latent = model.latent_end
normal = model.capture_video_frame().detach().cpu().numpy()
emb_1 = model.siamese(transform(gt_image_2)[None, ...])
emb_2 = model.siamese(transform(normal)[None, ...])
print(torch.linalg.vector_norm(emb_1 - emb_2, dim=-1))
plot_images([gt_image_2, normal])

# SNN L2-Dist Wrong - DeepSDF Normal

In [None]:
model.latent = model.latent_end
normal = model.capture_video_frame().detach().cpu().numpy()
emb_1 = model.siamese(transform(wrong_image_1)[None, ...])
emb_2 = model.siamese(transform(normal)[None, ...])
print(torch.linalg.vector_norm(emb_1 - emb_2, dim=-1))
plot_images([wrong_image_1, normal])

# SNN L2-Dist GT Normal All Views (90) - DeepSDF Normal

In [None]:
from tqdm import tqdm

l2_dist = []
emb_normal = model.siamese(transform(normal)[None, ...])

for i in tqdm(range(90), total=90):
    gt_image = np.asarray(metainfo.load_image(3, i, 1))
    emb_image = model.siamese(transform(gt_image)[None, ...])
    _l2_dist = torch.norm(emb_normal - emb_image, dim=-1)
    l2_dist.append(_l2_dist)
l2_dist = torch.concatenate(l2_dist)
print(f"{l2_dist.mean()=}")
print(f"{l2_dist.min()=}")
print(f"{l2_dist.max()=}")
plot_images([gt_image, normal])

# 15-NN Latent Code Shape(3) Prior

In [None]:
# search for the top 15 latent codes and then print the SNN loss
lats = []
for _i in range(10):
    latent = model.latent_end
    dist = l1_loss(model.model.lat_vecs.weight, latent, reduce=False).mean(-1)
    idx = torch.argsort(dist, descending=False)
    model.latent = model.model.lat_vecs.weight[idx[_i]]
    _n = model.capture_video_frame().detach().cpu().numpy()
    lats.append(model.latent)
    plot_images(_n)
    metainfo.label_to_obj_id(idx[_i].item()), idx[_i], dist[idx[_i]]
    l2_dist = []
    for i in range(90):
        gt_image = np.asarray(metainfo.load_image(idx[_i].item(), i, 0))
        emb_image = model.siamese(transform(gt_image)[None, ...])
        _l2_dist = torch.norm(emb_normal - emb_image, dim=-1)
        l2_dist.append(_l2_dist)
    l2_dist = torch.concatenate(l2_dist)
    print(idx[_i], l2_dist.mean())

# Sample Latent Code from 15-NN Latent Code Shape(3) Prior -> Mean of Latent Codes

In [None]:
for s in np.linspace(2, 0, 10):
    noise = torch.rand_like(model.latent_end) * torch.stack(lats).std(0) * s
    model.latent = torch.stack(lats).mean(0) + noise
    normal = model.capture_video_frame().detach().cpu().numpy()
    emb_1 = model.siamese(transform(gt_image_1)[None, ...])
    emb_2 = model.siamese(transform(normal)[None, ...])
    print(torch.norm(emb_1 - emb_2, dim=-1))
    plot_images(normal)

# SNN Retrieval Top 10 Latent Code Based on Sketch

In [None]:
from tqdm import tqdm

# encode the images
embs = []
for i in tqdm(range(4096)):
    img = metainfo.load_image(i, 11, 0) # load the normals from a good view
    emb = model.siamese(transform(img)[None, ...])
    embs.append(emb)
embs = torch.stack(embs).squeeze(1)

img = metainfo.load_image(3, 11, 0) # load the normals from a good view
query_emb= model.siamese(transform(img)[None, ...])

# get the top index in SNN space
k = 10
sx = torch.sum(query_emb**2, dim=-1, keepdim=True)
sy = torch.sum(embs**2, dim=-1, keepdim=True)
dist = torch.sqrt(-2 * (query_emb @ embs.T) + sx + sy.T)  # (Q, I)
dist = dist.nan_to_num(0)
idx = torch.argsort(dist)[:, :k]  # (Q, k)
top_dist, top_idx = dist.take(idx).detach().cpu().numpy(), idx.detach().cpu().numpy()

snn_lats = []
for i in top_idx.flatten():
    model.latent = model.model.lat_vecs.weight[i]
    _n = model.capture_video_frame().detach().cpu().numpy()
    snn_lats.append(model.latent)
    plot_images(_n) 

In [None]:
for s in np.linspace(2, 0, 10):
    noise = torch.rand_like(model.latent_end) * torch.stack(snn_lats).std(0) * s
    model.latent = torch.stack(snn_lats).mean(0) + noise
    normal = model.capture_video_frame().detach().cpu().numpy()
    emb_1 = model.siamese(transform(gt_image_1)[None, ...])
    emb_2 = model.siamese(transform(normal)[None, ...])
    print(torch.norm(emb_1 - emb_2, dim=-1))
    plot_images(normal)

# Retrieval Generation

In [None]:
# val_id = 4192 # show 
val_id = 4193
val_sketch = metainfo.load_image(val_id, 11, 0) # load the normals from a good view
val_normal = metainfo.load_image(val_id, 11, 1) # load the normals from a good view
emb_1 = model.siamese(transform(val_sketch)[None, ...])
emb_2 = model.siamese(transform(val_normal)[None, ...])
print(torch.norm(emb_1 - emb_2, dim=-1))
plot_images([val_sketch, val_normal])

In [None]:
query_emb= model.siamese(transform(val_sketch)[None, ...])

# get the top index in SNN space
k = 10
sx = torch.sum(query_emb**2, dim=-1, keepdim=True)
sy = torch.sum(embs**2, dim=-1, keepdim=True)
dist = torch.sqrt(-2 * (query_emb @ embs.T) + sx + sy.T)  # (Q, I)
dist = dist.nan_to_num(0)
idx = torch.argsort(dist)[:, :k]  # (Q, k)
top_dist, top_idx = dist.take(idx).detach().cpu().numpy(), idx.detach().cpu().numpy()

val_lats = []
val_weights = []
for i in top_idx.flatten():
    model.latent = model.model.lat_vecs.weight[i]
    _n = model.capture_video_frame().detach().cpu().numpy()
    val_lats.append(model.latent)
    emb_1 = model.siamese(transform(val_sketch)[None, ...])
    emb_2 = model.siamese(transform(_n)[None, ...])
    weight = torch.norm(emb_1 - emb_2, dim=-1)
    val_weights.append(weight)
    print(weight)
    plot_images(_n) 

In [None]:
for s in np.linspace(2, 0, 10):
    noise = torch.rand_like(model.latent_end) * torch.stack(val_lats).std(0) * s
    model.latent = torch.stack(val_lats).mean(0) + noise
    normal = model.capture_video_frame().detach().cpu().numpy()
    emb_1 = model.siamese(transform(gt_image_1)[None, ...])
    emb_2 = model.siamese(transform(normal)[None, ...])
    print(torch.norm(emb_1 - emb_2, dim=-1))
    plot_images(normal)

# Weighted Latent Code

In [None]:
for i in range(0, 20, 2):
    w = torch.exp(-i * torch.concatenate(val_weights))
    w = w / w.sum()
    print(w)
    model.latent = (torch.stack(val_lats) * w[..., None]).sum(0)
    normal = model.capture_video_frame().detach().cpu().numpy()
    emb_1 = model.siamese(transform(val_normal)[None, ...])
    emb_2 = model.siamese(transform(normal)[None, ...])
    print(torch.norm(emb_1 - emb_2, dim=-1))
    plot_images([val_sketch, normal, val_normal])

In [None]:
from lib.data.scheduler import Coarse2FineScheduler
scheduler = Coarse2FineScheduler(resolution=256, milestones=[1, 2])
scheduler.current_epoch = 0
image = scheduler.downsample(transform(gt_image_2), reducer="avg")
image = model.normal_to_image(image, resolution=64)

In [None]:
plt.imshow(image.detach().cpu().numpy())