# Utils

In [None]:
from lib.data.metainfo import MetaInfo
from lib.data.transforms import BaseTransform
import hydra
from lib.utils.config import load_config
import matplotlib.pyplot as plt
import torch
from lib.data.transforms import BaseTransform, DilateSketch, SketchTransform, ToSketch
from torchvision.transforms import v2
from lib.data.metainfo import MetaInfo
from lib.data.transforms import BaseTransform
import hydra
from lib.utils.config import load_config
import matplotlib.pyplot as plt
import torch
import numpy as np
from lib.data.preprocess import PreprocessSynthetic
from pathlib import Path
import open3d as o3d
from pathlib import Path
import open3d as o3d
import shutil


def plot_images_np(images, size: int = 4):
    if isinstance(images, list):
        _, axes = plt.subplots(1, len(images), figsize=(size, size))
        for ax, image in zip(axes, images):
            ax.imshow(image)
            ax.axis("off")  # Turn off axis
        plt.show()
    else:
        plt.figure(figsize=(size, size))
        plt.imshow(images)
        plt.axis("off")
        plt.show()


def plot_images(images, size: int = 4):
    if isinstance(images, list):
        _, axes = plt.subplots(1, len(images), figsize=(size, size))
        for ax, image in zip(axes, images):
            ax.imshow(image.detach().cpu().numpy())
            ax.axis("off")  # Turn off axis
        plt.show()
    else:
        plt.figure(figsize=(size, size))
        plt.imshow(images.detach().cpu().numpy())
        plt.axis("off")
        plt.show()


cfg = load_config("optimize_sketch", ["+dataset=shapenet_chair_4096", "logger=wandb"])
cfg.loss_ckpt_path = f"{cfg.paths.checkpoint_dir}/latent_traverse.ckpt"
metainfo = MetaInfo(cfg.data.data_dir)
sketch_transform = SketchTransform(normalize=False, kernel_size=3)

# Traversal Dataset

In [None]:
x_margin = 30
y_margin = 10
t = 0.35
azim = 30
elev = -20
dist = 4.5
source_id = 0
target_ids = [15, 21, 30, 46, 130, 140]

base_transform = BaseTransform(normalize=False, transforms=[DilateSketch(3)])
preprocess = PreprocessSynthetic(
    cfg.data.data_dir,
    deepsdf_ckpt_path=cfg.deepsdf_ckpt_path,
    n_renderings=1,
    random=False,
    azims=[azim],
    elev=[elev],
    dist=dist,
)
cfg.model.prior_obj_id = metainfo.obj_ids[0]
cfg.model.prior_mode = 0
model = hydra.utils.instantiate(cfg.model).to("cuda")
model.deepsdf.create_camera(azim=azim, elev=elev, dist=dist)
obj_id = metainfo.label_to_obj_id(source_id)
_, sketches, _, _, _ = preprocess.preprocess(obj_id=obj_id)

input_sketch = base_transform(sketches[0]).permute(1, 2, 0)
interpolated_sketches = []
target_sketches = []

for i in target_ids:
    target_id = i
    target_obj_id = metainfo.label_to_obj_id(i)
    _, sketches, _, _, _ = preprocess.preprocess(obj_id=target_obj_id)
    target_sketch = sketches[0]

    source_latent = model.deepsdf.lat_vecs.weight[source_id]
    model.latent = source_latent
    source_normal = model.capture_camera_frame().detach().cpu().numpy()

    target_latent = model.deepsdf.lat_vecs.weight[target_id]
    model.latent = target_latent
    target_normal = model.capture_camera_frame().detach().cpu().numpy()

    interpolated_latent = (1 - t) * source_latent + t * target_latent
    model.latent = interpolated_latent
    interpolated_normal = model.capture_camera_frame()
    to_sketch = ToSketch()
    interpolated_sketch = to_sketch(interpolated_normal.permute(2, 0, 1)).detach().cpu()

    interpolated_sketches.append(base_transform(interpolated_sketch).permute(1, 2, 0))
    target_sketches.append(base_transform(target_sketch).permute(1, 2, 0))

i_sk = []
for inter in interpolated_sketches:
    i_sk.append(inter[y_margin:-y_margin, x_margin:-x_margin, :])
