# EMD

In [None]:
import torch
from lib.data.metainfo import MetaInfo
from lib.visualize.open3d import visualize_object, visualize_pointcloud

obj_id = 3
metainfo = MetaInfo(data_dir="/home/borth/sketch2shape/data/shapenet_chair_4096")
surface_samples = metainfo.load_surface_samples(metainfo.obj_ids[obj_id])
mesh = metainfo.load_normalized_mesh(metainfo.obj_ids[obj_id])
visualize_object(mesh)

In [None]:
samples = mesh.sample_points_uniformly(number_of_points=500)
visualize_pointcloud(samples.points)

In [None]:
visualize_pointcloud(surface_samples)

# Earth Movers Distance

In [None]:
from scipy.optimize import linear_sum_assignment
from scipy.spatial.distance import cdist
import numpy as np
import time

s = time.time()
num_samples = 5000
emd = []
for i in range(1):
    np.random.seed(1)
    idx = np.random.choice(range(len(surface_samples)), num_samples, replace=False)
    gt_samples = surface_samples[idx]
    samples = np.asarray(mesh.sample_points_uniformly(number_of_points=num_samples).points)
    d = cdist(gt_samples, samples)
    assignment = linear_sum_assignment(d)
    _emd = d[assignment].sum() / min(len(gt_samples), len(samples))
    emd.append(_emd)
np.mean(emd)

In [None]:
visualize_pointcloud(gt_samples)

# FID

In [None]:
import torch
from lib.data.metainfo import MetaInfo
from lib.visualize.open3d import visualize_object, visualize_pointcloud
import hydra
from lib.utils.config import load_config
import numpy as np
from lib.data.metainfo import MetaInfo
from lib.data.transforms import BaseTransform
import hydra
from lib.utils.config import load_config
import matplotlib.pyplot as plt
import torch
from torch.nn.functional import cosine_similarity
from tqdm import tqdm
from lib.visualize.image import image_grid


def siamese_loss(emb_1, emb_2):
    return 1 - cosine_similarity(emb_1, emb_2)


def transform(normal):
    _transform = BaseTransform()
    return _transform(normal).to("cuda")


def plot_images(images, size: int = 4):
    if isinstance(images, list):
        _, axes = plt.subplots(1, len(images), figsize=(size, size))
        for ax, image in zip(axes, images):
            ax.imshow(image)
            ax.axis("off")  # Turn off axis
        plt.show()
    else:
        plt.figure(figsize=(size, size))
        plt.imshow(images)
        plt.show()


obj_id = 0
metainfo = MetaInfo(data_dir="/home/borth/sketch2shape/data/shapenet_chair_4096")
cfg = load_config("optimize_sketch", ["+dataset=shapenet_chair_4096"])
metainfo = MetaInfo(cfg.data.data_dir)

cfg.loss_ckpt_path = "/home/borth/sketch2shape/checkpoints/latent_encoder.ckpt"
cfg.model.shape_k = 16
cfg.model.shape_view_id = 11
cfg.model.shape_init = False
cfg.model.shape_prior = False
cfg.model.obj_id = metainfo.obj_ids[0]
model = hydra.utils.instantiate(cfg.model).to("cuda")

surface_samples = metainfo.load_surface_samples(metainfo.obj_ids[obj_id])
mesh = metainfo.load_normalized_mesh(metainfo.obj_ids[obj_id])
visualize_object(mesh)

In [None]:
sketch = np.asarray(metainfo.load_image(4030, 11, 0))
model.latent = model.loss(transform(sketch)[None, ...])[0]
rendered_normal = model.capture_camera_frame().detach().cpu().numpy()
plot_images([sketch,rendered_normal], size=8)

In [None]:
import torch
_ = torch.manual_seed(123)
from torchmetrics.image.fid import FrechetInceptionDistance
fid = FrechetInceptionDistance(feature=2048)
# generate two slightly overlapping image intensity distributions

imgs_dist1 = torch.randint(0, 200, (100, 3, 256, 256), dtype=torch.uint8)
imgs_dist2 = torch.randint(100, 255, (100, 3, 256, 256), dtype=torch.uint8)
fid.update(imgs_dist1, real=True)
fid.update(imgs_dist2, real=False)
fid.compute()
