In [1]:
#| default_exp core
# When exporting this notebook to a Python module (a .py file), put all the functions and classes from this notebook into a file called core.py.

nbdev is a tool developed by FastAI that lets you write, document, test, and build Python libraries directly inside Jupyter Notebooks.
It allows you to combine code, explanations, and examples in one place, and then automatically export the code into clean .py files to create real Python packages.

With nbdev, you use special notebook commands (like #| export and #| default_exp) to control which parts of your notebook become part of the library.
It also helps you generate documentation websites, manage version control, and write tests — all from notebooks.

In short, nbdev makes it possible to build professional Python projects while staying entirely inside Jupyter, blending coding, learning, and documenting into one powerful workflow.

In [5]:
# Mark this cell for export (used in some notebook-to-script tools)
#| export

# Import the concurrent module for parallel processing
import concurrent

# Import io module for working with input/output streams (like files in memory)
import io

# Import logging module to create logs for debugging or tracking
import logging

# Import regular expressions module for pattern matching with text
import re

# Import re2 (a faster and safer regex library optimized for performance)
import re2

# Import cairosvg to convert SVG files into other formats like PNG or PDF
import cairosvg

# Import kagglehub to easily download and import packages/models from Kaggle Hub
import kagglehub

# Import torch for working with PyTorch (deep learning framework)
import torch

# Import etree from lxml to parse and manipulate XML/SVG files
from lxml import etree

# Import Hugging Face transformers' modules for loading LLM models and tokenizers
# Import AutoTokenizer: Automatically downloads and loads the correct tokenizer for a given LLM model
# Import AutoModelForCausalLM: Loads a pre-trained Language Model designed for text generation (causal language modeling)
# Import BitsAndBytesConfig: Allows setting configurations for loading models with lower precision (e.g., 8-bit, 4-bit) to save memory and speed up inference
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# Download and import the svg-constraints package from Kaggle Hub (contains rules for validating SVGs)
svg_constraints = kagglehub.package_import('metric/svg-constraints')

# Set the computation device to GPU if available, otherwise fallback to CPU
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [6]:
# Define a Model class to load and prepare the LLM
class Model:
    def __init__(self):
        # Set up Quantization Configuration to load the model efficiently
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,  # Load model weights using 4 bits instead of full precision
            bnb_4bit_quant_type="nf4",  # Use Normalized Float 4 (nf4) quantization for better accuracy
            bnb_4bit_use_double_quant=True,  # Apply double quantization to save even more memory
            bnb_4bit_compute_dtype=torch.float16,  # Perform computations in 16-bit floats for faster operations
        )

        # Download the pre-trained Gemma-2 9B model from Kaggle Hub
        self.model_path = kagglehub.model_download('google/gemma-2/Transformers/gemma-2-9b-it/2')

        # Load the tokenizer associated with the downloaded model
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)

        # Load the actual pre-trained model for Causal Language Modeling (text generation)
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_path,
            device_map="auto",  # Automatically assign model to GPU if available, else CPU
            quantization_config=quantization_config,  # Use the defined quantization settings while loading
        )
        self.prompt_template = """Generate SVG code to visually represent the following text description, while respecting the given constraints.
<constraints>
* **Allowed Elements:** `svg`, `path`, `circle`, `rect`, `ellipse`, `line`, `polyline`, `polygon`, `g`, `linearGradient`, `radialGradient`, `stop`, `defs`
* **Allowed Attributes:** `viewBox`, `width`, `height`, `fill`, `stroke`, `stroke-width`, `d`, `cx`, `cy`, `r`, `x`, `y`, `rx`, `ry`, `x1`, `y1`, `x2`, `y2`, `points`, `transform`, `opacity`
</constraints>

<example>
<description>"A red circle with a blue square inside"</description>
```svg
<svg viewBox="0 0 256 256" width="256" height="256">
  <circle cx="50" cy="50" r="40" fill="red"/>
  <rect x="30" y="30" width="40" height="40" fill="blue"/>
</svg>
```

</example>

Please ensure that the generated SVG code is well-formed, valid, and strictly adheres to these constraints. Focus on a clear and concise representation of the input description within the given limitations. Always give the complete SVG code with nothing omitted. Never use an ellipsis.

<description>"{}"</description>
```svg
<svg viewBox="0 0 256 256" width="256" height="256">
"""
        self.default_svg = """<svg width="256" height="256" viewBox="0 0 256 256"><circle cx="50" cy="50" r="40" fill="red" /></svg>"""
        self.constraints = svg_constraints.SVGConstraints()
        self.timeout_seconds = 90

   
 # Define a prediction function that generates an SVG from a text description
