In [None]:
#| default_exp core

In [15]:
#| export

import concurrent
import io
import logging
import re
import re2

import cairosvg
import kagglehub
import torch
from lxml import etree
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

In [None]:
svg_constraints = kagglehub.package_import('metric/svg-constraints')

## FineTune Model with SFT 

> this is the stage 1, I have obtained the dataset from huggingface

In [2]:
from huggingface_hub import login
login()
from datasets import load_dataset


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [7]:
# the following dataset is curated and picked from huggingface, I will finetune my model based off this

ds = load_dataset("thesantatitan/deepseek-svg-dataset")

In [9]:
ds['train'][0]

{'prompt': "Generate svg code for an image that looks like: a black background with a white and red flower. Don't use markdown just give svg code\n",
 'completion': '\n<reasoning>\n\nAlright, so the user is asking for an SVG code that creates an image with a black background and a white and red flower. They specified not to use markdown and just provide the SVG code. \n\nFirst, I need to structure the SVG properly. The basic structure includes the <svg> tag with width and height attributes. I\'ll set it to 400x400 for a square canvas.\n\nNext, the background. I\'ll use a <rect> element covering the entire canvas with a black fill.\n\nNow, for the flower. I\'ll need multiple layers: a white petal layer, a red center, and stems. \n\nThe petals can be created using circles. I\'ll position them around the center point (200,200) at different angles using transform attributes. Each petal will be a circle with a white fill and a subtle stroke for texture.\n\nThe center of the flower will be a

In [19]:
train_dataset, test_dataset = ds['train'], ds['test']
train_dataset, test_dataset

(Dataset({
     features: ['prompt', 'completion'],
     num_rows: 4750
 }),
 Dataset({
     features: ['prompt', 'completion'],
     num_rows: 250
 }))

In [22]:
def preprocess_dataset(example):
    """
    Removes <reasoning> sections and <generated_svg> tags from the completion.
    """
    text_prompt_output = example['completion']
    text_prompt_output = re.sub(r"<reasoning>.*?</reasoning>\n?", "", text_prompt_output, flags=re.DOTALL)
    text_prompt_output = re.sub(r"</?generated_svg>", "", text_prompt_output)
    example['completion'] = text_prompt_output.strip()
    return example

# 3. Apply Cleaning to Train and Test
clean_train_dataset = train_dataset.map(preprocess_dataset)
clean_test_dataset = test_dataset.map(preprocess_dataset)

Map:   0%|          | 0/4750 [00:00<?, ? examples/s]

Map:   0%|          | 0/250 [00:00<?, ? examples/s]

In [25]:
clean_train_dataset, clean_test_dataset

(Dataset({
     features: ['prompt', 'completion'],
     num_rows: 4750
 }),
 Dataset({
     features: ['prompt', 'completion'],
     num_rows: 250
 }))

# Main Class Model

> This is the main model class which is chosen for prediction

In [None]:
#| export
import concurrent
import io
import logging
import re
import re2

import cairosvg
import kagglehub
import torch
from lxml import etree
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

svg_constraints = kagglehub.package_import('metric/svg-constraints')

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Model:
    def __init__(self):
         # Quantization Configuration
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
            bnb_4bit_compute_dtype=torch.float16,
        )
        self.model_path = kagglehub.model_download('google/gemma-2/Transformers/gemma-2-9b-it/2')
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_path,
            device_map="auto",
            quantization_config=quantization_config,
        )
        self.prompt_template = """Generate SVG code to visually represent the following text description, while respecting the given constraints.
<constraints>
* **Allowed Elements:** `svg`, `path`, `circle`, `rect`, `ellipse`, `line`, `polyline`, `polygon`, `g`, `linearGradient`, `radialGradient`, `stop`, `defs`
* **Allowed Attributes:** `viewBox`, `width`, `height`, `fill`, `stroke`, `stroke-width`, `d`, `cx`, `cy`, `r`, `x`, `y`, `rx`, `ry`, `x1`, `y1`, `x2`, `y2`, `points`, `transform`, `opacity`
</constraints>

<example>
<description>"A red circle with a blue square inside"</description>
```svg
<svg viewBox="0 0 256 256" width="256" height="256">
  <circle cx="50" cy="50" r="40" fill="red"/>
  <rect x="30" y="30" width="40" height="40" fill="blue"/>
</svg>
```
</example>


Please ensure that the generated SVG code is well-formed, valid, and strictly adheres to these constraints. Focus on a clear and concise representation of the input description within the given limitations. Always give the complete SVG code with nothing omitted. Never use an ellipsis.

<description>"{}"</description>
```svg
<svg viewBox="0 0 256 256" width="256" height="256">
"""
        self.default_svg = """<svg width="256" height="256" viewBox="0 0 256 256"><circle cx="50" cy="50" r="40" fill="red" /></svg>"""
        self.constraints = svg_constraints.SVGConstraints()
        self.timeout_seconds = 90

    # You could try increasing `max_new_tokens`
    def predict(self, description: str, max_new_tokens=512) -> str:
        def generate_svg():
            try:
                prompt = self.prompt_template.format(description)
                inputs = self.tokenizer(text=prompt, return_tensors="pt").to(DEVICE)

                with torch.no_grad():
                    output = self.model.generate(
                        **inputs,
                        max_new_tokens=max_new_tokens,
                        do_sample=True,
                        temperature=0.7,
                        top_p=0.9
                    )

                output_decoded = self.tokenizer.decode(output[0], skip_special_tokens=True)
                logging.debug('Output decoded from model: %s', output_decoded)

                matches = re.findall(r"<svg.*?</svg>", output_decoded, re.DOTALL | re.IGNORECASE)
                if matches:
                    svg = matches[-1]
                else:
                    return self.default_svg

                logging.debug('Unprocessed SVG: %s', svg)
                svg = self.enforce_constraints(svg)
                logging.debug('Processed SVG: %s', svg)
                # Ensure the generated code can be converted by cairosvg
                cairosvg.svg2png(bytestring=svg.encode('utf-8'))
                return svg
            except Exception as e:
                logging.error('Exception during SVG generation: %s', e)
                return self.default_svg

        # Execute SVG generation in a new thread to enforce time constraints
        with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
            future = executor.submit(generate_svg)
            try:
                return future.result(timeout=self.timeout_seconds)
            except concurrent.futures.TimeoutError:
                logging.warning("Prediction timed out after %s seconds.", self.timeout_seconds)
                return self.default_svg
            except Exception as e:
                logging.error(f"An unexpected error occurred: {e}")
                return self.default_svg

    def enforce_constraints(self, svg_string: str) -> str:
        """Enforces constraints on an SVG string, removing disallowed elements
        and attributes.

        Parameters
        ----------
        svg_string : str
            The SVG string to process.

        Returns
        -------
        str
            The processed SVG string, or the default SVG if constraints
            cannot be satisfied.
        """
        logging.info('Sanitizing SVG...')

        try:
            parser = etree.XMLParser(remove_blank_text=True, remove_comments=True)
            root = etree.fromstring(svg_string, parser=parser)
        except etree.ParseError as e:
            logging.error('SVG Parse Error: %s. Returning default SVG.', e)
            return self.default_svg
    
        elements_to_remove = []
        for element in root.iter():
            tag_name = etree.QName(element.tag).localname
    
            # Remove disallowed elements
            if tag_name not in self.constraints.allowed_elements:
                elements_to_remove.append(element)
                continue  # Skip attribute checks for removed elements
    
            # Remove disallowed attributes
            attrs_to_remove = []
            for attr in element.attrib:
                attr_name = etree.QName(attr).localname
                if (
                    attr_name
                    not in self.constraints.allowed_elements[tag_name]
                    and attr_name
                    not in self.constraints.allowed_elements['common']
                ):
                    attrs_to_remove.append(attr)
    
            for attr in attrs_to_remove:
                element.attrib[attr] = 'default_value'
                logging.debug(
                    'Attribute "%s" for element "%s" not allowed. Removing.',
                    attr,
                    tag_name,
                )
                del element.attrib[attr]
    
            # Check and remove invalid href attributes
            for attr, value in element.attrib.items():
                 if etree.QName(attr).localname == 'href' and not value.startswith('#'):
                    logging.debug(
                        'Removing invalid href attribute in element "%s".', tag_name
                    )
                    del element.attrib[attr]

            # Validate path elements to help ensure SVG conversion
            if tag_name == 'path':
                d_attribute = element.get('d')
                if not d_attribute:
                    logging.warning('Path element is missing "d" attribute. Removing path.')
                    elements_to_remove.append(element)
                    continue # Skip further checks for this removed element
                # Use regex to validate 'd' attribute format
                # path_regex = re2.compile(
                #     r'^'  # Start of string
                #     r'(?:'  # Non-capturing group for each command + numbers block
                #     r'[MmZzLlHhVvCcSsQqTtAa]'  # Valid SVG path commands (adjusted to exclude extra letters)
                #     r'\s*'  # Optional whitespace after command
                #     r'(?:'  # Non-capturing group for optional numbers
                #     r'-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?'  # First number
                #     r'(?:[\s,]+-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?)*'  # Subsequent numbers with mandatory separator(s)
                #     r')?'  # Numbers are optional (e.g. for Z command)
                #     r'\s*'  # Optional whitespace after numbers/command block
                #     r')+'  # One or more command blocks
                #     r'\s*'  # Optional trailing whitespace
                #     r'$'  # End of string
                # )
                path_regex = re2.compile(r'^[MmLlHhVvCcSsQqTtAaZz][0-9,\.\-\s]*$')

                if not path_regex.match(d_attribute):
                    logging.warning(
                        'Path element has malformed "d" attribute format. Removing path.'
                    )
                    elements_to_remove.append(element)
                    continue
                logging.debug('Path element "d" attribute validated (regex check).')
        
        # Remove elements marked for removal
        for element in elements_to_remove:
            if element.getparent() is not None:
                element.getparent().remove(element)
                logging.debug('Removed element: %s', element.tag)

        try:
            cleaned_svg_string = etree.tostring(root, encoding='unicode')
            return cleaned_svg_string
        except ValueError as e:
            logging.error(
                'SVG could not be sanitized to meet constraints: %s', e
            )
            return self.default_svg

In [None]:
import kagglehub
import polars as pl

train_path = kagglehub.competition_download('drawing-with-llms', 'train.csv')
train = pl.read_csv(train_path)

train.head()

In [None]:
from IPython.display import SVG

model = Model()
svg = model.predict('Moon surrounded by stars')

print(svg)
display(SVG(svg))