In [1]:
import torch

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
from transformers import AutoProcessor, AutoModel
path_siglip = "google/siglip-so400m-patch14-384"

model_siglip     = AutoModel.from_pretrained(path_siglip).to(device)
processor_siglip = AutoProcessor.from_pretrained(path_siglip)

  from .autonotebook import tqdm as notebook_tqdm
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [4]:
def rank_images(prompt: str, image_list, batch_size: int = 4):
    """
    prompt       : Text instruction that the images should follow
    image_list   : List of image files
    batch_size   : Adjust according to GPU memory
    return       : (Path of the image with the highest score, list of scores for each image)
    """
    scores = []

    for i in range(0, len(image_list), batch_size):
        images = image_list[i : i + batch_size]
        texts = [prompt]

        inputs = processor_siglip(
            text=texts,
            images=images,
            padding="max_length",
            return_tensors="pt"
        ).to(device)

        with torch.no_grad():
            logits = model_siglip(**inputs).logits_per_image  # shape (B, 1)

        probs = torch.sigmoid(logits).squeeze(1).cpu()  # Convert to [0, 1] range
        scores.extend(probs.tolist())

    best_idx = int(torch.tensor(scores).argmax())
    return best_idx, scores


In [5]:
import kornia.filters as K
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from diff_jpeg import diff_jpeg_coding
from torch import fft, nn
from torchvision.transforms import InterpolationMode

DEFAULT_IMAGE_SIZE = 384
DEFAULT_FFT_CUTOFF = 0.5
DEFAULT_CROP_PERCENT = 0.05
FORWARD_CROP_PERCENT = 0.03
FORWARD_JPEG_QUALITY_1 = 95
FORWARD_MEDIAN_SIZE = 9
FORWARD_FFT_CUTOFF = 0.5
FORWARD_BILATERAL_D = 5
FORWARD_BILATERAL_SIGMA_COLOR = 75
FORWARD_BILATERAL_SIGMA_SPACE = 75
FORWARD_JPEG_QUALITY_2 = 92
TEST_JPEG_QUALITY = 85
COMPARISON_ATOL = 1e-2


