In [None]:
import json
import os
from tqdm.notebook import tqdm

In [None]:
import json 
import numpy as np 
from PIL import Image
import random 
import os 
import glob
import argparse 


# DO NOT modify the hyperparameters
RESIZE_H, RESIZE_W = 100, 100
H, W = 128, 128


# Use this function to preprocess data
def center_crop_img(tgt_img_path, mask_img_path):
    tgt_img = Image.open(tgt_img_path).convert("RGB")
    np_tgt_img = np.array(tgt_img)

    # mask is processed as [0, 255] value
    mask_img = Image.open(mask_img_path).convert("RGB") # Foreground mask
    # For some of the masks are given as [0, 255]
    if np.array(mask_img).max() > 1:
        np_mask_img = np.array(mask_img)
    else:
        np_mask_img = np.array(mask_img) * 255
    assert np_mask_img.max() <= 255 and np_mask_img.min() >= 0, f"{np_mask_img.min()}, {np_mask_img.max()}"
    np_tgt_img[np_mask_img == 0] = 255

    # Crop image using bbox
    y, x, r = np.where(np_mask_img == 255) # Get bbox using the mask
    x1, x2, y1, y2 = x.min(), x.max(), y.min(), y.max()

    crop_img = Image.fromarray(np_tgt_img).crop(
        (x1, y1, x2, y2)
    )
    cropped_mask = Image.fromarray(np_mask_img).crop(
        (x1, y1, x2, y2)
    )
    w = x2 - x1 
    assert w > 0, f"{x2} - {x1} = {w}"
    h = y2 - y1 
    assert h > 0, f"{y2} - {y1} = {h}"

    # Resize image with respect to max length 
    max_length = max(w, h)
    ratio = RESIZE_W / max_length
    resized_w, resized_h = round(w * ratio), round(h * ratio) # Avoid float error
    assert resized_h == RESIZE_H or resized_w == RESIZE_W

    resized_img = crop_img.resize(
        (resized_w, resized_h)
    )
    resized_object_mask = cropped_mask.resize(
        (resized_w, resized_h)
    )
    img_canvas = Image.new("RGB", (H, W), (255, 255, 255))
    mask_canvas = Image.new("RGB", (H, W), (0, 0, 0))
    pos_w, pos_h = resized_w - W, resized_h - H
    
    pos_w = abs(pos_w) // 2
    pos_h = abs(pos_h) // 2
    assert pos_w + resized_w <= W and pos_h + resized_h <= H

    img_canvas.paste(
        resized_img, (pos_w, pos_h)
    )
    mask_canvas.paste(
        resized_object_mask, (pos_w, pos_h)
    )

    return img_canvas, mask_canvas



In [None]:
json_file = 'metadata.json'
with open(json_file) as f:
    data = json.load(f)
    print(len(data))

In [None]:
data

In [None]:
ouput_dir = 'data/center_cropped'
os.makedirs(ouput_dir, exist_ok=True)

Convert object Mask to 3 channel mask.


Directory Structure:

```
data/
    images/
      obj_28_metal_bucket/
        obj_28_metal_bucket_002_CA2.png
        obj_28_metal_bucket_003_CA2.png  
        ...
    masks/
      obj_28_metal_bucket/
        mask_obj_28_metal_bucket_CA2.png
      ...
```

In [None]:
from multiprocessing import Pool
from functools import partial
def process_image(k, output_dir, data):
  tgt_img_path = data[k]['tgt_img_path']
  mask_img_path = data[k]['mask_path']
  img, mask = center_crop_img(tgt_img_path, mask_img_path)

  img_name = k + '.png'
  # key example: 'obj_28_metal_bucket_010_NA3'
  object_name = k[:-8]
  viewpoint_id = k[-3:]
  
  image_dir = os.path.join(output_dir, 'images', object_name)
  if not os.path.exists(image_dir):
    os.makedirs(image_dir, exist_ok=True)
  output_img_path = os.path.join(image_dir, img_name)
  
  mask_dir = os.path.join(output_dir, 'masks', object_name)
  if not os.path.exists(mask_dir):
    os.makedirs(mask_dir, exist_ok=True)
  
  mask_file_name = object_name +'_' + viewpoint_id + '.png'
  output_mask_path = os.path.join(mask_dir, mask_file_name)
  
  img.save(output_img_path)
  mask.save(output_mask_path)
  print(f"Saved {output_img_path}")
  print(f"Saved {output_mask_path}")
  

# Create a pool of workers
def process_images(json_path, output_dir):
  with open(json_path) as f:
      data = json.load(f)
  os.makedirs(output_dir, exist_ok=True)
  
  with Pool() as pool:
      for _ in tqdm(pool.imap_unordered(partial(process_image, data=data, output_dir=output_dir), data.keys()), total=len(data)):
          pass

