# Utils

In [None]:
from lib.data.metainfo import MetaInfo
from lib.data.transforms import BaseTransform
import hydra
from lib.utils.config import load_config
import matplotlib.pyplot as plt
import torch
from lib.data.transforms import BaseTransform, DilateSketch, SketchTransform, ToSketch
from torchvision.transforms import v2
from lib.data.metainfo import MetaInfo
from lib.data.transforms import BaseTransform
import hydra
from lib.utils.config import load_config
import matplotlib.pyplot as plt
import torch
from torch.nn.functional import l1_loss
import cv2 as cv
import numpy as np
from lib.data.preprocess import PreprocessSynthetic
from lib.visualize.open3d import visualize_object


def plot_images_np(images, size: int = 4):
    if isinstance(images, list):
        _, axes = plt.subplots(1, len(images), figsize=(size, size))
        for ax, image in zip(axes, images):
            ax.imshow(image)
            ax.axis("off")  # Turn off axis
        plt.show()
    else:
        plt.figure(figsize=(size, size))
        plt.imshow(images)
        plt.axis("off")
        plt.show()


def plot_images(images, size: int = 4):
    if isinstance(images, list):
        _, axes = plt.subplots(1, len(images), figsize=(size, size))
        for ax, image in zip(axes, images):
            ax.imshow(image.detach().cpu().numpy())
            ax.axis("off")  # Turn off axis
        plt.show()
    else:
        plt.figure(figsize=(size, size))
        plt.imshow(images.detach().cpu().numpy())
        plt.show()


cfg = load_config("optimize_sketch", ["+dataset=shapenet_chair_4096", "logger=wandb"])
cfg.loss_ckpt_path = f"{cfg.paths.checkpoint_dir}/latent_traverse.ckpt"
metainfo = MetaInfo(cfg.data.data_dir)
sketch_transform = SketchTransform(normalize=False, kernel_size=3)

# Traversal Dataset

# FIND GOOD TRAVERSALS

In [None]:
source_id = 0
target_ids = [i for i in range(300)]
# target_ids = [1, 4, 15, 21, 24, 30, 44, 46, 65, 71, 88, 130, 140, 177, 188]
target_ids = [15, 21, 30, 46, 130, 140]

cfg.model.prior_obj_id = metainfo.obj_ids[0]
cfg.model.prior_mode = 0
model = hydra.utils.instantiate(cfg.model).to("cuda")
source_id = 0
base_transform = BaseTransform(normalize=False, transforms=[DilateSketch(3)])

t = 0.35
azim = 30
elev = -20
dist = 4.5
preprocess = PreprocessSynthetic(
    cfg.data.data_dir,
    deepsdf_ckpt_path=cfg.deepsdf_ckpt_path,
    n_renderings=1,
    random=False,
    azims=[azim],
    elev=[elev],
    dist=dist,
)
model.deepsdf.create_camera(azim=azim, elev=elev, dist=dist)
obj_id = metainfo.label_to_obj_id(source_id)
_, sketches, _, _, _ = preprocess.preprocess(obj_id=obj_id)
input_sketch = sketches[0]

# for i in torch.randint(4096, (100,)):
for i in target_ids:
    target_id = i
    target_obj_id = metainfo.label_to_obj_id(i)
    _, sketches, _, _, _ = preprocess.preprocess(obj_id=target_obj_id)
    target_sketch = sketches[0]

    source_latent = model.deepsdf.lat_vecs.weight[source_id]
    model.latent = source_latent
    source_normal = model.capture_camera_frame().detach().cpu().numpy()

    target_latent = model.deepsdf.lat_vecs.weight[target_id]
    model.latent = target_latent
    target_normal = model.capture_camera_frame().detach().cpu().numpy()

    interpolated_latent = (1 - t) * source_latent + t * target_latent
    model.latent = interpolated_latent
    interpolated_normal = model.capture_camera_frame()
    to_sketch = ToSketch()
    interpolated_sketch = to_sketch(interpolated_normal.permute(2, 0, 1)).detach().cpu()

    print(target_obj_id, i)
    plot_images_np(
        [
            base_transform(input_sketch).permute(1, 2, 0),
            base_transform(interpolated_sketch).permute(1, 2, 0),
            base_transform(target_sketch).permute(1, 2, 0),
        ],
        size=8,
    )

