# Utils

In [None]:
import torch
from torchvision.transforms import v2
from PIL import Image
import matplotlib.pyplot as plt
from lib.data.metainfo import MetaInfo

def plot_images(images, size: int = 4):
    if isinstance(images, list):
        _, axes = plt.subplots(1, len(images), figsize=(size, size))
        for ax, image in zip(axes, images):
            ax.imshow(image.permute(1, 2, 0).detach().cpu().numpy())
            ax.axis("off")  # Turn off axis
        plt.show()
    else:
        plt.figure(figsize=(size, size))
        plt.imshow(images.permute(1, 2, 0).detach().cpu().numpy())
        plt.show()

# Different Sketch Types

In [None]:
import cv2 as cv
import numpy as np
from dataclasses import dataclass


@dataclass
class ToSketch(object):
    """Convert the image to an edge map.

    The input of the edge maps needs to be of dim 3xHxW and the output
    """

    t_lower: int = 100
    t_upper: int = 150
    aperture_size: int = 3  # 3, 5, 7
    l2_gradient: bool = True

    def __call__(self, image):
        edge = cv.Canny(
            image,
            threshold1=self.t_lower,
            threshold2=self.t_upper,
            apertureSize=self.aperture_size,
            L2gradient=self.l2_gradient,
        )
        edge = cv.bitwise_not(edge)
        return np.stack((np.stack(edge),) * 3, axis=-1)


@dataclass
class SketchDilation(object):
    def __init__(self, kernal_size: int = 1):
        assert kernal_size >= 1
        self.conv = torch.nn.Conv2d(
            in_channels=3,
            out_channels=3,
            kernel_size=kernel_size,
            padding="same",
            stride=1,
            bias=False,
        )
        self.conv.weight = torch.nn.Parameter(torch.ones_like(self.conv.weight))
        self.padding = (kernal_size - 1) * 2

    def __call__(self, image):
        _, H, W = image.shape
        img = 1.0 - image
        img = v2.functional.pad(img, padding=self.padding) # 3xH+PxW+P
        img = self.conv(img)
        img = 1.0 - torch.min(img, torch.tensor(1.0))
        return v2.functional.resize(img, (H, W), antialias=True)

for obj_id in range(74, 120):
    # obj_id = 74
    metainfo = MetaInfo(data_dir="/home/borth/sketch2shape/data/shapenet_chair_4096")
    sketch = metainfo.load_sketch(metainfo.obj_ids[obj_id], "00011")
    base_transform = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])

    images = []
    for kernel_size in range(1, 10, 2):
        dilation = SketchDilation(kernal_size=kernel_size)
        image = dilation(base_transform(sketch))
        images.append(image)

    plot_images(images, size=32)

In [None]:
@dataclass
class ToSilhouette(object):
    def __call__(self, image):
        surface_maks = image.sum(0) < 2.95
        image[:, surface_maks] = 0.0
        return image

@dataclass
class ToGrayScale(object):
    def __call__(self, image):
        mean = image.mean(0)
        return torch.stack([mean, mean, mean], dim=0)

grayscale = ToGrayScale()
normal = metainfo.load_normal(metainfo.obj_ids[obj_id], "00011")
image = grayscale(jitter(base_transform(normal)))
# image = grayscale(base_transform(normal))
plot_images(image)

# Stack multiple sketches on top of each other

In [None]:
t = 4
overlaps = []
for degree, image in enumerate(images[1:4][::-1]):
    img = 1 - image
    overlaps.append(v2.functional.rotate(img, degree))
plot_images(1 - torch.stack(overlaps).sum(0))

# Sharpness

In [None]:
sharpness_images = [
    v2.functional.adjust_sharpness(images[0], 0),
    v2.functional.adjust_sharpness(images[0], 100),
]
plot_images(sharpness_images, size=16)

# Resize

In [None]:
resized_image = v2.functional.resize(images[4], size=(64, 64), antialias=True)
plot_images(resized_image)

In [None]:
pad_img = v2.functional.pad(images[5], padding=5, fill=1.0)
plot_images(pad_img)