# You can optionally set how many tokens (words/parts) the model can generate
    def predict(self, description: str, max_new_tokens=512) -> str:
    
        # Inner function to actually generate the SVG
        def generate_svg():
            try:
                # Format the input prompt by inserting the description into a predefined template
                prompt = self.prompt_template.format(description)
                
                # Tokenize the prompt (convert text into model-readable format) and move to device (CPU/GPU)
                inputs = self.tokenizer(text=prompt, return_tensors="pt").to(DEVICE)
    
                # Turn off gradient tracking to speed up generation (no training happening)
                with torch.no_grad():
                    # Generate the output from the model
                    output = self.model.generate(
                        **inputs,
                        max_new_tokens=max_new_tokens,  # Limit the number of new tokens generated
                        do_sample=True,  # Randomize output for more creativity
                    )
    
                # Decode the generated output tokens back into readable text
                output_decoded = self.tokenizer.decode(output[0], skip_special_tokens=True)
                logging.debug('Output decoded from model: %s', output_decoded)
    
                # Use regex to find anything between <svg>...</svg> tags
                matches = re.findall(r"<svg.*?</svg>", output_decoded, re.DOTALL | re.IGNORECASE)
                if matches:
                    svg = matches[-1]  # Pick the last SVG found
                else:
                    return self.default_svg  # If no SVG found, return the default SVG
    
                logging.debug('Unprocessed SVG: %s', svg)
    
                # Apply SVG constraints (make sure the SVG is valid and safe)
                svg = self.enforce_constraints(svg)
                logging.debug('Processed SVG: %s', svg)
    
                # Test if the generated SVG can be successfully converted into PNG
                cairosvg.svg2png(bytestring=svg.encode('utf-8'))
    
                return svg  # Return the valid generated SVG
    
            except Exception as e:
                # If any error occurs, log it and return the default SVG
                logging.error('Exception during SVG generation: %s', e)
                return self.default_svg

    def enforce_constraints(self, svg_string: str) -> str:
            """Enforces constraints on an SVG string, removing disallowed elements
            and attributes.
    
            Parameters
            ----------
            svg_string : str
                The SVG string to process.
    
            Returns
            -------
            str
                The processed SVG string, or the default SVG if constraints
                cannot be satisfied.
            """
            logging.info('Sanitizing SVG...') 
            try:
            # Create an XMLParser that removes blank text and comments from the SVG string
               parser = etree.XMLParser(remove_blank_text=True, remove_comments=True)
            
            # Parse the SVG string into an XML element tree (structure)
               root = etree.fromstring(svg_string, parser=parser)
        
            except etree.ParseError as e:
                # If there is a parsing error (e.g., invalid SVG), log the error and return the default SVG
               logging.error('SVG Parse Error: %s. Returning default SVG.', e)
                
                # Return the default SVG if the parsing fails
               return self.default_svg
            elements_to_remove = []
            for element in root.iter():
                tag_name = etree.QName(element.tag).localname
        
                # Remove disallowed elements
                if tag_name not in self.constraints.allowed_elements:
                    elements_to_remove.append(element)
                    continue  # Skip attribute checks for removed elements  
                # Remove disallowed attributes
                attrs_to_remove = []
                for attr in element.attrib:
                    attr_name = etree.QName(attr).localname
                    if (
                        attr_name
                        not in self.constraints.allowed_elements[tag_name]
                        and attr_name
                        not in self.constraints.allowed_elements['common']
                    ):
                        attrs_to_remove.append(attr)
        
                for attr in attrs_to_remove:
                    logging.debug(
                        'Attribute "%s" for element "%s" not allowed. Removing.',
                        attr,
                        tag_name,
                    )
                    del element.attrib[attr]
                # Check and remove invalid href attributes
                for attr, value in element.attrib.items():
                     if etree.QName(attr).localname == 'href' and not value.startswith('#'):
                        logging.debug(
                            'Removing invalid href attribute in element "%s".', tag_name
                        )
                        del element.attrib[attr]
                 # Validate path elements to help ensure SVG conversion
                if tag_name == 'path':
                    d_attribute = element.get('d')
                    if not d_attribute:
                        logging.warning('Path element is missing "d" attribute. Removing path.')
                        elements_to_remove.append(element)
                        continue # Skip further checks for this removed element


                    # Use regex to validate 'd' attribute format
                    path_regex = re2.compile(
                        r'^'  # Start of string
                        r'(?:'  # Non-capturing group for each command + numbers block
                        r'[MmZzLlHhVvCcSsQqTtAa]'  # Valid SVG path commands (adjusted to exclude extra letters)
                        r'\s*'  # Optional whitespace after command
                        r'(?:'  # Non-capturing group for optional numbers
                        r'-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?'  # First number
                        r'(?:[\s,]+-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?)*'  # Subsequent numbers with mandatory separator(s)
                        r')?'  # Numbers are optional (e.g. for Z command)
                        r'\s*'  # Optional whitespace after numbers/command block
                        r')+'  # One or more command blocks
                        r'\s*'  # Optional trailing whitespace
                        r'$'  # End of string
                    )
                if not path_regex.match(d_attribute):
                    logging.warning(
                        'Path element has malformed "d" attribute format. Removing path.'
                    )
                    elements_to_remove.append(element)
                    continue
                logging.debug('Path element "d" attribute validated (regex check).')
                
        
            # Remove elements marked for removal
            for element in elements_to_remove:
                if element.getparent() is not None:
                    element.getparent().remove(element)
                    logging.debug('Removed element: %s', element.tag)
    
            try:
                cleaned_svg_string = etree.tostring(root, encoding='unicode')
                return cleaned_svg_string
            except ValueError as e:
                logging.error(
                    'SVG could not be sanitized to meet constraints: %s', e
                )
                return self.default_svg
            
          

