In [None]:
from lib.data.metainfo import MetaInfo
from lib.utils.config import load_config
from pathlib import Path
from tqdm import tqdm

cfg = load_config("optimize_sketch", ["dataset=shapenet_chair_4096"])
cfg.data.dataset_name = "shapenet_chair_4096"
metainfo = MetaInfo(cfg.data.data_dir)
len(metainfo.obj_ids), metainfo.data_dir
source = (Path(cfg.data.data_dir) / "hand_drawn")

In [None]:
from PIL import Image
import numpy as np
from torchvision.transforms import v2
import torch

padding = 0.05
transform = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
final_transform = v2.Compose([v2.Resize((256, 256)), v2.ToPILImage()])
shapes_path = Path(cfg.data.data_dir) / "shapes"

images = []
for path in tqdm(Path(source).iterdir()):
    obj_id = path.stem
    image = Image.open(path)

    # extract the sketch
    sketch = transform(image)
    mask = sketch.sum(0) == 0.0
    idx = np.where(mask)
    try:
        bbox = sketch[:, np.min(idx[0]) : np.max(idx[0]), np.min(idx[1]) : np.max(idx[1])]
    except:
        print(obj_id)
        bbox = sketch

    # add padding
    max_size = max(bbox.shape[1], bbox.shape[2])
    pad_2 = (max_size - bbox.shape[2]) // 2
    pad_1 = (max_size - bbox.shape[1]) // 2
    bbox = torch.nn.functional.pad(bbox, (pad_2, pad_2, pad_1, pad_1), value=1.0)
    margin = int(max_size * padding)
    bbox = torch.nn.functional.pad(bbox, (margin, margin, margin, margin), value=1.0)

    hand_drawn = final_transform(bbox)
    out_path = shapes_path / obj_id / "hand_drawn_sketch" / "00001.png"
    out_path.parent.mkdir(parents=True, exist_ok=True)
    hand_drawn.save(out_path)