# GET THE FINAL IMAGE FOR THE REPORT

In [None]:
x_margin = 30
y_margin = 10
t = 0.35
azim = 30
elev = -20
dist = 4.5
source_id = 0
target_ids = [15, 21, 30, 46, 130, 140]

base_transform = BaseTransform(normalize=False, transforms=[DilateSketch(3)])
preprocess = PreprocessSynthetic(
    cfg.data.data_dir,
    deepsdf_ckpt_path=cfg.deepsdf_ckpt_path,
    n_renderings=1,
    random=False,
    azims=[azim],
    elev=[elev],
    dist=dist,
)
cfg.model.prior_obj_id = metainfo.obj_ids[0]
cfg.model.prior_mode = 0
model = hydra.utils.instantiate(cfg.model).to("cuda")
model.deepsdf.create_camera(azim=azim, elev=elev, dist=dist)
obj_id = metainfo.label_to_obj_id(source_id)
_, sketches, _, _, _ = preprocess.preprocess(obj_id=obj_id)

input_sketch = base_transform(sketches[0]).permute(1, 2, 0)
interpolated_sketches = []
target_sketches = []

for i in target_ids:
    target_id = i
    target_obj_id = metainfo.label_to_obj_id(i)
    _, sketches, _, _, _ = preprocess.preprocess(obj_id=target_obj_id)
    target_sketch = sketches[0]

    source_latent = model.deepsdf.lat_vecs.weight[source_id]
    model.latent = source_latent
    source_normal = model.capture_camera_frame().detach().cpu().numpy()

    target_latent = model.deepsdf.lat_vecs.weight[target_id]
    model.latent = target_latent
    target_normal = model.capture_camera_frame().detach().cpu().numpy()

    interpolated_latent = (1 - t) * source_latent + t * target_latent
    model.latent = interpolated_latent
    interpolated_normal = model.capture_camera_frame()
    to_sketch = ToSketch()
    interpolated_sketch = to_sketch(interpolated_normal.permute(2, 0, 1)).detach().cpu()

    interpolated_sketches.append(base_transform(interpolated_sketch).permute(1, 2, 0))
    target_sketches.append(base_transform(target_sketch).permute(1, 2, 0))

i_sk = []
for inter in interpolated_sketches:
    i_sk.append(inter[y_margin:-y_margin, x_margin:-x_margin, :])
plot_images_np(torch.concatenate(i_sk, dim=1), size=16)
i_sk = []
for inter in target_sketches:
    i_sk.append(inter[y_margin:-y_margin, x_margin:-x_margin, :])
plot_images_np(torch.concatenate(i_sk, dim=1), size=16)
i_sk = []
for inter in [input_sketch]:
    i_sk.append(inter[y_margin:-y_margin, x_margin:-x_margin, :])
plot_images_np(torch.concatenate(i_sk, dim=1), size=6)

# Debug Test SET

In [None]:
idxs = [30, 25, 56, 51, 98, 136, 143, 213]
idxs = [136, 30, 56, 98, 136, 143, 213]
for i in idxs:
    print(metainfo.label_to_obj_id(4224 + i))