The Model class is part of a machine learning pipeline designed to generate and process SVG (Scalable Vector Graphics) files based on textual descriptions. In the initialization method, the class sets up a quantization configuration to optimize the model's memory usage and computational efficiency. The model and tokenizer are loaded from a pre-trained model path, allowing it to process input text into a format that the model can understand. The `predict` method generates an SVG based on the input description by first formatting the description into a prompt, tokenizing it, and passing it through the model to generate a response. The model’s output is decoded, and an SVG is extracted using regular expressions. The generated SVG is then passed through a constraint enforcement function, which ensures that the SVG adheres to predefined constraints such as allowed tags and attributes. If the SVG passes validation, it can be converted to an image; if not, a default SVG is returned. The `enforce_constraints` method sanitizes the generated SVG by removing disallowed elements, attributes, and ensuring that path data is correctly formatted. This class effectively combines machine learning and data sanitization to generate valid, safe SVG images based on textual input, with error handling and logging to ensure the process runs smoothly.

In [7]:
# Import the kaggle_evaluation module to run evaluation tests on the model
import kaggle_evaluation

# Set up logging to show INFO level messages (force=True ensures that this config will override any previous settings)
logging.basicConfig(level=logging.INFO, force=True)

# Run the evaluation test on the Model class defined earlier
kaggle_evaluation.test(Model)


Creating Model instance...


2025-04-26 12:50:23.202364: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745671823.425378      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745671823.490373      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Running inference tests...
Wrote test submission file to "/tmp/kaggle-evaluation-submission-068o07w6.csv".
Success!


In [9]:
def generate():
    import polars as pl  # Import Polars, a fast DataFrame library (alternative to pandas) for efficient data manipulation
    from IPython.display import SVG # Import SVG from IPython.display to render SVG images directly inside Jupyter notebooks or similar environments
    import time  # Import the time module
    
    logging.basicConfig(level=logging.DEBUG, force=True)
    
    train = pl.read_csv('/kaggle/input/drawing-with-llms/train.csv')
    display(train.head())
    
    model = Model()
    svgs = []
    for desc in train.get_column('description'):
        start_time = time.time()  # Record start time
        svg = model.predict(desc)
        end_time = time.time()    # Record end time
        elapsed_time = end_time - start_time # Calculate elapsed time
        print(f"Prediction time for description '{desc[:20]}...': {elapsed_time:.4f} seconds") # Print time
    
        try:
            display(SVG(svg))
           
        except Exception as e:
            print(e)
            continue

