In [1]:
import os 
import torch 
import numpy as np
from PIL import Image

In [2]:
def applySHlight(normal_images, sh_coeff):
  N = normal_images
  sh = torch.stack(
    [
      N[0] * 0.0 + 1.0,
      N[0],
      N[1],
      N[2],
      N[0] * N[1],
      N[0] * N[2],
      N[1] * N[2],
      N[0] ** 2 - N[1] ** 2,
      3 * (N[2] ** 2) - 1,
    ],
    0,
  )  # [9, h, w]
  pi = np.pi
  constant_factor = torch.tensor(
    [
      1 / np.sqrt(4 * pi), #confirmed
      ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))),
      ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))),
      ((2 * pi) / 3) * (np.sqrt(3 / (4 * pi))),
      (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))),
      (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))),
      (pi / 4) * (3) * (np.sqrt(5 / (12 * pi))),
      (pi / 4) * (3 / 2) * (np.sqrt(5 / (12 * pi))),
      (pi / 4) * (1 / 2) * (np.sqrt(5 / (4 * pi))),
    ]
  ).float()
  sh = sh * constant_factor[:, None, None]

  shading = torch.sum(
    sh_coeff[:, :, None, None] * sh[:, None, :, :], 0
  )  # [9, 3, h, w]

  return shading


In [3]:
def genSurfaceNormals(n):
  x = torch.linspace(-1, 1, n)
  y = torch.linspace(1, -1, n)
  y, x = torch.meshgrid(y, x)

  z = (1 - x ** 2 - y ** 2)
  mask = z < 0
  z[mask] = 0
  z = torch.sqrt(z)
  return torch.stack([x, y, z], 0), mask

In [4]:
def applySHlightXYZ(xyz, sh):
  out = applySHlight(xyz, sh)
  #bar = torch.quantile(out, 0.99)
  #out = out / bar 
  # out /= pt.max(out)
  out *= 0.7
  return torch.clip(out, 0, 1)

In [5]:
def drawSphere(sh, img_size=256, is_back=False, white_bg=False):
  n = img_size
  xyz, mask = genSurfaceNormals(n)
  if(is_back):
    xyz[2] = xyz[2] * -1
  if white_bg:
    xyz[:, mask] = 1
  out = applySHlightXYZ(xyz, sh)
  out[:, xyz[2] == 0] = 0
  return out

In [10]:
def create_image_grid(images, rows=4, cols=4):
  """Creates an image grid of size 4x4 from a list of 16 PIL.Images.

  Args:
      images: A list of 16 PIL.Image objects.

  Returns:
      A new PIL.Image object containing the image grid.
  """
  if len(images) != 16:
    raise ValueError("List must contain exactly 16 images.")

  # Find the maximum width and height of the images
  max_width = max(image.size[0] for image in images)
  max_height = max(image.size[1] for image in images)

  # Calculate the total width and height of the grid
  grid_width = rows * max_width
  grid_height = cols * max_height

  # Create a new image for the grid
  grid_image = Image.new('RGB', size=(grid_width, grid_height))

  # Paste each image into the grid with padding
  for i, image in enumerate(images):
    row = i // rows
    col = i % cols
    x_offset = col * max_width
    y_offset = row * max_height

    # Paste the image with padding to center it within its grid cell
    paste_box = (x_offset + (max_width - image.size[0]) // 2,
                 y_offset + (max_height - image.size[1]) // 2,
                 x_offset + max_width,
                 y_offset + max_height)
    grid_image.paste(image, box=paste_box)

  return grid_image

In [11]:
DIRECTIONS = ['left','right', 'bottom', 'top', 'back', 'front']

In [None]:

for image_id in [0]:
    image_output = []
    for direction in DIRECTIONS:
        gt_image = f"../../datasets/validation/angelica6axis_small05/images/{image_id:05d}/{direction}{image_id}.png"
        gt_image = Image.open(gt_image).resize((256, 256))
        sh_coeff = np.load(f"../../datasets/validation/angelica6axis_small05/light/{image_id:05d}/{direction}{image_id}_light.npy")
        sphere_image = drawSphere(sh_coeff, is_back=False, white_bg=False).permute(1,2,0).numpy()
        sphere_image = Image.fromarray((sphere_image * 255).astype(np.uint8)).resize((256, 256))
        # concate 2 image 
        concate_image = Image.new('RGB', (gt_image.width + sphere_image.width, gt_image.height))
        concate_image.paste(gt_image, (0,0))
        concate_image.paste(sphere_image, (gt_image.width, 0))
        image_output.append(concate_image)
    image_output = create_image_grid(image_output, rows=3, cols=2)
    display(image_output)