In [None]:
for i in range(255):
    obj_label = 4224 + i  # 10, 4112, (4117), 12, 13, 4152
    print(i, obj_label, metainfo.label_to_obj_id(obj_label))
    azim = 30
    elev = -10

    # create sketch
    # img = metainfo.load_image(obj_label, 0, 10)
    img = metainfo.load_image(obj_label, 6, 9)
    transforms = [v2.Resize((256, 256)), ToSketch(), DilateSketch(kernel_size=5)]
    to_image = BaseTransform(normalize=False, transforms=transforms)
    sketch = to_image(img).permute(1, 2, 0)

    # setup model
    cfg.model.prior_mode = 0
    cfg.model.prior_obj_id = metainfo.obj_ids[obj_label]
    cfg.model.latent_init = "latent"
    model = hydra.utils.instantiate(cfg.model).to("cuda")
    model.deepsdf.create_camera(azim=azim, elev=elev)

    with torch.no_grad():
        points, surface_mask = model.deepsdf.sphere_tracing(
            latent=model.latent,
            points=model.deepsdf.camera_points,
            mask=model.deepsdf.camera_mask,
            rays=model.deepsdf.camera_rays,
        )
        normals = model.deepsdf.render_normals(
            points=points,
            latent=model.latent,
            mask=surface_mask,
        )
        grayscale = model.deepsdf.render_grayscale(
            points=points,
            latent=model.latent,
            mask=surface_mask,
        )
        # grayscale = model.deepsdf.normals_to_grayscales(normals)
        silhouette = model.deepsdf.render_silhouette(
            normals=normals,
            points=points,
            latent=model.latent,
            proj_blur_eps=0.7,
            weight_blur_kernal_size=9,
            weight_blur_sigma=9.0,
        )
    plot_images(
        [
            sketch,
            grayscale,
            normals,
            silhouette["base_silhouette"],
            silhouette["min_sdf"],
            silhouette["extra_silhouette"],
            silhouette["proj_silhouette"],
            silhouette["proj_blur_silhouette"],
            silhouette["base_blur_silhouette"],
            silhouette["weighted_silhouette"],
            silhouette["final_silhouette"],
        ],
        size=32,
    )

# FINAL REPORT SILHOUETS

In [None]:
idxs = [136, 30, 56, 98, 136, 143, 213]
# obj_label = 4112
obj_label = 4224 + 30
print(obj_label, metainfo.label_to_obj_id(obj_label))
azim = 90
elev = -10

# create sketch
# img = metainfo.load_image(obj_label, 0, 10)
img = metainfo.load_image(obj_label, 6, 9)
transforms = [v2.Resize((256, 256)), ToSketch(), DilateSketch(kernel_size=5)]
to_image = BaseTransform(normalize=False, transforms=transforms)
sketch = to_image(img).permute(1, 2, 0)

# setup model
cfg.model.prior_mode = 0
cfg.model.prior_obj_id = metainfo.obj_ids[obj_label]
cfg.model.latent_init = "latent"
model = hydra.utils.instantiate(cfg.model).to("cuda")
model.deepsdf.create_camera(azim=azim, elev=elev)


with torch.no_grad():
    points, surface_mask = model.deepsdf.sphere_tracing(
        latent=model.latent,
        points=model.deepsdf.camera_points,
        mask=model.deepsdf.camera_mask,
        rays=model.deepsdf.camera_rays,
    )
    normals = model.deepsdf.render_normals(
        points=points,
        latent=model.latent,
        mask=surface_mask,
    )
    grayscale = model.deepsdf.render_grayscale(
        points=points,
        latent=model.latent,
        mask=surface_mask,
    )
    # grayscale = model.deepsdf.normals_to_grayscales(normals)
    silhouette = model.deepsdf.render_silhouette(
        normals=normals,
        points=points,
        latent=model.latent,
        proj_blur_eps=0.7,
        weight_blur_kernal_size=9,
        weight_blur_sigma=9.0,
    )
plot_images(
    [
        sketch,
        grayscale,
        normals,
        silhouette["base_silhouette"],
        silhouette["min_sdf"],
        silhouette["extra_silhouette"],
        silhouette["proj_silhouette"],
        silhouette["proj_blur_silhouette"],
        silhouette["base_blur_silhouette"],
        silhouette["weighted_silhouette"],
        silhouette["final_silhouette"],
    ],
    size=32,
)