class ImageProcessorTorch(nn.Module):
    """
    A differentiable image processor using PyTorch for frequency-based filtering.

    Args:
        seed (int, optional): Random seed for reproducibility. Defaults to None.
    """

    def __init__(self, seed=None):
        super().__init__()
        self.seed = seed
        if seed is not None:
            self.rng = np.random.RandomState(seed)
        else:
            self.rng = np.random

    def apply_fft_low_pass(
        self, images: torch.Tensor, cutoff_frequency: float = DEFAULT_FFT_CUTOFF
    ) -> torch.Tensor:
        """
        Apply a low-pass filter in the frequency domain using FFT (Fast Fourier Transform).

        This method removes high-frequency details (edges, noise) from the image while preserving
        low-frequency components (smooth areas, large structures). Useful for smoothing or preparing
        images for aesthetic analysis.

        Args:
            images (torch.Tensor): A batch of input images with shape (B, C, H, W).
                                   Pixel values must be in [0, 1] and images must be square.
            cutoff_frequency (float, optional): Normalized cutoff frequency (0-1). Lower values
                                                remove more high-frequency content.
                                                Defaults to DEFAULT_FFT_CUTOFF.

        Returns:
            torch.Tensor: The filtered images with the same shape (B, C, H, W), clamped to [0, 1].
        """
        b, c, h, w = images.shape
        assert h == w, "The images must be square"
        rows, cols = h, w
        crow, ccol = rows // 2, cols // 2

        # FFT and center shift
        fshift = fft.fftshift(
            fft.fft2(images.to(dtype=torch.float32), dim=(-2, -1)), dim=(-2, -1)
        )

        # Scale cutoff to match current image resolution
        cutoff_frequency = cutoff_frequency * h / DEFAULT_IMAGE_SIZE
        r = min(crow, ccol) * cutoff_frequency
        r_sq = r**2

        # Create circular low-pass mask
        y, x = torch.meshgrid(
            torch.arange(rows, device=images.device),
            torch.arange(cols, device=images.device),
            indexing="ij",
        )
        center_dist_sq = (y - crow) ** 2 + (x - ccol) ** 2
        mask = (
            (center_dist_sq <= r_sq).to(torch.float16).unsqueeze(0).unsqueeze(0)
        )  # (1, 1, H, W)

        # Apply mask and inverse FFT
        fshift_filtered = fshift * mask
        f_ishift = fft.ifftshift(fshift_filtered, dim=(-2, -1))
        img_back = fft.ifft2(f_ishift, dim=(-2, -1))
        img_back_real = torch.real(img_back)

        # Clamp values to [0, 1]
        filtered_images = torch.clamp(img_back_real, 0, 1)

        return filtered_images

    def apply_random_crop_resize(
        self,
        images: torch.Tensor,
        crop_percent: float = DEFAULT_CROP_PERCENT,
        interpolation=InterpolationMode.BILINEAR,
    ) -> torch.Tensor:
        """
        Apply random cropping followed by resizing to the original size for a batch of images.

        This is commonly used for data augmentation to improve generalization.

        Args:
            images (torch.Tensor): Batch of input images with shape (B, C, H, W), values in [0, 1].
            crop_percent (float): Max percentage of height/width to randomly crop (0-1).
            interpolation: Interpolation method for resizing. Default is bilinear.

        Returns:
            torch.Tensor: Batch of processed images with the same shape (B, C, H, W).
        """
        b, c, h, w = images.shape
        original_size = (h, w)

        crop_pixels_h = int(h * crop_percent)
        crop_pixels_w = int(w * crop_percent)

        cropped_images = []

        for i in range(b):
            left = self.rng.randint(0, crop_pixels_w + 1)
            top = self.rng.randint(0, crop_pixels_h + 1)
            right = w - self.rng.randint(0, crop_pixels_w + 1)
            bottom = h - self.rng.randint(0, crop_pixels_h + 1)
            new_w = right - left
            new_h = bottom - top
            # print(f"[PyTorch] left: {left}, top: {top}, right: {right}, bottom: {bottom}")

            # Crop & Resize (`TF.resized_crop` is differentiable)
            cropped_img = TF.resized_crop(
                images[i],
                top,
                left,
                new_h,
                new_w,
                original_size,
                interpolation=interpolation,
            )
            cropped_images.append(cropped_img)
        return torch.stack(cropped_images)

    def apply_median_filter(self, images: torch.Tensor, size: int = 3) -> torch.Tensor:
        """
        Apply a median filter to a batch of images using unfolding for efficiency.

        This method avoids using heavy convolutional operations and performs
        an exact median blur that is fully differentiable.

        Args:
            images (torch.Tensor): Input image tensor of shape (B, C, H, W), where
                                B = batch size, C = number of channels, H = height, W = width.
            size (int): Size of the median filter kernel (must be odd). Default is 3.

        Returns:
            torch.Tensor: Median-filtered image tensor of shape (B, C, H, W).
        """
        # Validate that kernel size is a positive integer
        if not isinstance(size, int) or size < 1:
            raise ValueError("kernel_size must be a positive integer.")

        # Ensure kernel size is odd (median needs a center)
        if size % 2 == 0:
            size += 1

        # Calculate padding size
        pad = (size - 1) // 2

        # Pad the image using replicate mode to handle borders
        x = F.pad(images, (pad, pad, pad, pad), mode="replicate")

        # Extract image patches of shape (B, C * size^2, H * W)
        patches = F.unfold(x, kernel_size=size)

        # Get dimensions
        B, CP, HW = patches.shape

        # Reshape to (B, C, size*size, H*W) to compute median across kernel window
        patches = patches.view(B, images.shape[1], size * size, HW)

        # Compute median along the kernel dimension (dim=2)
        med = torch.median(patches, dim=2).values

        # Reshape result to original image shape
        med = med.view_as(images)

        return med  # Return the median filtered images

    def apply_bilateral_filter(
        self, images: torch.Tensor, d=9, sigma_color=75, sigma_space=75
    ) -> torch.Tensor:
        """
        Apply a bilateral filter using Kornia (differentiable and GPU-friendly).

        Bilateral filtering smooths the image while preserving edges, by considering
        both spatial closeness and color similarity.

        Args:
            images (torch.Tensor): Input image tensor of shape (B, C, H, W),
                                with pixel values assumed to be in [0, 1].
            d (int, optional): Diameter of the filter kernel. Will be scaled based on image size. Defaults to 9.
            sigma_color (float, optional): Filter sigma in the color space (range 0–255). Controls how dissimilar colors are smoothed. Defaults to 75.
            sigma_space (float, optional): Filter sigma in the coordinate space. Controls how far pixels influence each other spatially. Defaults to 75.

        Returns:
            torch.Tensor: Bilateral-filtered image tensor of shape (B, C, H, W).
        """
        # Normalize sigma_color from 0–255 scale to [0, 1] for Kornia
        normalized_sigma_color = sigma_color / 255.0

        # Extract image shape
        b, c, h, w = images.shape

        # Check if the input is a square image (required by this implementation)
        assert h == w, "The images must be square"

        # Scale the kernel size d proportionally to the input size (relative to a default reference size)
        d = int(np.ceil(d * h / DEFAULT_IMAGE_SIZE))

        # Apply bilateral filter using Kornia
        return K.bilateral_blur(
            images,
            kernel_size=d,
            sigma_color=normalized_sigma_color,
            sigma_space=(sigma_space, sigma_space),  # horizontal and vertical
        )

    def apply_jpeg_compression(self, images: torch.Tensor, quality=85) -> torch.Tensor:
        """
        Simulate JPEG compression using diff_jpeg (optionally non-differentiable).

        Note:
            This function uses diff_jpeg to simulate the effect of lossy JPEG compression.
            When `ste=False`, the compression is not differentiable (i.e., gradients won't flow).

        Args:
            images (torch.Tensor): Input image tensor of shape (B, C, H, W),
                                with pixel values expected to be in the [0, 1] range.
            quality (int, optional): JPEG quality level (1–100); higher means better quality and less compression.
                                    Defaults to 85.

        Returns:
            torch.Tensor: JPEG-compressed (simulated) image tensor of shape (B, C, H, W),
                        with pixel values in the [0, 1] range.
        """
        # Get batch size and image dimensions
        b, c, h, w = images.shape

        # Create a tensor of JPEG quality values for each image in the batch
        quality_tensor = torch.tensor(
            # Repeat the quality value for each image
            [quality] * b,
            # Place it on the same device as the images (CPU/GPU)
            device=images.device,
            dtype=torch.float16,  # Use float16 for memory efficiency
        )

        # Perform JPEG compression simulation using diff_jpeg
        # Inputs are scaled to [0, 255] as JPEG encoding expects that range
        # `ste=False` disables straight-through estimator, making this step non-differentiable
        compressed_images = diff_jpeg_coding(
            images * 255.0, jpeg_quality=quality_tensor, ste=False
        )

        # Rescale the result back to [0, 1] before returning
        return compressed_images / 255.0

    def forward(
        self,
        x: torch.Tensor,
        skip_fft_low_pass=False,
        skip_random_crop_resize=False,
        skip_median_filter=False,
        skip_bilateral_filter=False,
        skip_jpeg_compression=False,
    ) -> torch.Tensor:
        """
        Apply a sequence of image degradation operations to the input tensor.

        This method simulates real-world image corruption by applying various
        differentiable and non-differentiable image transformations. These
        operations can help models learn more robust or aesthetically-aware
        representations. Each step can be skipped independently via its
        corresponding flag.

        Args:
            x (torch.Tensor): Input image tensor of shape (B, C, H, W) with pixel values in [0, 1].
            skip_fft_low_pass (bool): If True, skip FFT low-pass filtering.
            skip_random_crop_resize (bool): If True, skip random cropping and resizing.
            skip_median_filter (bool): If True, skip median filtering.
            skip_bilateral_filter (bool): If True, skip bilateral filtering.
            skip_jpeg_compression (bool): If True, skip JPEG compression (both passes).

        Returns:
            torch.Tensor: Output image tensor after transformations, same shape as input.
        """
        # 1. Apply random crop and resize (differentiable)
        if not skip_random_crop_resize:
            x = self.apply_random_crop_resize(
                x, crop_percent=FORWARD_CROP_PERCENT
            )

        # 2. Apply JPEG compression (1st pass - non-differentiable)
        if not skip_jpeg_compression:
            x = self.apply_jpeg_compression(x, quality=FORWARD_JPEG_QUALITY_1)

        # 3. Apply median filtering (partially differentiable)
        if not skip_median_filter:
            x = self.apply_median_filter(x, size=FORWARD_MEDIAN_SIZE)

        # 4. Apply FFT low-pass filter (differentiable)
        if not skip_fft_low_pass:
            x = self.apply_fft_low_pass(x, cutoff_frequency=FORWARD_FFT_CUTOFF)

        # 5. Apply bilateral filter (differentiable)
        if not skip_bilateral_filter:
            x = self.apply_bilateral_filter(
                x,
                d=FORWARD_BILATERAL_D,
                sigma_color=FORWARD_BILATERAL_SIGMA_COLOR,
                sigma_space=FORWARD_BILATERAL_SIGMA_SPACE,
            )

        # 6. Apply JPEG compression again (2nd pass - non-differentiable)
        if not skip_jpeg_compression:
            x = self.apply_jpeg_compression(x, quality=FORWARD_JPEG_QUALITY_2)

        return x

    def apply(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
        return self(x, **kwargs)


In [6]:
#| export

import contextlib
from pathlib import Path

import clip
import torch
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms

class AestheticPredictor(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.input_size = input_size
        self.layers = nn.Sequential(
            nn.Linear(self.input_size, 1024),
            nn.Dropout(0.2),
            nn.Linear(1024, 128),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.Dropout(0.1),
            nn.Linear(64, 16),
            nn.Linear(16, 1),
        )

    def forward(self, x):
        return self.layers(x)

class AestheticEvaluatorOriginal:
    def __init__(self):
        self.model_path = "/home/anhndt/pysvgenius/models/sac+logos+ava1-l14-linearMSE.pth"
        self.clip_model_path = "/home/anhndt/pysvgenius/models/ViT-L-14.pt"
        self.predictor, self.clip_model, self.preprocessor = self.load()

    def load(self):
        """Loads the aesthetic predictor model and CLIP model."""
        state_dict = torch.load(self.model_path, weights_only=True, map_location=device)

        # CLIP embedding dim is 768 for CLIP ViT L 14
        predictor = AestheticPredictor(768)
        predictor.load_state_dict(state_dict)
        predictor.to(device)
        predictor.eval()
        clip_model, preprocessor = clip.load(self.clip_model_path, device=device)

        for param in predictor.parameters():
            param.requires_grad = False

        return predictor, clip_model, preprocessor

    def score(self, image: Image.Image) -> float:
        """Predicts the CLIP aesthetic score of an image."""
        image = self.preprocessor(image).unsqueeze(0).to(device)

        with torch.no_grad():
            image_features = self.clip_model.encode_image(image)
            # l2 normalize
            image_features /= image_features.norm(dim=-1, keepdim=True)
            image_features = image_features.cpu().detach().numpy()

        score = self.predictor(torch.from_numpy(image_features).to(device).float())

        return score.item() / 10.0  # scale to [0, 1]



class AestheticEvaluatorTorch:
    def __init__(self):
        self.model_path = "/home/anhndt/pysvgenius/models/sac+logos+ava1-l14-linearMSE.pth"
        self.clip_model_path = "/home/anhndt/pysvgenius/models/ViT-L-14.pt"
        self.predictor, self.clip_model, self.preprocessor = self.load()

    def load(self):
        """Loads the aesthetic predictor model and CLIP model."""
        state_dict = torch.load(self.model_path, weights_only=True, map_location=device)

        # CLIP embedding dim is 768 for CLIP ViT L 14
        predictor = AestheticPredictor(768).half()
        predictor.load_state_dict(state_dict)
        predictor.to(device)
        predictor.eval()
        clip_model, _ = clip.load(self.clip_model_path, device=device)
        preprocessor = transforms.Compose(
            [
                transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
                transforms.CenterCrop(224),
                # transforms.Lambda(lambda x: x.clamp_(0, 1)),
                transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
            ]
        )

        return predictor, clip_model, preprocessor

    def score(self, image: torch.Tensor, no_grad: bool = False) -> float:
        """Predicts the CLIP aesthetic score of an image."""
        if image.ndim != 4:
            raise ValueError(f"image must be 4 channels (shape: {image.shape})")

        with torch.no_grad() if no_grad else contextlib.nullcontext():
            image = self.preprocessor(image)
            image_features = self.clip_model.encode_image(image)
            # l2 normalize
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)

            score_tensor = self.predictor(image_features)

        return score_tensor / 10.0  # scale to [0, 1]

In [7]:
#| export
def load_aesthetic_evaluators() -> tuple[AestheticEvaluatorTorch]:
    print("Loading Aesthetic Evaluators...")
    aesthetic_evaluator_torch = AestheticEvaluatorTorch()  # target_size is left as default (224)
    print("Aesthetic Evaluators loaded.")
    return aesthetic_evaluator_torch


def load_siglip_model(
    device, dtype, model_path_or_name
) -> tuple[AutoModel, AutoProcessor]:
    """Load the SIGLIP model and processor"""
    print(f"Loading SIGLIP model from: {model_path_or_name}...")
    resolved_model_path = "google/siglip-so400m-patch14-384"
    print(f"Using SIGLIP model at: {resolved_model_path}")

    # SIGLIP models are typically loaded in float32.
    # You can specify dtype, but be aware of compatibility.
    # For MPS (Mac), float16 might not be supported, so we fallback to float32.
    effective_dtype = dtype
    if device == "mps" and dtype == torch.float16:
        print("Warning: MPS device selected with float16 for SIGLIP. Forcing float32 for stability.")
        effective_dtype = torch.float32

    # Specify torch_dtype at loading and move to device with .to(device)
    model = AutoModel.from_pretrained(resolved_model_path, torch_dtype=effective_dtype).to(device)
    # processor = AutoProcessor.from_pretrained(resolved_model_path)

    model.eval()
    print(f"SIGLIP model loaded to {device} with dtype {model.dtype}.")
    return model


# Load models
aesthetic_evaluator_torch = load_aesthetic_evaluators()
siglip_model = load_siglip_model(
    device, torch.float16, model_path_or_name="google/siglip-so400m-patch14-384"
)

# Clear memory
import gc
gc.collect()
torch.cuda.empty_cache()


Loading Aesthetic Evaluators...
Aesthetic Evaluators loaded.
Loading SIGLIP model from: google/siglip-so400m-patch14-384...
Using SIGLIP model at: google/siglip-so400m-patch14-384
SIGLIP model loaded to cuda with dtype torch.float16.


In [8]:
import time
import argparse
from copy import deepcopy
from pathlib import Path

from tqdm import trange
import numpy as np

# import clip # CLIP is not needed
import torch
import pandas as pd  # For loading text prompts
import pydiffvg
import torch.nn.functional as F  # For image resizing

# Import PIL.Image
from PIL import Image

# Add imports related to PaliGemma
from transformers import (
    AutoModel,  # Added for SIGLIP
    AutoProcessor,
)


def render(canvas_width, canvas_height, shapes, shape_groups, seed=0):
    """Render the scene using pydiffvg"""
    _render = pydiffvg.RenderFunction.apply
    scene_args = pydiffvg.RenderFunction.serialize_scene(canvas_width, canvas_height, shapes, shape_groups)
    img = _render(
        canvas_width,      # width
        canvas_height,     # height
        2,                 # num_samples_x
        2,                 # num_samples_y
        seed,              # seed
        None,              # background_image
        *scene_args,
    )
    return img


# Cosine Annealing with Warmup learning rate scheduler
def get_lr(initial_lr, final_lr, iteration, warmup_iterations, total_iterations):
    if iteration < warmup_iterations:
        return initial_lr * (iteration + 1) / warmup_iterations
    progress = (iteration - warmup_iterations) / (total_iterations - warmup_iterations)
    return final_lr + (initial_lr - final_lr) * 0.5 * (1 + np.cos(np.pi * progress))


# --- Helper function (split out from optimize_single_svg) ---

def _calculate_target_image_embedding(
    image: Image, model: AutoModel, device: str, dtype: torch.dtype
) -> torch.Tensor | None:
    """Load the target image and compute SIGLIP image embedding"""

    image_size = model.config.vision_config.image_size
    target_img_pil = image.convert("RGB").resize((image_size, image_size))
    
    # Using processor is more robust
    # inputs = processor(images=target_img_pil, return_tensors="pt").to(device=device, dtype=dtype)
    # pixel_values = inputs["pixel_values"]
    
    target_img_torch = (
        torch.from_numpy(np.array(target_img_pil)).permute(2, 0, 1).unsqueeze(0) / 255.0 * 2.0
    ) - 1.0

    # Fast version
    pixel_values = target_img_torch.to(device=device, dtype=dtype)

    with torch.no_grad():
        target_embedding = model.get_image_features(pixel_values=pixel_values)
    target_embedding.requires_grad_(False)  # Gradients not needed
    print(f"  Computed encoding for target image. Shape: {target_embedding.shape}")
    return target_embedding


In [9]:
def _load_svg_and_prepare_params(svg_path: Path) -> tuple[int, int, list, list, dict] | None:
    """Load an SVG file and collect parameters for optimization"""
    if not svg_path.exists():
        print(f"  Warning: SVG file not found: {svg_path}")
        return None
    try:
        canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene(str(svg_path))

        # Resize the canvas to 384x384
        assert canvas_width == canvas_height, "Image must be square"
        ratio = 384 / canvas_width
        canvas_width = 384
        canvas_height = 384
        for shape in shapes:
            if hasattr(shape, "points") and shape.points is not None:
                shape.points = shape.points * ratio
            if hasattr(shape, "center") and shape.center is not None:
                shape.center = shape.center * ratio
            if hasattr(shape, "radius") and shape.radius is not None:
                shape.radius = shape.radius * ratio
            if hasattr(shape, "p_min") and shape.p_min is not None:
                shape.p_min = shape.p_min * ratio
            if hasattr(shape, "p_max") and shape.p_max is not None:
                shape.p_max = shape.p_max * ratio
            if hasattr(shape, "stroke_width") and shape.stroke_width is not None:
                shape.stroke_width = shape.stroke_width * ratio

        shape_groups = [g for g in shape_groups if g.shape_ids.numel() > 0]
        print(f"  Loaded SVG file. Canvas: {canvas_width}x{canvas_height}")

        points_vars = []
        stroke_width_vars = []
        color_vars = []

        # Extract parameters to be optimized from all but the background and overlay elements
        for shape in shapes[1:-2]:
            if hasattr(shape, "points") and shape.points is not None:
                shape.points.requires_grad = True
                points_vars.append(shape.points)
            if hasattr(shape, "center") and shape.center is not None:  # Circle
                shape.center.requires_grad = True
                points_vars.append(shape.center)
            if hasattr(shape, "radius") and shape.radius is not None:  # Circle
                shape.radius.requires_grad = True
                points_vars.append(shape.radius)
            if hasattr(shape, "p_min") and shape.p_min is not None:  # Rect
                shape.p_min.requires_grad = True
                points_vars.append(shape.p_min)
            if hasattr(shape, "p_max") and shape.p_max is not None:  # Rect
                shape.p_max.requires_grad = True
                points_vars.append(shape.p_max)
            if hasattr(shape, "stroke_width") and shape.stroke_width is not None:
                shape.stroke_width.requires_grad = True
                stroke_width_vars.append(shape.stroke_width)

        for group in shape_groups[:-2]:
            if hasattr(group, "fill_color") and group.fill_color is not None:
                group.fill_color.requires_grad = True
                color_vars.append(group.fill_color)
            if hasattr(group, "stroke_color") and group.stroke_color is not None:
                group.stroke_color.requires_grad = True
                color_vars.append(group.stroke_color)

        print(
            f"  Parameters to optimize: points={len(points_vars)}, stroke_width={len(stroke_width_vars)}, color={len(color_vars)}"
        )
        if not points_vars and not stroke_width_vars and not color_vars:
            print("  Warning: No optimizable parameters found.")
            return None

        params = {
            "points_vars": points_vars,
            "stroke_width_vars": stroke_width_vars,
            "color_vars": color_vars,
        }
        return canvas_width, canvas_height, shapes, shape_groups, params

    except Exception as e:
        print(f"  Error occurred while loading SVG file: {e}")
        return None


In [10]:
def _setup_optimizer(args, params: dict) -> tuple[torch.optim.Optimizer, list[dict]] | None:
    """Set up the optimizer and parameter groups"""
    coords_initial_lr = args.lr_points
    coords_final_lr = args.lr_points / 10
    color_initial_lr = args.lr_color
    color_final_lr = args.lr_color / 10
    stroke_initial_lr = 0  # Learning rate for stroke width (can be configurable via args)
    stroke_final_lr = 0

    param_groups = []
    if params["points_vars"]:
        param_groups.append(
            {
                "params": params["points_vars"],
                "lr": coords_initial_lr,
                "name": "coords",
                "initial_lr": coords_initial_lr,
                "final_lr": coords_final_lr,
            }
        )
    if params["stroke_width_vars"]:
        param_groups.append(
            {
                "params": params["stroke_width_vars"],
                "lr": stroke_initial_lr,
                "name": "stroke",
                "initial_lr": stroke_initial_lr,
                "final_lr": stroke_final_lr,
            }
        )
    if params["color_vars"]:
        param_groups.append(
            {
                "params": params["color_vars"],
                "lr": color_initial_lr,
                "name": "color",
                "initial_lr": color_initial_lr,
                "final_lr": color_final_lr,
            }
        )

    if not param_groups:
        print("  Error: No valid parameter groups found.")
        return None

    optimizer = torch.optim.Adam(param_groups)
    return optimizer, param_groups


In [11]:
def _run_optimization_loop(
    args,
    canvas_width,
    canvas_height,
    shapes,
    shape_groups,
    optimizer,
    param_groups,
    aesthetic_evaluator_torch,
    similarity_mode,  # モードを追加
    siglip_model=None,  # SIGLIPモデルを追加
    siglip_processor=None,  # SIGLIPプロセッサを追加
    text_prompt=None,  # テキストプロンプトを追加
    target_image_embedding=None,  # SIGLIP用ターゲットエンべディング
    device="cuda",
    dtype=torch.float16,
    image=None
):
    """Run optimization loop and return best results"""
    print(
        f"  Starting optimization to minimize loss (Aesthetic + {similarity_mode.upper()}) (Batch Size: {args.batch_size})..."
    )
    start_time = time.time()
    best_loss = float("inf")

    best_iteration = 0
    warmup_iterations = args.warmup_iter
    batch_size = args.batch_size

    image = image.resize((canvas_width, canvas_height))
    to_tensor = transforms.ToTensor()
    tensor_image = to_tensor(image)
    tensor_image = tensor_image.to(device)

    seed = int(time.time() * 1000) % (2**32 - 1)
    image_processor_torch = ImageProcessorTorch(seed=seed).to(device)
    image_processor_torch_ref = ImageProcessorTorch(seed=seed).to(device)

    
    for t in trange(args.iterations):
        # pydiffvg.save_svg(f"{t}.svg", canvas_width, canvas_height, shapes, shape_groups)
        

        for param_group in optimizer.param_groups:
            group_initial_lr = param_group["initial_lr"]
            group_final_lr = param_group["final_lr"]
            current_lr = get_lr(group_initial_lr, group_final_lr, t, warmup_iterations, args.iterations)
            param_group["lr"] = current_lr

        optimizer.zero_grad()

        img_render_single = (
            render(canvas_width, canvas_height, shapes, shape_groups, seed=t)[:, :, :3]
            .permute(2, 0, 1)
            .unsqueeze(0)
            .to(device=device, dtype=dtype)  # dtype
        )
        img_render_batch = img_render_single.repeat(batch_size, 1, 1, 1)  # (B, C, H_svg, W_svg)

        # # Resize to 224x224
        # img_render_batch = F.interpolate(img_render_batch, size=(224, 224), mode="bilinear", align_corners=False)

        # Apply ImageProcessorTorch (assumed to perform random cropping and resizing internally)
        # Ensure that the output size from ImageProcessorTorch matches the expected input size of PaliGemma
        # Here we assume they match, or we will resize later if needed

        if t < args.jpeg_iter:
            processed_img_render_batch = image_processor_torch.apply(img_render_batch, skip_jpeg_compression=True)
        else:
            processed_img_render_batch = image_processor_torch.apply(img_render_batch)  # (B, C, H_proc, W_proc), [0, 1]
        
        # --- Similarity loss computation ---
        similarity_loss_raw = torch.tensor(np.inf, device=device, dtype=dtype)
        similarity_loss = torch.tensor(0.0, device=device, dtype=dtype)

        # color mse
        tensor_image_batch = tensor_image.repeat(batch_size, 1, 1, 1)
        tensor_image_cropped_batch = image_processor_torch_ref.apply_random_crop_resize(
            tensor_image_batch, crop_percent=FORWARD_CROP_PERCENT
        )
        mse_loss = ((processed_img_render_batch - tensor_image_cropped_batch) ** 2).mean()

        try:
            # (B, C, H, W), [0,1] -> (B, C, H, W), [-1, 1]
            # The SIGLIP processor performs normalization internally, so you can either pass [0,1] or manually match the model's expected input.
            # Here we follow exp028 and manually normalize to [-1, 1].
            render_pixel_values_batch = (processed_img_render_batch * 2.0) - 1.0
            render_pixel_values_batch = render_pixel_values_batch.to(
                dtype=siglip_model.dtype
            )  # Match the dtype of the SIGLIP model

            # Resize to the input size expected by siglip_model
            render_pixel_values_batch = F.interpolate(
                render_pixel_values_batch,
                size=siglip_model.config.vision_config.image_size,
                mode="bilinear",
                align_corners=False,
            )

            rendered_image_embedding_batch = siglip_model.get_image_features(
                pixel_values=render_pixel_values_batch
            )  # (B, embed_dim)

            cosine_similarity_batch = torch.nn.functional.cosine_similarity(
                target_image_embedding, rendered_image_embedding_batch, dim=-1
            )  # (B,)
            similarity_loss_raw = -cosine_similarity_batch.mean()  # Negative sign for loss
            similarity_loss = similarity_loss_raw  # Normalization may not be necessary for SIGLIP, needs further consideration

        except Exception as e:
            print(f"  Error: Exception occurred during SIGLIP similarity calculation (iter {t}): {e}")
            break

        # --- Aesthetic score calculation (batch) ---
        try:
            # Pass processed_img_render_batch to the aesthetic_evaluator_torch
            # The resolution can be different from PaliGemma (AestheticEvaluatorTorch handles resizing internally)
            # processed_img_render_batch is in shape BCHW, range [0,1], with the main dtype
            aesthetic_score_batch = aesthetic_evaluator_torch.score(
                processed_img_render_batch.to(device, dtype=torch.float16)
            )  # (B,)
            aesthetic_loss = -aesthetic_score_batch.mean()
        except Exception as e:
            print(f" Error: Failed to compute aesthetic score (iter {t}): {e}")
            break

        # --- Total loss calculation (batch average) ---
        if t >= args.aesthetic_iter:
            if similarity_mode == "paligemma":
                loss = args.w_aesthetic * aesthetic_loss + args.w_paligemma * similarity_loss + args.w_mse * mse_loss
            elif similarity_mode == "siglip":
                loss = args.w_aesthetic * aesthetic_loss + args.w_siglip * similarity_loss + args.w_mse * mse_loss
            else:  # Should not happen
                loss = args.w_aesthetic * aesthetic_loss + args.w_mse * mse_loss
        else:  # During aesthetic_iter, use only aesthetic loss
            loss = args.w_aesthetic * aesthetic_loss

        if torch.isnan(loss):
            print(f"  Total loss became NaN (iter {t}). Stopping optimization.")
            break


        try:
            loss.backward()
            with torch.no_grad():
                for group in shape_groups:
                    if (
                        hasattr(group, "fill_color")
                        and group.fill_color is not None
                        and group.fill_color.grad is not None
                    ):
                        group.fill_color.grad[3] = 0.0  # Ignore alpha channel in gradient
                    if (
                        hasattr(group, "stroke_color")
                        and group.stroke_color is not None
                        and group.stroke_color.grad is not None
                    ):
                        group.stroke_color.grad[3] = 0.0  # Ignore alpha channel in gradient
        except Exception as e:
            print(f"  Error: Exception occurred during gradient computation (iter {t}): {e}")
            break

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(
            [p for pg in param_groups for p in pg["params"] if p.grad is not None],
            max_norm=args.grad_clip_norm
        )
        optimizer.step()

        # Clamp values to valid ranges after update
        with torch.no_grad():
            for group in shape_groups:
                if hasattr(group, "fill_color") and group.fill_color is not None and group.fill_color.requires_grad:
                    group.fill_color.data.clamp_(0.0, 1.0)
                if (
                    hasattr(group, "stroke_color")
                    and group.stroke_color is not None
                    and group.stroke_color.requires_grad
                ):
                    group.stroke_color.data.clamp_(0.0, 1.0)
            for shape in shapes:
                if (
                    hasattr(shape, "stroke_width")
                    and shape.stroke_width is not None
                    and shape.stroke_width.requires_grad
                ):
                    shape.stroke_width.data.relu_()
                if hasattr(shape, "radius") and shape.radius is not None and shape.radius.requires_grad:
                    shape.radius.data.clamp_(min=1.0)

        current_loss = loss.item()

        # If after aesthetic_iter and loss is improved, update best result
        if t >= args.aesthetic_iter and current_loss < best_loss:
            best_loss = current_loss
            best_iteration = t

        # Logging
        if (t + 1) % args.log_interval == 0 or t == 0:
            with torch.no_grad():
                current_aesthetic_score_torch_avg = aesthetic_score_batch.mean().item()
                current_similarity_loss_raw_avg = similarity_loss_raw.item()  # PaliGemma loss is already scalar
                current_similarity_loss_avg = similarity_loss.item()

            elapsed_time = time.time() - start_time
            lrs = {pg["name"]: pg["lr"] for pg in optimizer.param_groups}
            lr_str = "/".join([f"{lr:.6f}" for lr in lrs.values()])
            print(
                f"    Iter [{(t + 1):>4}/{args.iterations}], "
                f"LRs: {lr_str}, "
                f"AesScore(T Avg): {current_aesthetic_score_torch_avg:.6f}, "
                f"{similarity_mode.upper()}Loss(Raw Avg): {current_similarity_loss_raw_avg:.6f}, "
                f"{similarity_mode.upper()}Loss(Norm/Direct Avg): {current_similarity_loss_avg:.6f}, "
                f"Loss(A/Sim/Mse/T Avg): {aesthetic_loss.item():.4f}/{similarity_loss.item():.4f}/{mse_loss.item():.4f}/{loss.item():.6f}, "
                f"Best Loss: {best_loss:.6f} (iter {best_iteration}), "
                f"Time: {elapsed_time:.2f}s"
            )

    print(f"  Optimization completed! Best Loss: {best_loss:.6f} at iteration {best_iteration}")
    return shapes, shape_groups, -aesthetic_loss.item()


In [12]:
#| export
import re
import tempfile
import os
import vtracer

def svg_conversion(img, image_size=(256,256)):
    tmp_dir = tempfile.TemporaryDirectory()
    # Open the image, resize it, and save it to the temporary directory
    resized_img = img.resize(image_size)
    tmp_file_path = os.path.join(tmp_dir.name, "tmp.png")
    resized_img = resized_img.convert("RGB")
    resized_img.save(tmp_file_path)
    
    svg_path = os.path.join(tmp_dir.name, "gen_svg.svg")
    vtracer.convert_image_to_svg_py(
                tmp_file_path,
                svg_path,
                colormode="color",  # ["color"] or "binary"
                # hierarchical="cutout",  # ["stacked"] or "cutout"
                hierarchical="stacked",
                mode="polygon",
            )
    
    with open(svg_path, 'r', encoding='utf-8') as f:
        svg_str = f.read()
    
    return svg_str

def svg_conversion_division(img: Image.Image, image_size=(384, 384), num_divisions=5):
    """
    1. Resize the image based on image_size and divide it into num_divisions x num_divisions tiles.
    2. For each tile, crop the original image accordingly and convert it into SVG.
    3. Finally, merge all the tile SVG paths, adjusting coordinates directly instead of using transforms.
    """
    # Resize the input image to the specified image_size
    img_resized = img.resize(image_size, Image.LANCZOS)
    W, H = img_resized.size

    # Calculate the base size for each tile
    base_tile_w = W // num_divisions
    base_tile_h = H // num_divisions

    if base_tile_w == 0 or base_tile_h == 0:
        # If the resulting tile size becomes 0, issue a warning and continue with minimal size
        print(
            f"Warning: Calculated base tile size is very small ({base_tile_w}x{base_tile_h}). "
            "Resulting SVG might be distorted or empty."
        )
        # Set minimum size to 1 to continue processing
        base_tile_w = max(1, base_tile_w)
        base_tile_h = max(1, base_tile_h)

    all_shifted_path_tags = []
    overall_bg_color = None  # Background color of the entire SVG

    for r_idx in range(num_divisions):
        for c_idx in range(num_divisions):
            offset_x = c_idx * base_tile_w
            offset_y = r_idx * base_tile_h

            # Calculate the actual width and height for the current tile (last tiles may include remainders)
            current_tile_w = base_tile_w if c_idx < num_divisions - 1 else W - offset_x
            current_tile_h = base_tile_h if r_idx < num_divisions - 1 else H - offset_y

            if current_tile_w <= 0 or current_tile_h <= 0:
                continue  # Skip tiles with zero or negative size

            # Crop the tile from the resized image
            box = (offset_x, offset_y, offset_x + current_tile_w, offset_y + current_tile_h)
            tile_img = img_resized.crop(box)

            if tile_img.width == 0 or tile_img.height == 0:
                continue

            # Same logic as svg_conversion: convert tile to SVG using vtracer
            try:
                with tempfile.TemporaryDirectory() as tmp_dir_name:
                    tmp_file_path = os.path.join(tmp_dir_name, f"tmp_tile_{r_idx}_{c_idx}.png")
                    # Save in RGB mode
                    tile_img.convert("RGB").save(tmp_file_path)

                    svg_tile_output_path = os.path.join(tmp_dir_name, f"tile_gen_svg_{r_idx}_{c_idx}.svg")
                    vtracer.convert_image_to_svg_py(
                        tmp_file_path,
                        svg_tile_output_path,
                        colormode="color",
                        # hierarchical="cutout",  # Use cutout mode
                        hierarchical="stacked",
                        mode="polygon",
                        # Add other vtracer parameters here if needed
                    )
                    with open(svg_tile_output_path, encoding="utf-8") as f:
                        tile_svg_str = f.read()
            except Exception as e:
                print(f"Error processing tile ({r_idx},{c_idx}): {e}")
                continue

            # Extract <path> elements from the tile SVG
            raw_path_tags_from_tile = extract_paths(tile_svg_str)

            if not raw_path_tags_from_tile:
                continue

            # Try to get the background color from the first path of the first tile
            if r_idx == 0 and c_idx == 0 and not overall_bg_color:
                first_path_of_first_tile = raw_path_tags_from_tile[0]
                match_bg = re.search(r'fill="([^"]+)"', first_path_of_first_tile)
                if match_bg:
                    overall_bg_color = match_bg.group(1)

            # Offset the coordinates of each path
            for path_idx, path_tag_original in enumerate(raw_path_tags_from_tile):
                d_match = re.search(r'd="([^"]+)"', path_tag_original)
                if not d_match:
                    continue
                d_original = d_match.group(1)

                shifted_sub_paths_strings = []
                # Split into subpaths using "M...Z" pattern
                sub_paths_data = re.findall(r"M[^M]*?Z", d_original, flags=re.IGNORECASE)
                if not sub_paths_data and d_original.strip():  # If no subpaths, treat as one whole path
                    sub_paths_data = [d_original]

                valid_path_data_found = False
                for sub_d_segment in sub_paths_data:
                    # Tokens are commands (letters) or numbers
                    tokens = re.findall(r"[MLZHVCSQTA]|-?[\d\.]+", sub_d_segment, flags=re.IGNORECASE)
                    shifted_tokens_for_sub = []
                    k = 0
                    while k < len(tokens):
                        cmd = tokens[k]
                        shifted_tokens_for_sub.append(cmd)
                        k += 1

                        # Handle coordinates based on the number of arguments for each command
                        # vtracer polygon mode is expected to use mainly M and L (absolute coordinates)
                        # Assume coordinates are integers based on vtracer polygon output
                        coords_to_process = 0
                        if cmd.upper() in ("M", "L", "T"):
                            coords_to_process = 1  # One x,y pair
                        elif cmd.upper() in ("Q", "S"):
                            coords_to_process = 2  # Two x,y pairs
                        elif cmd.upper() == "C":
                            coords_to_process = 3  # Three x,y pairs
                        # A (Arc) is complex and assumed to be unused in polygon mode
                        # H, V are single coordinates

                        if cmd.upper() == "H":  # Horizontal line
                            if k < len(tokens):
                                try:
                                    x = int(float(tokens[k]))
                                    shifted_tokens_for_sub.append(str(x + offset_x))
                                    k += 1
                                except (ValueError, IndexError):
                                    break  # Parse error
                        elif cmd.upper() == "V":  # Vertical line
                            if k < len(tokens):
                                try:
                                    y = int(float(tokens[k]))
                                    shifted_tokens_for_sub.append(str(y + offset_y))
                                    k += 1
                                except (ValueError, IndexError):
                                    break  # Parse error
                        else:  # Commands with x,y pairs
                            for _ in range(coords_to_process):
                                if k + 1 < len(tokens):
                                    try:
                                        x = int(float(tokens[k]))
                                        y = int(float(tokens[k + 1]))
                                        shifted_tokens_for_sub.append(str(x + offset_x))
                                        shifted_tokens_for_sub.append(str(y + offset_y))
                                        k += 2
                                    except (ValueError, IndexError):
                                        k = len(tokens)  # Exit loop on error
                                        break
                                else:  # Not enough tokens
                                    k = len(tokens)
                                    break
                            if k == len(tokens) and _ < coords_to_process - 1:  # Ended prematurely
                                shifted_tokens_for_sub = [shifted_tokens_for_sub[0]]  # Only keep the command

                    if len(shifted_tokens_for_sub) > 1:  # Command + at least one argument
                        shifted_sub_paths_strings.append(" ".join(shifted_tokens_for_sub))
                        valid_path_data_found = True

                if valid_path_data_found:
                    shifted_d_final = "".join(shifted_sub_paths_strings)
                    # Replace only the d attribute in the original path tag
                    new_path_tag = re.sub(r'd="[^"]*"', f'd="{shifted_d_final}"', path_tag_original, count=1)
                    all_shifted_path_tags.append(new_path_tag)

    # Construct overall SVG header using resized width and height
    svg_final_header = f'<svg width="{W}" height="{H}" viewBox="0 0 {W} {H}" xmlns="http://www.w3.org/2000/svg">'

    # Concatenate all adjusted paths
    final_svg_str = svg_final_header + "".join(all_shifted_path_tags) + "</svg>"

    return final_svg_str


In [13]:
def remove_transform_from_path(path_text: str) -> str:
    """
    Takes an SVG <path> tag string and:
    - Removes transform="translate(tx,ty)"
    - Adds (tx, ty) to all coordinates in the `d` attribute
    Then returns the modified string.
    """
    # 1) Extract tx, ty from the transform attribute
    m_tx = re.search(
        r'transform=(["\'])\s*translate\(\s*([-+]?\d*\.?\d+)\s*,\s*([-+]?\d*\.?\d+)\s*\)\s*\1',
        path_text
    )
    if not m_tx:
        # If there's no transform, return as is
        return path_text

    tx, ty = float(m_tx.group(2)), float(m_tx.group(3))

    # 2) Extract the content of the `d` attribute
    m_d = re.search(r'd=(["\'])(?P<d>.*?)\1', path_text)
    if not m_d:
        # If `d` attribute not found, just remove transform and return
        return re.sub(r'\s+transform=(["\']).*?\1', '', path_text)

    d_orig = m_d.group('d')

    # 3) Find coordinate pairs “x,y” and shift them
    def shift_coord(m):
        x, y = float(m.group(1)), float(m.group(2))
        x2, y2 = x + tx, y + ty

        def fmt(v: float) -> str:
            return str(int(v)) if v.is_integer() else ('%f' % v).rstrip('0').rstrip('.')

        return fmt(x2) + ',' + fmt(y2)

    d_shifted = re.sub(r'([-+]?\d*\.?\d+),\s*([-+]?\d*\.?\d+)', shift_coord, d_orig)

    # 4) Replace the old `d` attribute with the new shifted one, and remove transform
    start, end = m_d.span()
    before_d = path_text[:start]
    after_d  = path_text[end:]
    new_path = before_d + f'd="{d_shifted}"' + after_d
    new_path = re.sub(r'\s+transform=(["\']).*?\1', '', new_path)

    return new_path


In [14]:
def extract_paths(svg):
    """
    Extracts all <path> opening tags from an SVG string,
    removes transform attributes, and returns a list
    of cleaned path tags in their original order.
    """
    # ① Extract only <path> tags in order
    tag_pattern = re.compile(r'<path\b[^>]*>', flags=re.IGNORECASE)
    raw_tags = [m.group(0) for m in tag_pattern.finditer(svg)]
    
    # ② Clean each path tag string
    return [remove_transform_from_path(t) for t in raw_tags]


In [15]:
def optimize_svg_path_vt(elem: str) -> str:
    """
    Takes a <path> element in M x,y L ... Z format,
    and returns the shortest equivalent path code that
    draws the same shape.
    Optimizes all "M…Z" subpaths in order.
    """
    # --- Extract attributes --------------------------------------------------
    d_match = re.search(r'd="([^"]+)"', elem)
    if not d_match:
        return elem
    d_raw = d_match.group(1)

    fill_m = re.search(r'fill="([^"]+)"', elem)
    fill = fill_m.group(1) if fill_m else None

    # --- Split by subpaths (M…Z blocks) -------------------------------------
    subpaths = re.findall(r'M[^M]*?Z', d_raw)

    optimized_subs = []
    for sub in subpaths:
        # --- Convert string to coordinate list -------------------------------
        tokens = re.findall(r'[MLZ]|-?\d+', sub)
        pts, cmd, i = [], None, 0
        while i < len(tokens):
            t = tokens[i]
            if t in ('M', 'L', 'Z'):
                cmd, i = t, i + 1
                continue
            if cmd in ('M', 'L'):
                pts.append((int(t), int(tokens[i + 1])))
                i += 2
            else:
                i += 1

        # Skip this subpath if no valid coordinates
        if not pts:
            continue

        # --- Build the shortest path command sequence ------------------------
        d_parts = [f'M{pts[0][0]} {pts[0][1]}']
        prev_x, prev_y = pts[0]

        for x, y in pts[1:]:
            dx, dy = x - prev_x, y - prev_y

            cands = []
            # ― Absolute commands ―
            if dy == 0:
                cands.append(f'H{x}')
            if dx == 0:
                cands.append(f'V{y}')
            cands.append(f'L{x} {y}')
            # ― Relative commands ―
            if dy == 0:
                cands.append(f'h{dx}')
            if dx == 0:
                cands.append(f'v{dy}')
            cands.append(f'l{dx} {dy}')

            # Choose the shortest command
            d_parts.append(min(cands, key=len))
            prev_x, prev_y = x, y

        d_parts.append('Z')
        # Skip if only "M" and "Z"
        if len(d_parts) > 2:
            optimized_subs.append(''.join(d_parts))

    if not optimized_subs:
        return None

    d_optimized = ''.join(optimized_subs)
    return f'<path d="{d_optimized}"' + (f' fill="{fill}"' if fill else '') + '/>'


In [16]:
def extract_svg_size_with_regex(svg):
    """Extract width and height from an SVG tag using regular expressions"""
    w_match = re.search(r'<svg[^>]*\bwidth=["\']([^"\']+)["\']', svg)
    h_match = re.search(r'<svg[^>]*\bheight=["\']([^"\']+)["\']', svg)
    width  = w_match.group(1) if w_match else None
    height = h_match.group(1) if h_match else None
    return width, height


In [17]:
def svg_compress(svg):
    """
    Compresses an SVG string by:
    - Extracting all paths
    - Optimizing each path to its shortest form
    - Rebuilding a compact SVG with consistent header and footer
    """
    path_list = extract_paths(svg)
    path_list = [optimize_svg_path_vt(p) for p in path_list]
    path_list = [p for p in path_list if p is not None]

    width, height = extract_svg_size_with_regex(svg)
    header = f'<svg width="384" height="384" viewBox="0 0 {width} {height}">'
    fill_background = re.search(r'fill="([^"]+)"', path_list[0])
    header += f'<rect width="{width}" height="{height}" fill="{fill_background.group(1)}"/>'
    
    # footer = f'<path d="M167 200h20l-20 30h20" fill="none" stroke="white" stroke-width="{4*int(width)/256}" transform="scale({int(width)/256} {int(height)/256})" stroke-linecap="round" stroke-linejoin="round"/><path d="M167 200h20l-20 30h20" fill="none" stroke="#EDCE6A" stroke-width="{2*int(width)/256}" transform="scale({int(width)/256} {int(height)/256})" stroke-linecap="round" stroke-linejoin="round"/></svg>'
    footer = "</svg>"
    return header + ''.join(path_list[1:]) + footer


In [18]:
def vtracer_png_to_svg(image, size_range=(50, 500), limit=9500):
    """
    Convert the given image to an SVG and compress it,
    using binary search to find the largest size such that
    the length of the resulting SVG string does not exceed `limit`.

    Parameters
    ----------
    image : any
        Input image object to be vectorized into SVG.
    size_range : tuple(int, int)
        Minimum and maximum image sizes (inclusive) to search over.
    limit : int
        Maximum allowed length (in characters) for the SVG string.

    Returns
    -------
    best_size : int
        The largest image size for which len(svg) <= limit.
    best_svg : str
        The SVG string converted and compressed at best_size.

    Raises
    ------
    ValueError
        If no image size within `size_range` can satisfy the length limit.
    """
    lo, hi = size_range
    best_size = None
    best_svg = None

    # Binary search to find the largest size that keeps SVG under limit
    while abs(lo - hi) > 5:
        mid = (lo + hi) // 2

        # Convert to SVG and compress
        svg = svg_conversion(image, (mid, mid))
        svg = svg_compress(svg)
        length = len(svg)

        if length <= limit:
            # This size is valid → Try larger sizes
            best_size = mid
            best_svg = svg
            lo = mid + 1
        else:
            # SVG too large → Try smaller sizes
            hi = mid - 1

    # If size is small, try subdividing the image and recompress
    if best_size < 256:
        return best_svg
    else:
        for i in range(2, 14):
            svg = svg_conversion_division(image, (256, 256), i)
            svg = svg_compress(svg)
            length = len(svg)
            if length < 9000:
                best_svg = svg
            else:
                break

    return best_svg


In [19]:
# |export
import time  # Added for plotting and measuring execution time
import contextlib
import io
from pathlib import Path
from typing import List

import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn.functional as F  # Added for interpolate
import torchvision.transforms as transforms
import re
from collections import defaultdict

def svg_to_png(svg_code: str, size: tuple = (384, 384)) -> Image.Image:
    """
    Converts an SVG string to a PNG image using CairoSVG.
    If the SVG does not define a `viewBox`, it will add one using the provided size.

    Parameters
    ----------
    svg_code : str
        The SVG string to convert.
    size : tuple[int, int], default=(384, 384)
        The desired size of the output PNG image (width, height).

    Returns
    -------
    PIL.Image.Image
        The generated PNG image.
    """
    # Add viewBox if not present
    if 'viewBox' not in svg_code:
        svg_code = svg_code.replace('<svg', f'<svg viewBox="0 0 {size[0]} {size[1]}"')

    # Convert SVG to PNG
    png_data = cairosvg.svg2png(bytestring=svg_code.encode('utf-8'))
    return Image.open(io.BytesIO(png_data)).convert('RGB').resize(size)


def tensor_to_pil(tensor):
    """
    Convert a tensor in HWC format to a PIL Image.
    """
    if tensor.dim() == 4:
        tensor = tensor.squeeze(0)  # Remove batch dimension

    # Ensure RGB format
    if tensor.shape[2] == 4:  # RGBA format
        tensor = tensor[:, :, :3]  # Remove alpha channel

    # Convert from [0, 1] to [0, 255]
    tensor = tensor.cpu().detach() * 255
    tensor = tensor.to(torch.uint8)

    # Convert to PIL image (expects HWC format)
    img = Image.fromarray(tensor.numpy())
    return img


def extract_rect_and_path(svg_text, resize):
    """
    Extract opening <rect> and <path> tags from SVG string,
    removing opacity, converting RGB to hex, and rounding values.

    Returns:
        A list of processed tag strings in order of appearance.
    """
    tag_pattern = re.compile(r'<(?:rect|path)\b[^>]*>', flags=re.IGNORECASE)
    raw_tags = [m.group(0) for m in tag_pattern.finditer(svg_text)]

    return [ _process_tag(t, resize) for t in raw_tags ]


def _process_tag(tag, resize):
    # 1) Remove `opacity` attributes
    tag = re.sub(r'\s*opacity\s*=\s*"[^"]*"', '', tag, flags=re.IGNORECASE)

    # 2) Convert fill="rgb(R,G,B)" to fill="#RRGGBB"
    def _rgb_to_hex(m: re.Match) -> str:
        nums = [float(m.group(i)) for i in (1,2,3)]
        ints = [int(round(v)) for v in nums]
        return f'fill="#{ints[0]:02X}{ints[1]:02X}{ints[2]:02X}"'
    
    tag = re.sub(
        r'fill\s*=\s*"rgb\(\s*([0-9]+(?:\.[0-9]+)?)\s*,\s*([0-9]+(?:\.[0-9]+)?)\s*,\s*([0-9]+(?:\.[0-9]+)?)\s*\)"',
        _rgb_to_hex,
        tag,
        flags=re.IGNORECASE
    )

    # 3) Round all decimal numbers and scale by resize factor
    def _round_num(m: re.Match) -> str:
        return str(int(round(float(m.group(0)) * resize / 384)))
    
    tag = re.sub(r'-?\d+\.\d+', _round_num, tag)
    return tag


def _close_path_with_Z(d: str) -> str:
    """
    If the path string ends with 'L x y', replace it with 'Z'.
    Otherwise, return the original path.
    """
    m = re.match(r'^(.*?)(?:\s+L\s+[-+]?\d*\.?\d+(?:[,\s]+[-+]?\d*\.?\d+)\s*)$', d)
    if m:
        return m.group(1) + ' Z'
    return d


def merge_paths_by_fill(paths: List[str]) -> List[str]:
    """
    Merge <path> elements with the same fill color.
    - Concatenate `d` attributes.
    - Close each subpath with 'Z' if necessary.
    
    Returns:
        List of <path> elements with merged `d` and same fill.
    """
    paths_by_fill = defaultdict(list)

    for path_str in paths:
        # Extract `d` attribute
        d_match = re.search(r'd\s*=\s*["\']([^"\']+)["\']', path_str)
        if not d_match:
            continue
        d = d_match.group(1)

        # Extract fill attribute
        fill_match = re.search(r'fill\s*=\s*["\']([^"\']*)["\']', path_str)
        fill = fill_match.group(1) if fill_match else ''

        # Fix closing
        d_fixed = _close_path_with_Z(d)
        paths_by_fill[fill].append(d_fixed)

    # Merge each group
    merged_paths = []
    for fill, d_list in paths_by_fill.items():
        merged_d = " ".join(d_list)
        fill_attr = f' fill="{fill}"' if fill else ''
        merged_path = f'<path d="{merged_d}"{fill_attr} />'
        merged_paths.append(merged_path)

    return merged_paths


def _shortest_segment(prev, cur):
    """
    Return the shortest SVG command for moving from `prev` to `cur`.
    Uses relative and absolute commands and chooses the shortest.
    """
    px, py = prev
    x, y = cur
    dx, dy = x - px, y - py
    cands = [
        (f'H{x}'      , abs(dy)==0),           # Absolute horizontal
        (f'h{dx}'     , abs(dy)==0),           # Relative horizontal
        (f'V{y}'      , abs(dx)==0),           # Absolute vertical
        (f'v{dy}'     , abs(dx)==0),           # Relative vertical
        (f'L{x} {y}'  , True),                 # Absolute diagonal
        (f'l{dx} {dy}', True),                 # Relative diagonal
    ]
    # Filter by condition and select the shortest string
    return min((s for s, ok in cands if ok), key=len)


In [20]:
def _encode_path(pts):
    """List of points → shortest SVG command sequence (with Z)"""
    d = [f'M{pts[0][0]} {pts[0][1]}']
    for a, b in zip(pts, pts[1:]):
        d.append(_shortest_segment(a, b))
    d.append('')
    return ''.join(d)

def svg_path_extream(elem: str) -> str | None:
    """
    Given a <path> element (possibly multiple M…Z subpaths),
    - Try all starting points (cyclic permutation)
    - Try both clockwise and counter-clockwise
    and return the shortest possible path `d`.

    Returns None if no valid path is found.
    """
    # --- Extract attributes -------------------------------------------------
    m_d = re.search(r'd="([^"]+)"', elem)
    d_raw = m_d.group(1)

    fill_m = re.search(r'fill="([^"]+)"', elem)
    fill   = f' fill="{fill_m.group(1)}"'

    # --- Extract subpaths (M…Z) ---------------------------------------------
    subs_raw = re.findall(r'M[^M]*?Z', d_raw)
    optimized = []

    for sub in subs_raw:
        # Tokenize to points
        toks = sub.replace(',', ' ').split()
        pts, cmd, i = [], None, 0
        while i < len(toks):
            t = toks[i]
            if t in ('M', 'L', 'Z'):
                cmd, i = t, i + 1
            elif cmd in ('M', 'L'):
                pt = (int(float(t)), int(float(toks[i + 1])))
                if len(pts) == 0 or pts[-1] != pt:
                    pts.append(pt)
                i += 2
            else:
                i += 1
        if len(pts) < 3:  # Skip if not enough points for a shape
            continue

        # --- Try all N×2 permutations (start point × direction) --------------
        best, best_len = sub, len(sub)
        for seq in (pts, pts[::-1]):  # Normal and reversed
            n = len(seq)
            for k in range(n):
                rot = seq[k:] + seq[:k]  # Rotate starting point
                d_candidate = _encode_path(rot)
                if (l := len(d_candidate)) < best_len:
                    best, best_len = d_candidate, l

        optimized.append(best)

    if not optimized:  # No valid shapes
        return None

    d_final = ''.join(optimized)
    return f'<path d="{d_final}"{fill}/>'


def optimize_svg_size(svg, resize=384, limit=False):
    """
    Resize and optimize an SVG for a fixed canvas.
    If `limit` is True, it trims the <path> elements to fit within 10KB.

    Returns an optimized SVG string.
    """
    canvas_width = resize
    canvas_height = resize
    header = f'<svg width="384" height="384" viewBox="0 0 {canvas_width} {canvas_height}">'
    # footer = f'<g fill="none" transform="scale({round(canvas_width/256,2)} {round(canvas_height/256,2)})" stroke-linecap="round" stroke-linejoin="round"><path d="M167 200h20l-20 30h20" stroke="#fff" stroke-width="4"/><path d="M167 200h20l-20 30h20" stroke="#EDCE6A" stroke-width="2"/></g></svg>'
    footer = "</svg>"
    # Ignore text layers (last 2 elements)
    svg_opt_list = extract_rect_and_path(svg, resize)[:-2]

    # Split background <rect> and <path> elements
    rect = svg_opt_list[0]
    rect = re.sub(r'\s*(?:x|y)="0"', '', rect)
    header += rect

    path_list = svg_opt_list[1:]
    # Merge paths with same fill color
    path_list = merge_paths_by_fill(path_list)

    # Compress and optimize paths
    for i in range(len(path_list)):
        path_list[i] = svg_path_extream(path_list[i])

    # Remove any None results (invalid paths)
    path_list = [p for p in path_list if p is not None]

    svg_opt = header + ''.join(path_list) + footer

    # Trim to 10KB if limit is True
    if limit:
        len_tmp = len((header + footer).encode())
        idx = 0
        for i in range(len(path_list)):
            len_tmp += len(path_list[i].encode())
            if len_tmp > 10000:
                break
            idx += 1
        svg_opt = header + "".join(path_list[:idx]) + footer

    return svg_opt


def optimize_svg_10k(svg):
    """
    Resize and optimize an SVG to fit under 10,000 bytes.
    Uses binary search to find the maximum allowed size.

    If it fails, forcibly trims to fit using `limit=True`.

    Returns an optimized SVG string.
    """
    lo, hi = (150, 500)
    best_size = None
    best_svg = None

    while abs(lo - hi) > 3:
        mid = (lo + hi) // 2

        # Resize and optimize
        svg_opt = optimize_svg_size(svg, mid)
        length = len(svg_opt.encode())

        if length <= 10000:
            # Acceptable → try larger size
            best_size = mid
            best_svg = svg_opt
            lo = mid + 1
        else:
            # Too large → reduce size
            hi = mid - 1
    
    if best_size is None:
        print("Failed to reduce under 10,000 bytes. Removing some paths.")
        svg_opt = optimize_svg_size(svg, 150, limit=True)
        return svg_opt
    else:
        print(f"Resized to {best_size}")

    return best_svg


In [21]:
class OptimizationArgs:
    iterations = 200
    jpeg_iter = 200
    aesthetic_iter = 0
    warmup_iter = 0
    log_interval = 10
    w_aesthetic = 100.0
    w_siglip = 100.0
    w_mse = 5000.0
    batch_size = 1
    lr_points = 0.3
    lr_color = 0.01
    grad_clip_norm = 1.0
    similarity_mode = "siglip"
    
    device = "cuda"
    dtype = torch.float16

def opt_svg(svg, image, skip_final_optimization=False):
    pydiffvg.set_device(torch.device(device))
    args = OptimizationArgs
    path_svg_original = Path("original.svg")
    path_svg_tmp = Path("tmp.svg")
    with open(path_svg_original, "w", encoding="utf-8") as f:
        f.write(svg)

    load_result = _load_svg_and_prepare_params(path_svg_original)
    canvas_width, canvas_height, shapes, shape_groups, params = load_result

    optimizer_setup = _setup_optimizer(args, params)
    
    optimizer, param_groups = optimizer_setup

    target_image_embedding = _calculate_target_image_embedding(
        image, siglip_model, args.device, args.dtype
    )
    
    optimization_result = _run_optimization_loop(
        args,
        canvas_width,
        canvas_height,
        shapes,
        shape_groups,
        optimizer,
        param_groups,
        aesthetic_evaluator_torch=aesthetic_evaluator_torch,
        similarity_mode=args.similarity_mode,
        siglip_model=siglip_model,
        target_image_embedding=target_image_embedding,
        device=args.device,
        dtype=args.dtype,
        image=image
    )

    best_shapes, best_shape_groups, best_aes = optimization_result

    pydiffvg.save_svg(str(path_svg_tmp), canvas_width, canvas_height, best_shapes, best_shape_groups)

    if not skip_final_optimization:
        svg_opt = optimize_svg_10k(path_svg_tmp.read_text())
    else:
        svg_opt = path_svg_tmp.read_text()
    
    return svg_opt, best_aes

In [22]:
# #| export
# from concurrent.futures import ProcessPoolExecutor
# from tqdm import tqdm

# def convert_all(images, max_workers=4):
#     svgs = []
#     with ProcessPoolExecutor(max_workers=max_workers) as executor:
#         for svg in tqdm(executor.map(vtracer_png_to_svg, images), total=len(images)):
#             svgs.append(svg)
#     return svgs

In [23]:
# import cairosvg


In [24]:
# from typing import List, Optional

# import torch
# from diffusers import AutoPipelineForText2Image
# from PIL import Image


# class SDXLTurboGenerator():
#     def __init__(
#         self,
#         model_path: str = "stabilityai/sdxl-turbo",
#         guidance_scale: float = 0.0,
#         num_inference_steps: int = 4,
#         device: str = "cuda",
#         seed: int = 42,
#         lora_path: Optional[str] = None,
#     ):
#         self.model_path = model_path
#         self.guidance_scale = guidance_scale
#         self.num_inference_steps = num_inference_steps
#         self.device = device
#         self.seed = seed
#         self.lora_path = lora_path
#         self.pipe = self._load_pipeline()

#     def _load_pipeline(self):
#         try:
#             pipe = AutoPipelineForText2Image.from_pretrained(
#                 self.model_path, torch_dtype=torch.float16, variant="fp16"
#             ).to(self.device)
#             return pipe
#         except Exception as e:
#             raise ValueError(f"Failed to load pipeline: {e}") from e

#     def process(
#         self,
#         prompt,
#         num_images: int = 1,
#         negative_prompt: str = "",
#         height: int = 512,
#         width: int = 512,
#         **kwargs,
#     ) -> List[Image.Image]:
#         try:
#             generator = torch.Generator("cuda")
#             if self.seed is not None:
#                 generator.manual_seed(self.seed)

#             outputs = self.pipe(
#                 prompt=prompt,
#                 num_images_per_prompt=num_images,
#                 negative_prompt=negative_prompt,
#                 height=height,
#                 width=width,
#                 num_inference_steps=self.num_inference_steps,
#                 guidance_scale=self.guidance_scale,
#                 generator=generator
#             )
#             return outputs.images
#         except Exception as e:
#             raise RuntimeError(f"Error in process(): {e}") from e


In [25]:
# generator = SDXLTurboGenerator()

In [26]:
# def gen_bitmap(prompt, seed):
#     images = generator.process(prompt, 2)
#     return images

In [27]:
# def predict(prompt: str) -> str:
#     time_start = time.time()
#     original_images = []
#     images = []

#     # Generate images from text
#     for seed in tqdm(range(4), desc="Generate image"):
#         bitmap = gen_bitmap(prompt, seed)
#         original_images.extend(bitmap)

#     # Convert bitmaps to SVGs
#     svgs = convert_all(original_images, max_workers=4)

#     for svg in svgs:
#         # Convert SVG to raster image (for scoring)
#         png_bytes = cairosvg.svg2png(bytestring=svg.encode('utf-8'))
#         image = Image.open(io.BytesIO(png_bytes))
#         images.append(image)

#     # Select top-N images by score
#     step1_n = 2
#     _, scores = rank_images(prompt, images)
#     scores = np.array(scores)
#     idx_top_n = np.argsort(-scores)[:step1_n]
#     # idx_top_n = np.arange(step1_n)  # alternative fixed selection

#     print(f"{scores=}")
#     print(f"{idx_top_n=}")

#     # Quick optimization stage
#     OptimizationArgs.iterations = 10
#     OptimizationArgs.jpeg_iter = 10
#     list_svgs_tmp = []
#     list_aes_tmp = []
#     list_images_tmp = []

#     for idx in idx_top_n:
#         svg = svgs[idx]
#         original_image = original_images[idx]
#         svg_tmp, aes_tmp = opt_svg(svg, original_image, skip_final_optimization=True)
#         list_svgs_tmp.append(svg_tmp)
#         list_aes_tmp.append(aes_tmp)
#         list_images_tmp.append(original_image)

#     bestidx = np.argmax(list_aes_tmp)

#     print(f"aes list = {list_aes_tmp}")
#     print(f"{bestidx=}")

#     svg = list_svgs_tmp[bestidx]
#     image = list_images_tmp[bestidx]
#     image.save("raw_image.png")

#     # Final optimization stage
#     OptimizationArgs.iterations = 100
#     OptimizationArgs.jpeg_iter = 100

#     """
#     # Alternative: use best index from initial CLIP scoring instead
#     bestidx, scores = rank_images(prompt, images)
#     svg = svgs[bestidx]
#     image = original_images[bestidx]
#     """

#     print(f"SVG length before optimization: {len(svg.encode())}")
#     svg, _ = opt_svg(svg, image)
#     print(f"SVG length after optimization: {len(svg.encode())}")

#     time_end = time.time()
#     time_total = time_end - time_start
#     print(f"{time_total=}")
#     return svg


In [28]:
# svg = predict("flat color illustration, watercolor painting of a fire xbreathing dragon, inspired by Tom Whalen, vibrant palette, bold outlines, simple shapes, app icon.")

In [29]:
# print(len(svg.encode()))

In [30]:
image = Image.open("/home/anhndt/pysvgenius/notebooks/raw_image.png")
with open("/home/anhndt/pysvgenius/notebooks/original.svg", "r") as f:
    svg = f.read()
    
svg, _ = opt_svg(svg, image)

  Loaded SVG file. Canvas: 384x384
  Parameters to optimize: points=121, stroke_width=121, color=122
  Computed encoding for target image. Shape: torch.Size([1, 1152])
  Starting optimization to minimize loss (Aesthetic + SIGLIP) (Batch Size: 1)...


  0%|          | 1/200 [00:00<02:10,  1.52it/s]

    Iter [   1/200], LRs: 0.300000/0.000000/0.010000, AesScore(T Avg): 0.532227, SIGLIPLoss(Raw Avg): -0.854980, SIGLIPLoss(Norm/Direct Avg): -0.854980, Loss(A/Sim/Mse/T Avg): -0.5322/-0.8550/0.0171/-53.367111, Best Loss: -53.367111 (iter 0), Time: 0.68s


  5%|▌         | 10/200 [00:02<00:46,  4.11it/s]

    Iter [  10/200], LRs: 0.298653/0.000000/0.009955, AesScore(T Avg): 0.584961, SIGLIPLoss(Raw Avg): -0.886719, SIGLIPLoss(Norm/Direct Avg): -0.886719, Loss(A/Sim/Mse/T Avg): -0.5850/-0.8867/0.0162/-66.165077, Best Loss: -66.165077 (iter 9), Time: 2.84s


 10%|█         | 21/200 [00:05<00:36,  4.92it/s]

    Iter [  20/200], LRs: 0.294032/0.000000/0.009801, AesScore(T Avg): 0.609863, SIGLIPLoss(Raw Avg): -0.895508, SIGLIPLoss(Norm/Direct Avg): -0.895508, Loss(A/Sim/Mse/T Avg): -0.6099/-0.8955/0.0155/-72.963921, Best Loss: -75.453850 (iter 18), Time: 4.89s


 16%|█▌        | 31/200 [00:06<00:31,  5.44it/s]

    Iter [  30/200], LRs: 0.286234/0.000000/0.009541, AesScore(T Avg): 0.641602, SIGLIPLoss(Raw Avg): -0.898926, SIGLIPLoss(Norm/Direct Avg): -0.898926, Loss(A/Sim/Mse/T Avg): -0.6416/-0.8989/0.0149/-79.394211, Best Loss: -79.394211 (iter 29), Time: 6.70s


 20%|██        | 41/200 [00:08<00:28,  5.60it/s]

    Iter [  40/200], LRs: 0.275450/0.000000/0.009182, AesScore(T Avg): 0.642090, SIGLIPLoss(Raw Avg): -0.899414, SIGLIPLoss(Norm/Direct Avg): -0.899414, Loss(A/Sim/Mse/T Avg): -0.6421/-0.8994/0.0144/-82.216873, Best Loss: -82.290421 (iter 38), Time: 8.53s


 26%|██▌       | 51/200 [00:10<00:27,  5.44it/s]

    Iter [  50/200], LRs: 0.261947/0.000000/0.008732, AesScore(T Avg): 0.667480, SIGLIPLoss(Raw Avg): -0.900391, SIGLIPLoss(Norm/Direct Avg): -0.900391, Loss(A/Sim/Mse/T Avg): -0.6675/-0.9004/0.0142/-85.621758, Best Loss: -85.621758 (iter 49), Time: 10.39s


 30%|███       | 61/200 [00:12<00:25,  5.37it/s]

    Iter [  60/200], LRs: 0.246057/0.000000/0.008202, AesScore(T Avg): 0.678711, SIGLIPLoss(Raw Avg): -0.902344, SIGLIPLoss(Norm/Direct Avg): -0.902344, Loss(A/Sim/Mse/T Avg): -0.6787/-0.9023/0.0140/-88.016617, Best Loss: -88.016617 (iter 59), Time: 12.23s


 36%|███▌      | 71/200 [00:14<00:22,  5.62it/s]

    Iter [  70/200], LRs: 0.228171/0.000000/0.007606, AesScore(T Avg): 0.670410, SIGLIPLoss(Raw Avg): -0.904297, SIGLIPLoss(Norm/Direct Avg): -0.904297, Loss(A/Sim/Mse/T Avg): -0.6704/-0.9043/0.0140/-87.364685, Best Loss: -88.261024 (iter 67), Time: 14.01s


 40%|████      | 81/200 [00:15<00:21,  5.59it/s]

    Iter [  80/200], LRs: 0.208729/0.000000/0.006958, AesScore(T Avg): 0.659180, SIGLIPLoss(Raw Avg): -0.896973, SIGLIPLoss(Norm/Direct Avg): -0.896973, Loss(A/Sim/Mse/T Avg): -0.6592/-0.8970/0.0137/-87.279015, Best Loss: -89.819351 (iter 74), Time: 15.76s


 46%|████▌     | 91/200 [00:17<00:20,  5.43it/s]

    Iter [  90/200], LRs: 0.188210/0.000000/0.006274, AesScore(T Avg): 0.680176, SIGLIPLoss(Raw Avg): -0.902832, SIGLIPLoss(Norm/Direct Avg): -0.902832, Loss(A/Sim/Mse/T Avg): -0.6802/-0.9028/0.0138/-89.242439, Best Loss: -91.293625 (iter 88), Time: 17.64s


 50%|█████     | 101/200 [00:19<00:18,  5.28it/s]

    Iter [ 100/200], LRs: 0.167120/0.000000/0.005571, AesScore(T Avg): 0.698730, SIGLIPLoss(Raw Avg): -0.912598, SIGLIPLoss(Norm/Direct Avg): -0.912598, Loss(A/Sim/Mse/T Avg): -0.6987/-0.9126/0.0138/-92.195816, Best Loss: -92.195816 (iter 99), Time: 19.45s


 56%|█████▌    | 111/200 [00:21<00:16,  5.45it/s]

    Iter [ 110/200], LRs: 0.145978/0.000000/0.004866, AesScore(T Avg): 0.706055, SIGLIPLoss(Raw Avg): -0.911621, SIGLIPLoss(Norm/Direct Avg): -0.911621, Loss(A/Sim/Mse/T Avg): -0.7061/-0.9116/0.0136/-93.928772, Best Loss: -93.928772 (iter 109), Time: 21.28s


 60%|██████    | 121/200 [00:23<00:14,  5.62it/s]

    Iter [ 120/200], LRs: 0.125305/0.000000/0.004177, AesScore(T Avg): 0.700195, SIGLIPLoss(Raw Avg): -0.912109, SIGLIPLoss(Norm/Direct Avg): -0.912109, Loss(A/Sim/Mse/T Avg): -0.7002/-0.9121/0.0136/-93.119934, Best Loss: -94.930260 (iter 113), Time: 23.11s


 66%|██████▌   | 131/200 [00:24<00:11,  5.78it/s]

    Iter [ 130/200], LRs: 0.105608/0.000000/0.003520, AesScore(T Avg): 0.700195, SIGLIPLoss(Raw Avg): -0.912109, SIGLIPLoss(Norm/Direct Avg): -0.912109, Loss(A/Sim/Mse/T Avg): -0.7002/-0.9121/0.0134/-94.393501, Best Loss: -95.893593 (iter 128), Time: 24.85s


 70%|███████   | 141/200 [00:26<00:10,  5.39it/s]

    Iter [ 140/200], LRs: 0.087374/0.000000/0.002912, AesScore(T Avg): 0.684570, SIGLIPLoss(Raw Avg): -0.906738, SIGLIPLoss(Norm/Direct Avg): -0.906738, Loss(A/Sim/Mse/T Avg): -0.6846/-0.9067/0.0131/-93.482246, Best Loss: -95.893593 (iter 128), Time: 26.70s


 76%|███████▌  | 151/200 [00:28<00:08,  5.55it/s]

    Iter [ 150/200], LRs: 0.071052/0.000000/0.002368, AesScore(T Avg): 0.708008, SIGLIPLoss(Raw Avg): -0.914062, SIGLIPLoss(Norm/Direct Avg): -0.914062, Loss(A/Sim/Mse/T Avg): -0.7080/-0.9141/0.0133/-95.771240, Best Loss: -96.040894 (iter 143), Time: 28.48s


 80%|████████  | 161/200 [00:30<00:06,  5.75it/s]

    Iter [ 160/200], LRs: 0.057043/0.000000/0.001901, AesScore(T Avg): 0.683105, SIGLIPLoss(Raw Avg): -0.909180, SIGLIPLoss(Norm/Direct Avg): -0.909180, Loss(A/Sim/Mse/T Avg): -0.6831/-0.9092/0.0133/-92.623894, Best Loss: -97.401848 (iter 150), Time: 30.25s


 86%|████████▌ | 171/200 [00:32<00:05,  5.70it/s]

    Iter [ 170/200], LRs: 0.045692/0.000000/0.001523, AesScore(T Avg): 0.692383, SIGLIPLoss(Raw Avg): -0.900879, SIGLIPLoss(Norm/Direct Avg): -0.900879, Loss(A/Sim/Mse/T Avg): -0.6924/-0.9009/0.0131/-93.870361, Best Loss: -97.401848 (iter 150), Time: 32.08s


 90%|█████████ | 180/200 [00:33<00:03,  5.56it/s]

    Iter [ 180/200], LRs: 0.037278/0.000000/0.001243, AesScore(T Avg): 0.716309, SIGLIPLoss(Raw Avg): -0.914551, SIGLIPLoss(Norm/Direct Avg): -0.914551, Loss(A/Sim/Mse/T Avg): -0.7163/-0.9146/0.0135/-95.592072, Best Loss: -97.401848 (iter 150), Time: 33.91s


 96%|█████████▌| 191/200 [00:35<00:01,  5.68it/s]

    Iter [ 190/200], LRs: 0.032010/0.000000/0.001067, AesScore(T Avg): 0.699707, SIGLIPLoss(Raw Avg): -0.910645, SIGLIPLoss(Norm/Direct Avg): -0.910645, Loss(A/Sim/Mse/T Avg): -0.6997/-0.9106/0.0131/-95.382980, Best Loss: -97.648201 (iter 183), Time: 35.70s


100%|██████████| 200/200 [00:37<00:00,  5.35it/s]


    Iter [ 200/200], LRs: 0.030017/0.000000/0.001001, AesScore(T Avg): 0.706055, SIGLIPLoss(Raw Avg): -0.913574, SIGLIPLoss(Norm/Direct Avg): -0.913574, Loss(A/Sim/Mse/T Avg): -0.7061/-0.9136/0.0133/-95.327522, Best Loss: -97.873169 (iter 194), Time: 37.44s
  Optimization completed! Best Loss: -97.873169 at iteration 194
Resized to 281
