# Utils

In [None]:
import torch
from torchvision.transforms import v2
from PIL import Image
import matplotlib.pyplot as plt
from lib.data.metainfo import MetaInfo

def plot_images(images, size: int = 4):
    if isinstance(images, list):
        _, axes = plt.subplots(1, len(images), figsize=(size, size))
        for ax, image in zip(axes, images):
            ax.imshow(image.permute(1, 2, 0).detach().cpu().numpy())
            ax.axis("off")  # Turn off axis
        plt.show()
    else:
        plt.figure(figsize=(size, size))
        plt.imshow(images.permute(1, 2, 0).detach().cpu().numpy())
        plt.show()

# Phong Shader Inf Far Away

In [None]:
obj_id = 16
light_direction = [1, 1, 1]
ambient = 0.5
diffuse = 0.5

metainfo = MetaInfo(data_dir="/home/borth/sketch2shape/data/shapenet_chair_4096")
trans = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.5], std=[0.5])
])
normal = trans(metainfo.load_normal(metainfo.obj_ids[obj_id], "00011"))
mask = normal.sum(0) > 2.95
# normal[0, :, :] *= -1  # flip the shadow so that it looks from every side the same
normal[:, mask] = 0.0

L = torch.tensor(light_direction)  #  (3,)
L = L / L.sum()
image = torch.zeros_like(normal)
image += ambient
image += diffuse * (L[..., None, None] * normal).sum(0)[None, ...]
image[:, mask] = 1
plot_images(image, size=4)

# Mean Shader

In [None]:
obj_id = 0
metainfo = MetaInfo(data_dir="/home/borth/sketch2shape/data/shapenet_chair_4096")
trans = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
])
# normal = trans(metainfo.load_normal(metainfo.obj_ids[obj_id], "00081"))
normal = trans(metainfo.load_normal(metainfo.obj_ids[obj_id], "00011"))
mask = normal.sum(0) > 2.95
# normal[2, :, :] *= -1  # flip the shadow
normal[:, mask] = 1
mean = normal.mean(0)
image = torch.stack([mean, mean, mean], dim=0)
plot_images(image, size=4)