In [22]:

from PIL import Image
import torchvision.transforms.functional as TF
import torchvision.transforms.functional as TF

import torchvision.transforms as transforms
import torchvision.utils as vutils
import cv2
import tqdm
from torchvision import datasets, transforms
from torchvision import transforms

from torch.utils.data import Dataset
import torch.nn.functional as F
import numpy as np
from io import BytesIO
from PIL import Image, ImageOps


In [23]:
# match the other notebook
experiments = {
    "fantasy": ("./data/fantasy", "checkpoints/UNet128_2025-09-12_08-24-21_fantasy.pth"),
    "cats_and_dogs": ("./data/cats_and_dogs/dataset/training_set", "checkpoints/UNet128_2025-09-08_17-57-30_cats_and_dogs.pth"),
    "pokemon": ("./data/pokemon_archive","UNet128_2025-09-12_11-03-12_pokemon.pth"),
    "dragon": ("utils/dataset_fetcher/datasets/", "checkpoints/UNet128_2025-09-12_19-47-15_dragon.pth"),
    "yugioh": ("data/yugioh/YuGiOhImages",None),
    "mnist": (None,None),
}
HR=256
LR=64

In [None]:
class CenterCropSquare:
    def __call__(self, img: Image):
        w, h = img.size
        min_side = min(w, h)
        left = (w - min_side) // 2
        top = (h - min_side) // 2
        right = left + min_side
        bottom = top + min_side
        return img.crop((left, top, right, bottom))


class PadToSquare:
    def __init__(self, mode="both", color: tuple|None = None, size=None):
        assert mode in ['both', 'w', 'h'], "Mode must be 'both', 'width', or 'height'"
        self.mode = mode
        self.color = color
        self.size = size
    def __call__(self, img: Image):
        w, h = img.size
        np_img = np.array(img)
        if not self.color:
            avg_color = tuple(np_img.mean(axis=(0, 1)).astype(int))
        clr = self.color or avg_color
        size = self.size or max(h,w)
        pad_w = size - w if (size > w and self.mode != 'h') else 0
        pad_h = size - h if (size > h and self.mode != 'w') else 0
        pad_l = pad_w // 2
        pad_r = pad_w - pad_l
        pad_t = pad_h // 2
        pad_b = pad_h - pad_t
        padding = (pad_l, pad_t, pad_r, pad_b) 
        return ImageOps.expand(img, padding, fill=clr)  # fill=0 for black padding


class SaliencyTopFractionSquareOrFull:
    def __init__(self, top_fraction=0.01, k=None):
        """
        top_fraction: fraction of image pixels to consider as top salient points

        For images with 231480 pixels, k=1000 seems good
        1000 / 231480
        """
        self.top_fraction = top_fraction
        self.saliency = cv2.saliency.StaticSaliencySpectralResidual_create()
        self.k = k

    def __call__(self, img: Image.Image):
        img_cv = np.array(img.convert('RGB'))[..., ::-1]
        H, W = img_cv.shape[:2]
        side = min(H, W)

        # Compute saliency
        success, saliency_map = self.saliency.computeSaliency(img_cv)
        if not success:
            return img

        # Flatten saliency map and pick top fraction of pixels
        flat = saliency_map.flatten()
        k = self.k or max(1, int(self.top_fraction * flat.size))
        topk_idx = np.argpartition(flat, -k)[-k:]
        ys, xs = np.unravel_index(topk_idx, saliency_map.shape)

        # --- Attempt to fit all top-k points in a square of size `side` ---
        x_min, x_max = xs.min(), xs.max()
        y_min, y_max = ys.min(), ys.max()
        box_w = x_max - x_min + 1
        box_h = y_max - y_min + 1

        if box_w <= side and box_h <= side:
            # Center the square on the top-k bounding box
            cx = (x_min + x_max) // 2
            cy = (y_min + y_max) // 2
            left = cx - side // 2
            top = cy - side // 2
            # Clamp inside image
            left = max(0, min(left, W - side))
            top = max(0, min(top, H - side))
            cropped = img_cv[top:top+side, left:left+side]
            return Image.fromarray(cropped[..., ::-1])
        else:
            # Can't fit all points in the square → return full image
            return img

def resize_width_keep_aspect(img, target_width):
    w, h = img.size
    new_height = int(h * (target_width / w))
    return img.resize((target_width, new_height), resample=Image.BICUBIC)

def crop_top_square(img, top_offset=3, size=64):
    return TF.crop(img, top=top_offset, left=0, height=size, width=size)


