In [1]:
from abc import ABC, abstractmethod


class ISVGRanker(ABC):
    @abstractmethod
    def process(svg_list: list[str], prompt: str = None):
        pass


In [2]:
import io

import cairosvg
from PIL import Image

def svg_to_png(svg_code: str, size: tuple = (384, 384)) -> Image.Image:
    """
    Converts an SVG string to a PNG image using CairoSVG.

    If the SVG does not define a `viewBox`, it will add one using the provided size.

    Parameters
    ----------
    svg_code : str
         The SVG string to convert.
    size : tuple[int, int], default=(384, 384)
         The desired size of the output PNG image (width, height).

    Returns
    -------
    PIL.Image.Image
         The generated PNG image.
    """
    # Ensure SVG has proper size attributes
    if "viewBox" not in svg_code:
        svg_code = svg_code.replace("<svg", f'<svg viewBox="0 0 {size[0]} {size[1]}"')

    # Convert SVG to PNG
    png_data = cairosvg.svg2png(bytestring=svg_code.encode("utf-8"))
    return Image.open(io.BytesIO(png_data)).convert("RGB").resize(size)

In [3]:
import io
from PIL import Image, ImageFilter
import matplotlib.pyplot as plt
import cv2
import cairosvg
import numpy as np
from PIL import Image
from skimage.metrics import structural_similarity as ssim

class ImageProcessor:
    def __init__(self, image: Image.Image, seed=None):
        """Initialize with either a path to an image or a PIL Image object."""
        self.image = image
        self.original_image = self.image.copy()
        if seed is not None:
            self.rng = np.random.RandomState(seed)
        else:
            self.rng = np.random

    def reset(self):
        self.image = self.original_image.copy()
        return self

    def visualize_comparison(
        self,
        original_name="Original",
        processed_name="Processed",
        figsize=(10, 5),
        show=True,
    ):
        """Display original and processed images side by side."""
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
        ax1.imshow(np.asarray(self.original_image))
        ax1.set_title(original_name)
        ax1.axis("off")

        ax2.imshow(np.asarray(self.image))
        ax2.set_title(processed_name)
        ax2.axis("off")

        title = f"{original_name} vs {processed_name}"
        fig.suptitle(title)
        fig.tight_layout()
        if show:
            plt.show()
        return fig

    def apply_median_filter(self, size=3):
        """Apply median filter to remove outlier pixel values.

        Args:
             size: Size of the median filter window.
        """
        self.image = self.image.filter(ImageFilter.MedianFilter(size=size))
        return self

    def apply_bilateral_filter(self, d=9, sigma_color=75, sigma_space=75):
        """Apply bilateral filter to smooth while preserving edges.

        Args:
             d: Diameter of each pixel neighborhood
             sigma_color: Filter sigma in the color space
             sigma_space: Filter sigma in the coordinate space
        """
        # Convert PIL Image to numpy array for OpenCV
        img_array = np.asarray(self.image)

        # Apply bilateral filter
        filtered = cv2.bilateralFilter(img_array, d, sigma_color, sigma_space)

        # Convert back to PIL Image
        self.image = Image.fromarray(filtered)
        return self

    def apply_fft_low_pass(self, cutoff_frequency=0.5):
        """Apply low-pass filter in the frequency domain using FFT.

        Args:
             cutoff_frequency: Normalized cutoff frequency (0-1).
                  Lower values remove more high frequencies.
        """
        # Convert to numpy array, ensuring float32 for FFT
        img_array = np.array(self.image, dtype=np.float32)

        # Process each color channel separately
        result = np.zeros_like(img_array)
        for i in range(3):  # For RGB channels
            # Apply FFT
            f = np.fft.fft2(img_array[:, :, i])
            fshift = np.fft.fftshift(f)

            # Create a low-pass filter mask
            rows, cols = img_array[:, :, i].shape
            crow, ccol = rows // 2, cols // 2
            mask = np.zeros((rows, cols), np.float32)
            r = int(min(crow, ccol) * cutoff_frequency)
            center = [crow, ccol]
            x, y = np.ogrid[:rows, :cols]
            mask_area = (x - center[0]) ** 2 + (y - center[1]) ** 2 <= r * r
            mask[mask_area] = 1

            # Apply mask and inverse FFT
            fshift_filtered = fshift * mask
            f_ishift = np.fft.ifftshift(fshift_filtered)
            img_back = np.fft.ifft2(f_ishift)
            img_back = np.real(img_back)

            result[:, :, i] = img_back

        # Clip to 0-255 range and convert to uint8 after processing all channels
        result = np.clip(result, 0, 255).astype(np.uint8)

        # Convert back to PIL Image
        self.image = Image.fromarray(result)
        return self

    def apply_jpeg_compression(self, quality=85):
        """Apply JPEG compression.

        Args:
             quality: JPEG quality (0-95). Lower values increase compression.
        """
        buffer = io.BytesIO()
        self.image.save(buffer, format="JPEG", quality=quality)
        buffer.seek(0)
        self.image = Image.open(buffer)
        return self

    def apply_random_crop_resize(self, crop_percent=0.05):
        """Randomly crop and resize back to original dimensions.

        Args:
             crop_percent: Percentage of image to crop (0-0.4).
        """
        width, height = self.image.size
        crop_pixels_w = int(width * crop_percent)
        crop_pixels_h = int(height * crop_percent)

        left = self.rng.randint(0, crop_pixels_w + 1)
        top = self.rng.randint(0, crop_pixels_h + 1)
        right = width - self.rng.randint(0, crop_pixels_w + 1)
        bottom = height - self.rng.randint(0, crop_pixels_h + 1)

        self.image = self.image.crop((left, top, right, bottom))
        self.image = self.image.resize((width, height), Image.BILINEAR)
        return self

    def apply(self):
        """Apply an ensemble of defenses."""
        return (
            self.apply_random_crop_resize(crop_percent=0.03)
            .apply_jpeg_compression(quality=95)
            .apply_median_filter(size=9)
            .apply_fft_low_pass(cutoff_frequency=0.5)
            .apply_bilateral_filter(d=5, sigma_color=75, sigma_space=75)
            .apply_jpeg_compression(quality=92)
        )