# Debug TEST

In [None]:
azim = 30
elev = -10
for i in range(255):
    obj_label = 4224 + i  # 10, 4112, (4117), 12, 13, 4152
    print(i, obj_label, metainfo.label_to_obj_id(obj_label))

    # create sketch
    img = metainfo.load_image(obj_label, 6, 9)
    transforms = [v2.Resize((256, 256)), ToSketch(), DilateSketch(kernel_size=3)]
    to_image = BaseTransform(normalize=False, transforms=transforms)
    sketch = to_image(img).permute(1, 2, 0)

    # setup model
    cfg.model.prior_mode = 0
    cfg.model.prior_obj_id = metainfo.obj_ids[obj_label]
    cfg.model.latent_init = "latent"
    model = hydra.utils.instantiate(cfg.model).to("cuda")
    model.deepsdf.create_camera(azim=azim, elev=elev)

    with torch.no_grad():
        points, surface_mask = model.deepsdf.sphere_tracing(
            latent=model.latent,
            points=model.deepsdf.camera_points,
            mask=model.deepsdf.camera_mask,
            rays=model.deepsdf.camera_rays,
        )
        normals = model.deepsdf.render_normals(
            points=points,
            latent=model.latent,
            mask=surface_mask,
        )
        grayscale = model.deepsdf.render_grayscale(
            points=points,
            latent=model.latent,
            mask=surface_mask,
        )
    plot_images(
        [sketch, grayscale, normals],
        size=32,
    )

# View Robustness

In [None]:
obj_label = 4224 + 2  # 10, 4112, (4117), 12, 13, 4152
azims = [110, 50, 10, -20, -40, -130]
elevs = [-50, -10, -15, 5, -30, -20]
mesh_resolution = 256

obj_id = metainfo.label_to_obj_id(obj_label)
print(obj_id)
view_azim = 30
view_elev = -20
meshes = []

# create sketch
for azim, elev in zip(azims, elevs):
    preprocess = PreprocessSynthetic(
        cfg.data.data_dir,
        deepsdf_ckpt_path=cfg.deepsdf_ckpt_path,
        n_renderings=1,
        random=False,
        azims=[azim],
        elev=[elev],
    )
    _, sketches, _, _, _ = preprocess.preprocess(obj_id=obj_id)
    img = sketches[0]
    transforms = [v2.Resize((256, 256)), ToSketch(), DilateSketch(kernel_size=3)]
    to_image = BaseTransform(normalize=False, transforms=transforms)
    sketch = to_image(img).permute(1, 2, 0)

    # setup model
    cfg.model.mesh_resolution = mesh_resolution
    model = hydra.utils.instantiate(cfg.model).to("cuda")
    model.latent = model.loss.embedding(
        sketch_transform(img)[None, ...].to("cuda"), mode="sketch"
    )[0]
    model.deepsdf.create_camera(azim=view_azim, elev=view_elev)

    with torch.no_grad():
        points, surface_mask = model.deepsdf.sphere_tracing(
            latent=model.latent,
            points=model.deepsdf.camera_points,
            mask=model.deepsdf.camera_mask,
            rays=model.deepsdf.camera_rays,
        )
        normals = model.deepsdf.render_normals(
            points=points,
            latent=model.latent,
            mask=surface_mask,
        )
        grayscale = model.deepsdf.render_grayscale(
            points=points,
            latent=model.latent,
            mask=surface_mask,
        )

    plot_images(
        [
            sketch,
            grayscale,
            normals,
        ],
        size=32,
    )
    mesh = model.to_mesh()
    meshes.append(mesh)


visualize_object(meshes[0])

# Qualitative Hand Drawn

In [None]:
# TODO