In [None]:
process_images(json_file, 'center_cropped_2')

Convert object Mask to 3 channel mask.


Directory Structure:

```
data/
    images/
      obj_28_metal_bucket/
        obj_28_metal_bucket_002_CA2.png
        obj_28_metal_bucket_003_CA2.png  
        ...
    masks/
      obj_28_metal_bucket/
        mask_obj_28_metal_bucket_CA2.png
      ...
```

In [None]:
# Make jsonl file, which contains the path to the images and masks
def make_jsonl(output_path, data_dir):
  with open(output_path, 'w') as f:
    images_dir = os.path.join(data_dir, 'images')
    masks_dir = os.path.join(data_dir, 'masks')

    for object_name in os.listdir(images_dir):
      for image_name in os.listdir(os.path.join(images_dir, object_name)):
        image_path = os.path.join(images_dir, object_name, image_name)
        mask_name = image_name[:-11] + image_name[-7:]
        mask_path = os.path.join(masks_dir, object_name, mask_name)
        f.write(json.dumps({'image_path': image_path, 'mask_path': mask_path}) + '\n')


In [None]:
make_jsonl('center_cropped_2.jsonl', 'center_cropped_2')

In [None]:
img_mask_json_path = 'center_cropped_2.jsonl'

In [None]:
import os
from dataclasses import dataclass
from typing import Optional

import imageio
import numpy as np
import cv2
import simple_parsing


@dataclass
class Args:
    img: str  # Path to the image, to generate hints for.
    seed: int = 3407  # Seed for the generation
    fov: Optional[float] = None  # Field of view for the mesh reconstruction, none for auto estimation from the image

    mask_path: Optional[str] = None  # Path to the mask for the image
    use_sam: bool = True  # Use SAM for background removal
    mask_threshold: float = 25.  # Mask threshold for foreground object extraction
    
    power: float = 1200.  # Power of the point light
    use_gpu_for_rendering: bool = True  # Use GPU for radiance hints rendering

    pl_x: float = 1.  # X position of the point light
    pl_y: float = 1.  # Y position of the point light
    pl_z: float = 1.  # Z position of the point light
    

def generate_hint(img, seed = 3407, fov = None, mask_path = None, use_sam = True, mask_threshold = 25., power = 1200., 
                  use_gpu_for_rendering = True, pl_x = 1., pl_y = 1., pl_z = 1., output_dir = 'radiance_hints'):
    args = Args(img=img, seed=seed, fov=fov, mask_path=mask_path, use_sam=use_sam, mask_threshold=mask_threshold, power=power,
                use_gpu_for_rendering=use_gpu_for_rendering, pl_x=pl_x, pl_y=pl_y, pl_z=pl_z)
    
    from DiLightNet.demo.mesh_recon import mesh_reconstruction
    from DiLightNet.demo.relighting_gen import relighting_gen
    from DiLightNet.demo.render_hints import render_hint_images, render_bg_images
    from DiLightNet.demo.rm_bg import rm_bg
    
    # Load input image and generate/load mask
    input_image = imageio.v3.imread(args.img)
    input_image = cv2.resize(input_image, (512, 512))
    
    if args.mask_path:
        mask = imageio.v3.imread(args.mask_path)
        if mask.ndim == 3:
            mask = mask[..., -1]
        mask = cv2.resize(mask, (512, 512))
    else:
        _, mask = rm_bg(input_image, use_sam=args.use_sam)
    mask = mask[..., None].repeat(3, axis=-1)
    
    # Render radiance hints
    pls = [(
        args.pl_x,
        args.pl_y,
        args.pl_z
    ) ]

    # cache middle results
    img_id = os.path.basename(args.img).split(".")[0]
    lighting_id = f"pl-{args.pl_x}-{args.pl_y}-{args.pl_z}-{args.power}"
    output_folder = os.path.join(output_dir, img_id, lighting_id)
    os.makedirs(output_folder, exist_ok=True)
    # check if the radiance hints are already rendered and full
    
    print(f"Rendering radiance hints")
    # Mesh reconstruction and fov estimation for hints rendering
    fov = args.fov
    mesh, fov = mesh_reconstruction(input_image, mask, False, fov, args.mask_threshold)
    print(f"Mesh reconstructed with fov: {fov}")
    render_hint_images(mesh, fov, pls, args.power, output_folder=output_folder, use_gpu=args.use_gpu_for_rendering)
    print(f"Radiance hints rendered to {output_folder}")

In [None]:
with open(img_mask_json_path) as f:
    # load json file as a list of dict
    data = [json.loads(line) for line in f]
    for d in tqdm(data):
        generate_hint(d['image_path'], mask_path=d['mask_path'], output_dir='radiance_hints')