In [4]:
import string
from more_itertools import chunked
import statistics
import torch
import math
from PIL import Image
from transformers import (
    AutoProcessor,
    BitsAndBytesConfig,
    PaliGemmaForConditionalGeneration,
)

class PaliGemmaRanker(ISVGRanker):
    """Evaluates images based on their similarity to a given text description using multiple choice questions."""

    def __init__(self):
        self.quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
            bnb_4bit_compute_dtype=torch.float16,
        )
        self.letters = string.ascii_uppercase
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.model_path = "google/paligemma2-10b-mix-448"
        self.processor = AutoProcessor.from_pretrained(
            self.model_path, use_fast=True)
        self.model = PaliGemmaForConditionalGeneration.from_pretrained(
            self.model_path,
            low_cpu_mem_usage=True,
            quantization_config=self.quantization_config,
        ).to(self.device)
        self.question = ['Does image portray "SVG illustration of {}"?',
                         "Do the objects in the image match the following description: {}?"]
        self.choices = [["yes", "no"], ["yes", "no"]]
        self.answers = ['yes', 'yes']

    def score(self, questions, choices, answers, image, n=4):
        scores = []
        batches = (chunked(qs, n) for qs in [questions, choices, answers])
        for question_batch, choice_batch, answer_batch in zip(*batches, strict=True):
            scores.extend(
                self.score_batch(
                    image,
                    question_batch,
                    choice_batch,
                    answer_batch,
                )
            )
        return statistics.mean(scores)

    def score_batch(
        self,
        image: Image.Image,
        questions: list[str],
        choices_list: list[list[str]],
        answers: list[str],
    ) -> list[float]:
        """Evaluates the image based on multiple choice questions and answers.

        Parameters
        ----------
        image : PIL.Image.Image
            The image to evaluate.
        questions : list[str]
            List of questions about the image.
        choices_list : list[list[str]]
            List of lists of possible answer choices, corresponding to each question.
        answers : list[str]
            List of correct answers from the choices, corresponding to each question.

        Returns
        -------
        list[float]
            List of scores (values between 0 and 1) representing the probability of the correct answer for each question.
        """
        prompts = [
            self.format_prompt(question, choices)
            for question, choices in zip(questions, choices_list, strict=True)
        ]
        batched_choice_probabilities = self.get_choice_probability(
            image, prompts, choices_list
        )

        scores = []
        for i, _ in enumerate(questions):
            choice_probabilities = batched_choice_probabilities[i]
            answer = answers[i]
            answer_probability = 0.0
            for choice, prob in choice_probabilities.items():
                if choice == answer:
                    answer_probability = prob
                    break
            scores.append(answer_probability)

        return scores

    def format_prompt(self, question: str, choices: list[str]) -> str:
        prompt = f"<image>answer en Question: {question}\nChoices:\n"
        for i, choice in enumerate(choices):
            prompt += f"{self.letters[i]}. {choice}\n"
        return prompt

    def mask_choices(self, logits, choices_list):
        """Masks logits for the first token of each choice letter for each question in the batch."""
        batch_size = logits.shape[0]
        masked_logits = torch.full_like(logits, float("-inf"))

        for batch_idx in range(batch_size):
            choices = choices_list[batch_idx]
            for i in range(len(choices)):
                letter_token = self.letters[i]

                first_token = self.processor.tokenizer.encode(
                    letter_token, add_special_tokens=False
                )[0]
                first_token_with_space = self.processor.tokenizer.encode(
                    " " + letter_token, add_special_tokens=False
                )[0]

                if isinstance(first_token, int):
                    masked_logits[batch_idx, first_token] = logits[
                        batch_idx, first_token
                    ]
                if isinstance(first_token_with_space, int):
                    masked_logits[batch_idx, first_token_with_space] = logits[
                        batch_idx, first_token_with_space
                    ]

        return masked_logits

    def get_choice_probability(self, image, prompts, choices_list) -> list[dict]:
        inputs = self.processor(
            images=[image] * len(prompts),
            text=prompts,
            return_tensors="pt",
            padding="longest",
        ).to(self.device)

        with torch.no_grad():
            outputs = self.model(**inputs)
            # Logits for the last (predicted) token
            logits = outputs.logits[:, -1, :]
            masked_logits = self.mask_choices(logits, choices_list)
            probabilities = torch.softmax(masked_logits, dim=-1)

        batched_choice_probabilities = []
        for batch_idx in range(len(prompts)):
            choice_probabilities = {}
            choices = choices_list[batch_idx]
            for i, choice in enumerate(choices):
                letter_token = self.letters[i]
                first_token = self.processor.tokenizer.encode(
                    letter_token, add_special_tokens=False
                )[0]
                first_token_with_space = self.processor.tokenizer.encode(
                    " " + letter_token, add_special_tokens=False
                )[0]

                prob = 0.0
                if isinstance(first_token, int):
                    prob += probabilities[batch_idx, first_token].item()
                if isinstance(first_token_with_space, int):
                    prob += probabilities[batch_idx,
                                          first_token_with_space].item()
                choice_probabilities[choice] = prob

            # Renormalize probabilities for each question
            total_prob = sum(choice_probabilities.values())
            if total_prob > 0:
                renormalized_probabilities = {
                    choice: prob / total_prob
                    for choice, prob in choice_probabilities.items()
                }
            else:
                renormalized_probabilities = (
                    choice_probabilities  # Avoid division by zero if total_prob is 0
                )
            batched_choice_probabilities.append(renormalized_probabilities)

        return batched_choice_probabilities

    def ocr(self, image, free_chars=4, use_num_char=False):
        inputs = (
            self.processor(
                text="<image>ocr\n",
                images=image,
                return_tensors="pt",
            )
            .to(torch.float16)
            .to(self.model.device)
        )
        input_len = inputs["input_ids"].shape[-1]

        with torch.inference_mode():
            outputs = self.model.generate(
                **inputs, max_new_tokens=32, do_sample=False)
            outputs = outputs[0][input_len:]
            decoded = self.processor.decode(outputs, skip_special_tokens=True)

        num_char = len(decoded)

        # Exponentially decreasing towards 0.0 if more than free_chars detected
        # ---------------Modified Output----------------------
        return (
            min(1.0, math.exp(-num_char + free_chars))
            if not use_num_char
            else (min(1.0, math.exp(-num_char + free_chars)), decoded)
        )

    def process(self, svg_list: list[str], prompt: str = None):
        results = []

        question = [template.format(prompt) for template in self.question]
        choices = self.choices
        answers = self.answers

        for svg in svg_list:
            image_processor = ImageProcessor(
                image=svg_to_png(svg), seed=42).apply()
            image = image_processor.image.copy()
            score = self.score(
                questions=question,
                choices=choices,
                answers=answers,
                image=image
            )
            results.append({
                "svg": svg,
                "score": score
            })
        return results


  from .autonotebook import tqdm as notebook_tqdm