plot_images_np(torch.concatenate(i_sk, dim=1), size=16)
i_sk = []
for inter in target_sketches:
    i_sk.append(inter[y_margin:-y_margin, x_margin:-x_margin, :])
plot_images_np(torch.concatenate(i_sk, dim=1), size=16)
i_sk = []
for inter in [input_sketch]:
    i_sk.append(inter[y_margin:-y_margin, x_margin:-x_margin, :])
plot_images_np(torch.concatenate(i_sk, dim=1), size=6)

# Silhouettes 

In [None]:
mode = 9
dist = 4.5
mesh_resolution = 128
x_margin = 20
y_margin = 1
idxs = [230, 218, 172, 198, 199, 156]
# idxs = [172, 199, 156]
idxs = [156]
azims = [90, 65, 40, 10, -20, -130]
elevs = [0, -40, -30, -15, 5, -20]
# azims = [45]
# elevs = [-20]
transforms = [v2.Resize((256, 256)), ToSketch(), DilateSketch(kernel_size=3)]
# transforms = [v2.Resize((256, 256)), ToSketch(), DilateSketch(kernel_size=5)]
to_image = BaseTransform(normalize=False, transforms=transforms)
meshes = []

for idx in idxs:
    obj_label = 4224 + idx
    print(obj_label, metainfo.label_to_obj_id(obj_label))
    if mode == 9:
        img = metainfo.load_image(obj_label, 6, 9)
    else:
        img = metainfo.load_image(obj_label, 0, 10)
    sketch = to_image(img).permute(1, 2, 0)

    cfg.model.mesh_resolution = mesh_resolution
    cfg.model.prior_obj_id = metainfo.obj_ids[obj_label]
    cfg.model.latent_init = "latent"
    # setup model
    if mode == 9:
        cfg.model.prior_view_id = 6
        cfg.model.prior_mode = 9
    else:
        cfg.model.prior_view_id = 0
        cfg.model.prior_mode = 10
    model = hydra.utils.instantiate(cfg.model).to("cuda")

    silhouettes = []
    broke_silhouettes = []
    for azim, elev in zip(azims, elevs):
        model.deepsdf.create_camera(azim=azim, elev=elev, dist=dist)
        with torch.no_grad():
            points, surface_mask = model.deepsdf.sphere_tracing(
                latent=model.latent,
                points=model.deepsdf.camera_points,
                mask=model.deepsdf.camera_mask,
                rays=model.deepsdf.camera_rays,
            )
            normals = model.deepsdf.render_normals(
                points=points,
                latent=model.latent,
                mask=surface_mask,
            )
            grayscale = model.deepsdf.render_grayscale(
                points=points,
                latent=model.latent,
                mask=surface_mask,
            )
            silhouette = model.deepsdf.render_silhouette(
                normals=normals,
                points=points,
                latent=model.latent,
                proj_blur_eps=0.7,
                weight_blur_kernal_size=9,
                weight_blur_sigma=9.0,
            )
            silhouettes.append(silhouette["final_silhouette"])
            broke_silhouettes.append(silhouette["base_silhouette"])
    # meshes.append(model.to_mesh())
    plt.set_cmap("gist_yarg")
    sil = []
    for inter in silhouettes:
        sil.append(inter[y_margin:-y_margin, x_margin:-x_margin].detach().cpu())
    silhouette = torch.concatenate(sil, dim=1).detach().cpu().numpy()
    silhouette[silhouette == 0.0] = np.nan

    brok_sil = []
    for inter in broke_silhouettes:
        brok_sil.append(inter[y_margin:-y_margin, x_margin:-x_margin].detach().cpu())
    broke_silhouette = torch.concatenate(brok_sil, dim=1).detach().cpu().numpy()

    plot_images(sketch)
    plot_images_np(silhouette, size=16)
    plot_images_np(broke_silhouette, size=16)

output_dir = Path(cfg.paths.root_dir) / "temp/edge_maps"
for idx, mesh in enumerate(meshes):
    path = output_dir / f"mesh/{idx:05}.obj"
    path.parent.mkdir(parents=True, exist_ok=True)
    o3d.io.write_triangle_mesh(str(path), mesh=mesh, write_triangle_uvs=False)