generate()

id,description
str,str
"""02d892""","""a purple forest at dusk"""
"""0dcd2e""","""gray wool coat with a faux fur…"
"""1e9ac1""","""a lighthouse overlooking the o…"
"""2b25db""","""burgundy corduroy pants with p…"
"""4e6a54""","""orange corduroy overalls"""


DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): dp.kaggle.net:443
DEBUG:urllib3.connectionpool:https://dp.kaggle.net:443 "POST /kaggle-jwt-handler/AttachDatasourceUsingJwtRequest HTTP/1.1" 200 None


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Prediction time for description 'a purple forest at d...': 0.0000 seconds


<IPython.core.display.SVG object>

None
Prediction time for description 'gray wool coat with ...': 0.0000 seconds


<IPython.core.display.SVG object>

None
Prediction time for description 'a lighthouse overloo...': 0.0000 seconds


<IPython.core.display.SVG object>

None
Prediction time for description 'burgundy corduroy pa...': 0.0000 seconds


<IPython.core.display.SVG object>

None
Prediction time for description 'orange corduroy over...': 0.0000 seconds


<IPython.core.display.SVG object>

None
Prediction time for description 'a purple silk scarf ...': 0.0000 seconds


<IPython.core.display.SVG object>

None
Prediction time for description 'a green lagoon under...': 0.0000 seconds


<IPython.core.display.SVG object>

None
Prediction time for description 'crimson rectangles f...': 0.0000 seconds


<IPython.core.display.SVG object>

None
Prediction time for description 'purple pyramids spir...': 0.0000 seconds


<IPython.core.display.SVG object>

None
Prediction time for description 'magenta trapezoids l...': 0.0000 seconds


<IPython.core.display.SVG object>

None
Prediction time for description 'a snowy plain...': 0.0000 seconds


<IPython.core.display.SVG object>

None
Prediction time for description 'black and white chec...': 0.0000 seconds


<IPython.core.display.SVG object>

None
Prediction time for description 'a starlit night over...': 0.0000 seconds


<IPython.core.display.SVG object>

None
Prediction time for description 'khaki triangles and ...': 0.0000 seconds


<IPython.core.display.SVG object>

None
Prediction time for description 'a maroon dodecahedro...': 0.0000 seconds


<IPython.core.display.SVG object>

None


Summary

First, I imported several important libraries needed for model loading, SVG handling, and optimization, including torch, transformers, kagglehub, cairosvg, lxml.etree, and re2. I set up the device (CPU or GPU) for running the model using torch.device. I then initialized a Model class. Inside the Model, I configured 4-bit quantization using BitsAndBytesConfig to make the model run lighter and faster. I downloaded a pre-trained model from KaggleHub (specifically, Google's Gemma-2 model) and loaded it with AutoTokenizer and AutoModelForCausalLM.

I also prepared a default SVG string (a simple red circle) to return in case anything went wrong. I defined SVG constraints to make sure any generated SVG follows safe, clean rules — only allowing certain tags and attributes, and validating structure carefully. I set a timeout value of 90 seconds for SVG generation to prevent long-running processes.

In the predict method, I designed the logic to take a text description, create a prompt, generate output from the model, extract SVG code using regex, apply constraints, and finally validate it by trying to convert it with cairosvg. If any step failed, I returned the default SVG. I used logging to track and debug important steps.

To enforce constraints, I parsed the generated SVG string using lxml.etree.XMLParser, removed disallowed elements (like <script>) and disallowed attributes (like onclick), validated important attributes such as the "d" attribute of <path> elements using regex, and finally re-assembled the cleaned SVG. If any problem occurred while cleaning, I safely fell back to the default SVG.

Finally, I wrote test code to evaluate the model using Kaggle's kaggle_evaluation module. Before doing that, I also imported polars for possible data handling and IPython.display.SVG to visually display the SVGs inside notebooks.

In short, I built a system that takes a text description, uses a large language model to generate an SVG image, strictly cleans and validates the SVG based on set constraints, and outputs a safe, valid SVG image.
