### Link: https://www.kaggle.com/competitions/drawing-with-llms
### Ref Notebook: https://www.kaggle.com/code/ryanholbrook/drawing-with-llms-getting-started-with-gemma-2

In [15]:
# # This Python 3 environment comes with many helpful analytics libraries installed
# # It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# # For example, here's several helpful packages to load

# import numpy as np # linear algebra
# import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# # Input data files are available in the read-only "../input/" directory
# # For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# # You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# # You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [1]:
import kagglehub
import pandas as pd

train_path = kagglehub.competition_download('drawing-with-llms', 'train.csv')
train = pd.read_csv(train_path)

train.head()

Unnamed: 0,id,description
0,04c411,a starlit night over snow-covered peaks
1,215136,black and white checkered pants
2,3e2bc6,crimson rectangles forming a chaotic grid
3,61d7a8,burgundy corduroy pants with patch pockets and...
4,6f2ca7,orange corduroy overalls


In [2]:
## Ref: https://www.kaggle.com/code/ryanholbrook/drawing-with-llms-getting-started-with-gemma-2

import concurrent
import io
import logging
import re
#import re2

import cairosvg
import kagglehub
import torch
from lxml import etree
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from unsloth import FastLanguageModel
import torch

svg_constraints = kagglehub.package_import('metric/svg-constraints')

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Model:
    def __init__(self):
        
        self.output_decoded = ""
        self.svg = ""
        
        self.model, self.tokenizer = FastLanguageModel.from_pretrained(
            model_name="unsloth/Llama-3.2-3B-Instruct",
            max_seq_length=2048,
            dtype=None,
            load_in_4bit=False,
            
        )
        FastLanguageModel.for_inference(self.model)
        
        self.prompt_template = """You are a helpful assistant specialized in generating minimal, valid SVG code.\
                                Generate SVG code to visually represent the following text description: {}"""
        
        self.default_svg = """<svg width="256" height="256" viewBox="0 0 256 256"><circle cx="50" cy="50" r="40" fill="red" /></svg>"""
        self.constraints = svg_constraints.SVGConstraints()
        self.timeout_seconds = 90

    
    ### You could try increasing `max_new_tokens`
    def predict(self, description: str, max_new_tokens=2048) -> str:
        def generate_svg():
            try:
                prompt = self.prompt_template.format(description)
                inputs = self.tokenizer(text=prompt, return_tensors="pt").to(DEVICE)

                with torch.no_grad():
                    output = self.model.generate(
                        **inputs,
                        max_new_tokens=max_new_tokens,
                        do_sample=False,  # Set to True for sampling-based generation
                        temperature=0.2,
                        #top_k=10,
                        #top_p=0.9,
                        #repetition_penalty=1.2,
                        #length_penalty=1.0,
                        #eos_token_id=<EOS_TOKEN_ID>,  # Replace with actual EOS token ID
                        #attention_mask=attention_mask_tensor,
                        #use_cache=True
                    )

                output_decoded = self.tokenizer.decode(output[0], skip_special_tokens=True)
                #self.output_decoded=output_decoded
                logging.debug('Output decoded from model: %s', output_decoded)

                matches = re.findall(r"\n<svg.*?</svg>\n", output_decoded, re.DOTALL | re.IGNORECASE)
                
                if matches:
                    svg = matches[-1]
                    ##self.svg_code=svg
                else:
                    return self.default_svg

                logging.debug('Unprocessed SVG: %s', svg)
                svg = self.enforce_constraints(svg)
                logging.debug('Processed SVG: %s', svg)
                # Ensure the generated code can be converted by cairosvg
                cairosvg.svg2png(bytestring=svg.encode('utf-8'))
                
                #svg, svg_unclean
                return svg
                
            except Exception as e:
                logging.error('Exception during SVG generation: %s', e)
                #svg, svg_unclean
                return self.default_svg

        # Execute SVG generation in a new thread to enforce time constraints
        with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
            future = executor.submit(generate_svg)
            try:
                return future.result(timeout=self.timeout_seconds)
            except concurrent.futures.TimeoutError:
                logging.warning("Prediction timed out after %s seconds.", self.timeout_seconds)
                return self.default_svg
            except Exception as e:
                logging.error(f"An unexpected error occurred: {e}")
                return self.default_svg
                

    def enforce_constraints(self, svg_string: str) -> str:
        """Enforces constraints on an SVG string, removing disallowed elements
        and attributes.

        Parameters
        ----------
        svg_string : str
            The SVG string to process.

        Returns
        -------
        str
            The processed SVG string, or the default SVG if constraints
            cannot be satisfied.
        """
        logging.info('Sanitizing SVG...')

        try:
            parser = etree.XMLParser(remove_blank_text=True, remove_comments=True)
            root = etree.fromstring(svg_string, parser=parser)
        except etree.ParseError as e:
            logging.error('SVG Parse Error: %s. Returning default SVG.', e)
            return self.default_svg
    
        elements_to_remove = []
        for element in root.iter():
            tag_name = etree.QName(element.tag).localname
    
            # Remove disallowed elements
            if tag_name not in self.constraints.allowed_elements:
                elements_to_remove.append(element)
                continue  # Skip attribute checks for removed elements
    
            # Remove disallowed attributes
            attrs_to_remove = []
            for attr in element.attrib:
                attr_name = etree.QName(attr).localname
                if (
                    attr_name
                    not in self.constraints.allowed_elements[tag_name]
                    and attr_name
                    not in self.constraints.allowed_elements['common']
                ):
                    attrs_to_remove.append(attr)
    
            for attr in attrs_to_remove:
                logging.debug(
                    'Attribute "%s" for element "%s" not allowed. Removing.',
                    attr,
                    tag_name,
                )
                del element.attrib[attr]
    
            # Check and remove invalid href attributes
            for attr, value in element.attrib.items():
                 if etree.QName(attr).localname == 'href' and not value.startswith('#'):
                    logging.debug(
                        'Removing invalid href attribute in element "%s".', tag_name
                    )
                    del element.attrib[attr]

            # Validate path elements to help ensure SVG conversion
            if tag_name == 'path':
                d_attribute = element.get('d')
                if not d_attribute:
                    logging.warning('Path element is missing "d" attribute. Removing path.')
                    elements_to_remove.append(element)
                    continue # Skip further checks for this removed element
                # Use regex to validate 'd' attribute format
                path_regex = re2.compile(
                    r'^'  # Start of string
                    r'(?:'  # Non-capturing group for each command + numbers block
                    r'[MmZzLlHhVvCcSsQqTtAa]'  # Valid SVG path commands (adjusted to exclude extra letters)
                    r'\s*'  # Optional whitespace after command
                    r'(?:'  # Non-capturing group for optional numbers
                    r'-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?'  # First number
                    r'(?:[\s,]+-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?)*'  # Subsequent numbers with mandatory separator(s)
                    r')?'  # Numbers are optional (e.g. for Z command)
                    r'\s*'  # Optional whitespace after numbers/command block
                    r')+'  # One or more command blocks
                    r'\s*'  # Optional trailing whitespace
                    r'$'  # End of string
                )
                if not path_regex.match(d_attribute):
                    logging.warning(
                        'Path element has malformed "d" attribute format. Removing path.'
                    )
                    elements_to_remove.append(element)
                    continue
                logging.debug('Path element "d" attribute validated (regex check).')
        
        # Remove elements marked for removal
        for element in elements_to_remove:
            if element.getparent() is not None:
                element.getparent().remove(element)
                logging.debug('Removed element: %s', element.tag)

        try:
            cleaned_svg_string = etree.tostring(root, encoding='unicode')
            return cleaned_svg_string
        except ValueError as e:
            logging.error(
                'SVG could not be sanitized to meet constraints: %s', e
            )
            return self.default_svg

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!