# View Robustness

In [None]:
# idxs = [230, 218, 172, 199, 198, 156]
obj_label = 4224 + 230  # 36, 4112, (22)
azims = [110, 50, 10, -20, -130]
# azims = [100, 50, 10, -20, -130]
elevs = [-40, -30, -10, 5, -20]
mesh_resolution = 256

obj_id = metainfo.label_to_obj_id(obj_label)
print(obj_id)
view_azim = 30
view_elev = -20
meshes = []
sketches = []

# create sketch
for azim, elev in zip(azims, elevs):
    preprocess = PreprocessSynthetic(
        cfg.data.data_dir,
        deepsdf_ckpt_path=cfg.deepsdf_ckpt_path,
        n_renderings=1,
        random=False,
        azims=[azim],
        elev=[elev],
    )
    _, _sketches, _, _, _ = preprocess.preprocess(obj_id=obj_id)
    img = _sketches[0]
    transforms = [v2.Resize((256, 256)), ToSketch(), DilateSketch(kernel_size=3)]
    to_image = BaseTransform(normalize=False, transforms=transforms)
    sketch = to_image(img).permute(1, 2, 0)

    # setup model
    cfg.model.mesh_resolution = mesh_resolution
    cfg.model.latent_init = "mean"
    model = hydra.utils.instantiate(cfg.model).to("cuda")
    model.latent = model.loss.embedding(
        sketch_transform(img)[None, ...].to("cuda"), mode="sketch"
    )[0]
    model.deepsdf.create_camera(azim=view_azim, elev=view_elev)

    with torch.no_grad():
        points, surface_mask = model.deepsdf.sphere_tracing(
            latent=model.latent,
            points=model.deepsdf.camera_points,
            mask=model.deepsdf.camera_mask,
            rays=model.deepsdf.camera_rays,
        )
        normals = model.deepsdf.render_normals(
            points=points,
            latent=model.latent,
            mask=surface_mask,
        )
        grayscale = model.deepsdf.render_grayscale(
            points=points,
            latent=model.latent,
            mask=surface_mask,
        )

    plot_images(
        [
            sketch,
            grayscale,
            normals,
        ],
        size=32,
    )
    mesh = model.to_mesh()
    meshes.append(mesh)
    sketches.append(sketch)

output_dir = Path(cfg.paths.root_dir) / "temp/multi_view_v3"
for idx, mesh in enumerate(meshes):
    path = output_dir / f"mesh/{idx:05}.obj"
    path.parent.mkdir(parents=True, exist_ok=True)
    o3d.io.write_triangle_mesh(str(path), mesh=mesh, write_triangle_uvs=False)
for idx, sketch in enumerate(sketches):
    path = output_dir / f"output/sketch/{idx:05}.png"
    path.parent.mkdir(parents=True, exist_ok=True)
    v2.functional.to_pil_image(sketch.permute(2, 0, 1)).save(path)

obj_id = metainfo.label_to_obj_id(obj_label)
source_path = metainfo.normalized_mesh_path(obj_id=obj_id)
target_path = Path(output_dir) / f"mesh/gt_{obj_label}.obj"
target_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(source_path, target_path)

# Qualitative Hand Drawn

In [None]:
import shutil
from lib.data.transforms import DilateSketch, BaseTransform
from pathlib import Path

base_transform = BaseTransform(normalize=False, transforms=[DilateSketch(3)])
handdrawn_retrieval_folder = (
    "/home/borth/sketch2shape/logs/optimize_deepsdf/runs/2024-03-06_16-32-35/mesh"
)
handdrawn_encoder_folder = (
    "/home/borth/sketch2shape/logs/optimize_deepsdf/runs/2024-03-06_20-33-37/mesh"
)
handdrawn_silhouette_folder = (
    "/home/borth/sketch2shape/logs/optimize_sketch/runs/2024-03-09_12-25-47/mesh"
)
handdrawn_global_folder = (
    "/home/borth/sketch2shape/logs/optimize_sketch/runs/2024-03-10_17-43-31/mesh"
)