In [None]:
import os
import random
from torchvision import datasets
from torchvision.transforms import ToPILImage
from PIL import Image
import numpy as np
from tqdm import tqdm
def mnist_tiler():
    # Load MNIST dataset
    mnist = datasets.MNIST(root="./data", train=True, download=True)
    to_pil = ToPILImage()

    # Shuffle all images
    #images = [to_pil(img) for img, _ in mnist]
    images = [img for img,_ in mnist]
    random.shuffle(images)

    # Create output folder
    os.makedirs("data/256x256/mnist_grid/tiles", exist_ok=True)

    # Parameters
    canvas_size = 256
    grid_size = 8
    cell_size = canvas_size // grid_size
    digit_size = 28
    padding = (cell_size - digit_size) // 2

    # Batch into groups of 64
    for i in tqdm(range(0, len(images), 64)):
        batch = images[i:i+64]
        if len(batch) < 64:
            break  # skip incomplete batch

        # Create canvas with random background color
        bg_color = tuple(random.randint(0, 255) for _ in range(3))
        canvas = Image.new("RGB", (canvas_size, canvas_size), bg_color)

        for idx, digit in enumerate(batch):
            row, col = divmod(idx, grid_size)
            x = col * cell_size + padding
            y = row * cell_size + padding

            # Paste digit onto canvas
            canvas.paste(digit.convert("RGB"), (x, y))

        # Save to disk
        filename = f"data/256x256/mnist_grid/tiles/{i//64 + 1:04d}.webp"
        canvas.save(filename, format="WEBP")
mnist_tiler()

100%|█████████▉| 937/938 [00:06<00:00, 136.51it/s]


In [None]:
mnist_transform = transforms.Compose([
    PadToSquare(size=256,color=(64,)),
    transforms.Resize((HR, HR)),
])

yugioh_transform = transforms.Compose([
    transforms.Lambda(lambda img: resize_width_keep_aspect(img, HR)),
    transforms.Lambda(lambda img: crop_top_square(img, top_offset=3, size=HR)),
])
#hr_dataset = datasets.ImageFolder(root=data_root, transform=yugioh_transform)

dragon_transform = transforms.Compose([
    SaliencyTopFractionSquareOrFull(0.01), # dragon heads without much background
    PadToSquare(),                         # crop to centered square
    transforms.Resize((HR, HR)),               # resize to 64x64
])

hr_transform = transforms.Compose([
    PadToSquare(),                        # crop to centered square
    transforms.Resize((HR, HR)),               # resize to 64x64
])

import os

from functools import lru_cache
from tqdm import tqdm
@lru_cache(10)
def show_name(l):
    print(l)

def save_a_dataset(path,ds):
    if os.path.exists(path):
        print("already did {path}")
        return
    for idx,(img, label) in enumerate(tqdm(ds)):
        class_name = 'img'
        class_dir = f"{path}/{class_name}"
        os.makedirs(class_dir,exist_ok=True)
        filename = f"{idx:06d}.webp"
        out_path = os.path.join(class_dir, filename)
        img.save(out_path, format="WEBP", quality=50)

for k,(path,checkpoint) in experiments.items():
    print(k,path)
    if k == 'yugioh':
        src_ds = datasets.ImageFolder(root=path, transform=yugioh_transform)
    elif k == 'mnist':
        src_ds = datasets.MNIST(root='./data', train=True, download=True, transform=mnist_transform)
    elif k == 'dragon':
        src_ds = datasets.ImageFolder(root=path, transform=dragon_transform)
    else:
        src_ds = datasets.ImageFolder(root=path, transform=hr_transform)
    path = f"data/256x256/{k}"
    save_a_dataset(path,src_ds)

    # "fantasy": ("./data/fantasy", "checkpoints/UNet128_2025-09-12_08-24-21_fantasy.pth"),
    # "cats_and_dogs": ("./data/cats_and_dogs/dataset/training_set", "checkpoints/UNet128_2025-09-08_17-57-30_cats_and_dogs.pth"),
    # "pokemon": ("./data/pokemon_archive","UNet128_2025-09-12_11-03-12_pokemon.pth"),
    # "dragon": ("utils/dataset_fetcher/datasets/", "checkpoints/UNet128_2025-09-12_19-47-15_dragon.pth"),
    # "yugioh": ("data/yugioh/YuGiOhImages",None),
    # "mnist": (None,None),

fantasy ./data/fantasy


100%|██████████| 12814/12814 [18:08<00:00, 11.77it/s] 


cats_and_dogs ./data/cats_and_dogs/dataset/training_set


100%|██████████| 8000/8000 [01:15<00:00, 105.86it/s]


pokemon ./data/pokemon_archive


100%|██████████| 2503/2503 [00:20<00:00, 122.88it/s]


dragon utils/dataset_fetcher/datasets/


100%|██████████| 10287/10287 [02:21<00:00, 72.53it/s]


yugioh data/yugioh/YuGiOhImages


100%|██████████| 13878/13878 [01:59<00:00, 116.20it/s]


mnist None


100%|██████████| 60000/60000 [02:41<00:00, 372.14it/s]