In [5]:
paligemma_ranker = PaliGemmaRanker()


Loading checkpoint shards: 100%|██████████| 4/4 [00:56<00:00, 14.06s/it]


In [6]:
import os

def load_svg_files_from_folder(folder_path: str) -> list[str]:
    svg_contents = []

    for filename in os.listdir(folder_path):
        if filename.lower().endswith(".svg"):
            file_path = os.path.join(folder_path, filename)
            with open(file_path, "r", encoding="utf-8") as f:
                svg_contents.append(f.read())

    return svg_contents


In [7]:
svg_folder = "../data/test"
svg_list = load_svg_files_from_folder(svg_folder)
prompt = "a purple forest at dusk"

In [8]:
score = paligemma_ranker.process([svg_list[0]], prompt)

In [9]:
score 


[{'svg': '<svg xmlns="http://www.w3.org/2000/svg" width="768" height="768" viewBox="0 0 384 384"><rect width="384" height="384" fill="#92769a"/><polygon points="31.0,0.0 1.0,383.0 154.0,383.0 167.0,190.0 119.0,70.0 99.0,288.0 107.0,45.0 52.0,209.0" fill="#b297b9"/><polygon points="252.0,6.0 247.0,232.0 190.0,233.0 184.0,105.0 247.0,366.0 383.0,383.0 350.0,0.0 318.0,319.0 270.0,317.0" fill="#b297b9"/><polygon points="99.0,0.0 187.0,272.0 195.0,0.0 171.0,70.0 152.0,15.0 169.0,154.0" fill="#050f1f"/><polygon points="252.0,0.0 197.0,0.0 184.0,104.0 232.0,66.0" fill="#f7dbe7"/><polygon points="256.0,0.0 272.0,316.0 318.0,317.0 332.0,4.0 299.0,195.0 318.0,0.0 300.0,60.0 278.0,5.0 286.0,188.0" fill="#4b0d55"/><polygon points="108.0,45.0 101.0,287.0 114.0,287.0" fill="#4b0d55"/><polygon points="205.0,254.0 160.0,320.0 156.0,383.0 211.0,382.0 224.0,286.0 213.0,383.0 229.0,363.0 224.0,383.0 282.0,383.0 238.0,378.0 246.0,320.0" fill="#1f2034"/><polygon points="351.0,0.0 383.0,341.0 383.0,0.0 371.