sketch_retrieval_folder = (
    "/home/borth/sketch2shape/logs/optimize_deepsdf/runs/2024-03-08_10-06-03/mesh"
)
sketch_encoder_folder = (
    "/home/borth/sketch2shape/logs/optimize_deepsdf/runs/2024-03-08_15-16-01/mesh"
)
sketch_silhouette_folder = (
    "/home/borth/sketch2shape/logs/optimize_sketch/runs/2024-03-09_12-05-58/mesh"
)
sketch_global_folder = (
    "/home/borth/sketch2shape/logs/optimize_sketch/runs/2024-03-10_17-23-43/mesh"
)
output_folder = "/home/borth/sketch2shape/temp/qualitative_v5"
idxs = [230, 218, 172, 199, 198, 156]

for idx in idxs:
    idx = 4224 + idx
    obj_id = metainfo.label_to_obj_id(idx)

    source_path = Path(handdrawn_retrieval_folder) / f"{obj_id}.obj"
    target_path = Path(output_folder) / f"mesh/handdrawn_retrieval_{idx}.obj"
    target_path.parent.mkdir(parents=True, exist_ok=True)
    shutil.copy(source_path, target_path)

    source_path = Path(handdrawn_encoder_folder) / f"{obj_id}.obj"
    target_path = Path(output_folder) / f"mesh/handdrawn_encoder_{idx}.obj"
    target_path.parent.mkdir(parents=True, exist_ok=True)
    shutil.copy(source_path, target_path)

    source_path = Path(handdrawn_silhouette_folder) / f"{obj_id}.obj"
    if source_path.exists():
        target_path = Path(output_folder) / f"mesh/handdrawn_silhouette_{idx}.obj"
        target_path.parent.mkdir(parents=True, exist_ok=True)
        shutil.copy(source_path, target_path)

    source_path = Path(handdrawn_global_folder) / f"{obj_id}.obj"
    if source_path.exists():
        target_path = Path(output_folder) / f"mesh/handdrawn_global_{idx}.obj"
        target_path.parent.mkdir(parents=True, exist_ok=True)
        shutil.copy(source_path, target_path)

    source_path = Path(sketch_retrieval_folder) / f"{obj_id}.obj"
    target_path = Path(output_folder) / f"mesh/sketch_retrieval_{idx}.obj"
    target_path.parent.mkdir(parents=True, exist_ok=True)
    shutil.copy(source_path, target_path)

    source_path = Path(sketch_encoder_folder) / f"{obj_id}.obj"
    target_path = Path(output_folder) / f"mesh/sketch_encoder_{idx}.obj"
    target_path.parent.mkdir(parents=True, exist_ok=True)
    shutil.copy(source_path, target_path)

    source_path = Path(sketch_silhouette_folder) / f"{obj_id}.obj"
    if source_path.exists():
        target_path = Path(output_folder) / f"mesh/sketch_silhouette_{idx}.obj"
        target_path.parent.mkdir(parents=True, exist_ok=True)
        shutil.copy(source_path, target_path)

    source_path = Path(sketch_global_folder) / f"{obj_id}.obj"
    if source_path.exists():
        target_path = Path(output_folder) / f"mesh/sketch_global_{idx}.obj"
        target_path.parent.mkdir(parents=True, exist_ok=True)
        shutil.copy(source_path, target_path)

    source_path = metainfo.normalized_mesh_path(obj_id=obj_id)
    target_path = Path(output_folder) / f"mesh/gt_{idx}.obj"
    target_path.parent.mkdir(parents=True, exist_ok=True)
    shutil.copy(source_path, target_path)

    sketch = metainfo.load_image(idx, 6, 9)
    sketch = base_transform(sketch)
    pil_sketch = v2.functional.to_pil_image(sketch)
    target_path = Path(output_folder) / f"output/sketch/{idx}.png"
    target_path.parent.mkdir(parents=True, exist_ok=True)
    pil_sketch.save(target_path)

    sketch = metainfo.load_image(idx, 0, 10)
    sketch = base_transform(sketch)
    sketch[sketch > 0.1] = 1.0
    pil_sketch = v2.functional.to_pil_image(sketch)
    target_path = Path(output_folder) / f"output/handdrawn/{idx}.png"
    target_path.parent.mkdir(parents=True, exist_ok=True)
    pil_sketch.save(target_path)