This code could modify your python environment or operating system.

Review this code at https://www.kaggle.com/code/metric/svg-constraints/versions/1
or in your download cache at /home/vino/.cache/kagglehub/notebooks/metric/svg-constraints/output/versions/1

It is strongly recommended that you run this code within a container
such as Docker to provide a secure, isolated execution environment.
See https://www.kaggle.com/docs/packages for more information.

Do you want to proceed? (y)es/[no]:  y


In [3]:
# import kagglehub
# package = kagglehub.package_import('dster/drawing-with-llms-starter-notebook/versions/3')
# model = package.Model()
# svg = model.predict('a goose winning a gold medal')

In [4]:
# import sys
# sys.path.append('/home/vino/ML_Projects/kaggle/03_LLM_Drawing_SVG/drawing-with-llms')
# import kaggle_evaluation
# kaggle_evaluation.test(Model)

In [5]:
model=Model()

==((====))==  Unsloth 2025.2.15: Fast Llama patching. Transformers: 4.49.0.
   \\   /|    GPU: NVIDIA GeForce RTX 4070 Ti SUPER. Max memory: 15.693 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.9. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


In [6]:
svg=model.predict('a goose winning a gold medal')

In [9]:
#kaggle_evaluation.test(Model)

In [10]:
# from transformers import AutoProcessor, AutoModel
# model_sl = AutoModel.from_pretrained("google/siglip-so400m-patch14-384")
# processor_sl = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-384")
# import torch
# from PIL import Image
# import cairosvg
# import os

# def svgMetric(prompt, svg):
    
#     try:
#         # Convert SVG to PNG
#         cairosvg.svg2png(svg, write_to="./tmp/temp.png")
        
#         # Open and process the image
#         image = Image.open('./tmp/temp.png').convert("RGB")
#         texts = ["SVG illustration of " + prompt]
#         inputs = processor_sl(text=texts, images=image, padding="max_length", return_tensors="pt")
        
#         # Inference without gradient tracking
#         with torch.no_grad():
#             outputs = model_sl(**inputs)
        
#         logits_per_image = outputs.logits_per_image
#         probs = torch.sigmoid(logits_per_image)
        
#         # Clean up temporary PNG file
#         #os.remove('./tmp/temp.png')
        
#         return probs[0][0].item()
    
#     except Exception as e:
#         print(f"An error occurred: {e}")
#         return None