In [9]:
"""Utility functions."""
import io
import json
import math
import os
import pickle

import cairosvg
import matplotlib.pyplot as plt
from PIL import Image
import torch
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image


def show_images(images, titles=None, cols=4, figsize=(16, 8)):
  if isinstance(images, torch.Tensor):
    images = tensor_to_image(images)
  if not isinstance(images, list):
    images = [images]
  if isinstance(images[0], str):
    images = [svg_to_image(image) for image in images]

  plt.figure(figsize=figsize)
  rows = math.ceil(len(images) / cols)
  for i, image in enumerate(images):
    plt.subplot(rows, cols, i + 1)
    plt.imshow(image)
    if titles and i < len(titles):
        plt.title(titles[i])
    plt.axis("off")
  plt.tight_layout()
  plt.show()


def svg_to_image(svg):
  if isinstance(svg, str):
    svg = svg.encode('utf-8')
  if isinstance(svg, bytes):
    png_bytes = cairosvg.svg2png(bytestring=svg)
    return Image.open(io.BytesIO(png_bytes)).convert('RGB')
  else:
    return [svg_to_image(s) for s in svg]


def svg_to_tensor(svg, dtype=torch.bfloat16):
  if isinstance(svg, str) or isinstance(svg, bytes):
    return transforms.ToTensor()(svg_to_image(svg)).to("cuda", dtype=dtype)
  else:
    return torch.stack([svg_to_tensor(s) for s in svg])


def tensor_to_image(tensor):
  if tensor.ndim == 4:
    return [to_pil_image(img_tensor.to(torch.float32)) for img_tensor in tensor]
  return to_pil_image(tensor.to(torch.float32))


def image_to_tensor(image, dtype=torch.bfloat16):
  transform = transforms.ToTensor()
  if isinstance(image, list):
    tensor = torch.stack([transform(img) for img in image])
  else:
    tensor = transform(image)
  return tensor.to(device="cuda", dtype=dtype)


def mkdir(path):
  if not os.path.exists(path):
    os.makedirs(path, exist_ok=True)


def load_pickle(path):
  with open(path, "rb") as f:
    return pickle.load(f)


def write_pickle(o, path):
  if "/" in path:
    mkdir(path.rsplit("/", 1)[0])
  with open(path, "wb") as f:
    pickle.dump(o, f, -1)


def load_json(path):
  with open(path, "r") as f:
    return json.load(f)


def write_json(o, path):
  if "/" in path:
    mkdir(path.rsplit("/", 1)[0])
  with open(path, "w") as f:
    json.dump(o, f, indent=2)


def print_max_memory():
  print(f"Peak allocated memory: {torch.cuda.max_memory_allocated() / 1e6:.2f} MB")


def print_allocated_memory():
  print(f"Allocated memory: {torch.cuda.memory_allocated() / 1e6:.2f} MB")


def print_cached_memory():
  print(f"Cached memory: {torch.cuda.memory_reserved() / 1e6:.2f} MB")

In [None]:
"""Image -> svg conversion."""
from colorsys import rgb_to_hls
import math
import re
import random
import tempfile
import xml.etree.ElementTree as etree
from tqdm import tqdm
import cv2
from kmeans_gpu import KMeans
import numpy as np
import pydiffvg
import torch
from transformers import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup


"""Image -> svg conversion."""


def optimize_svg(svg, image, n_iter=100, point_lr=2.0, color_lr=0.05,
                 warmup_steps=0, cosine_schedule=False, loss_fn=None,
                 optimizer=torch.optim.Adam, return_best=False, max_color_deviation=None,
                 color_decay=0.0, point_decay=0.0):
    target = image_to_tensor([image], dtype=torch.float32)

    render = pydiffvg.RenderFunction.apply

    root = etree.fromstring(svg)
    canvas_width, canvas_height, shapes, shape_groups = pydiffvg.parse_scene(
        root)
    scene_args = pydiffvg.RenderFunction.serialize_scene(
        canvas_width, canvas_height, shapes, shape_groups)

    points_vars = []
    initial_points = {}
    for path in shapes:
        if not isinstance(path, pydiffvg.Rect):
            path.points.requires_grad = True
            points_vars.append(path.points)
            initial_points[path.points.data_ptr()] = path.points.data.clone()

    color_vars = {}
    initial_colors = {}
    for group in shape_groups:
        group.fill_color.requires_grad = True
        color_vars[group.fill_color.data_ptr()] = group.fill_color
        initial_colors[group.fill_color.data_ptr(
        )] = group.fill_color.data.clone()
    color_vars = list(color_vars.values())

    points_optim = optimizer(points_vars, lr=point_lr)
    color_optim = optimizer(color_vars, lr=color_lr)
    if cosine_schedule:
        points_sched = get_cosine_schedule_with_warmup(
            optimizer=points_optim,
            num_warmup_steps=warmup_steps,
            num_training_steps=n_iter
        )
        color_sched = get_cosine_schedule_with_warmup(
            optimizer=color_optim,
            num_warmup_steps=warmup_steps,
            num_training_steps=n_iter
        )
    else:
        points_sched = get_constant_schedule_with_warmup(
            points_optim, warmup_steps)
        color_sched = get_constant_schedule_with_warmup(
            color_optim, warmup_steps)

    best_svg = None
    lowest_loss = 1000
    losses = []
    for t in tqdm(range(n_iter)):
        points_optim.zero_grad()
        color_optim.zero_grad()
        # Forward pass: render the image.
        scene_args = pydiffvg.RenderFunction.serialize_scene(
            canvas_width, canvas_height, shapes, shape_groups)
        img = render(canvas_width,  # width
                     canvas_height,  # height
                     2,   # num_samples_x
                     2,   # num_samples_y
                     0,   # seed
                     None,  # bg
                     *scene_args)
        alpha = img[:, :, 3:4]
        img = alpha * img[:, :, :3] + (1 - alpha)
        img = img[:, :, :3].unsqueeze(0).permute(0, 3, 1, 2)

        loss = torch.abs(img - target).mean()

        loss.backward()

        points_optim.step()
        color_optim.step()
        points_sched.step()
        color_sched.step()

        for i, group in enumerate(shape_groups):
            initial_color = initial_colors[group.fill_color.data_ptr()]
            if color_decay > 0:
                group.fill_color.data = color_decay * \
                    group.fill_color.data + (1 - color_decay) * initial_color
            group.fill_color.data.clamp_(0.0, 1.0)
            if i == 0:
                group.fill_color.data[-1] = 1.0
            if max_color_deviation is not None:
                min_color = torch.clamp(
                    initial_color - max_color_deviation, 0.0, 1.0)
                max_color = torch.clamp(
                    initial_color + max_color_deviation, 0.0, 1.0)
                group.fill_color.data.clamp_(min_color, max_color)

        for point in points_vars:
            initial_point = initial_points[point.data_ptr()]
            if point_decay > 0:
                point.data = point_decay * point.data + \
                    (1 - point_decay) * initial_point
            point.data.clamp_(0.0, canvas_height)

    if best_svg is not None:
        return best_svg
    with tempfile.NamedTemporaryFile('r+', delete=False, suffix=".svg") as tmpfile:
        pydiffvg.save_svg(tmpfile.name, canvas_width,
                          canvas_height, shapes, shape_groups)
        tmpfile.seek(0)
        return tmpfile.read()

: 

In [3]:
from scour import scour
def optimize_svg_with_scour(svg):
    options = scour.parse_args([
        '--enable-viewboxing',
        '--enable-id-stripping',
        '--enable-comment-stripping',
        '--shorten-ids',
        '--indent=none',
        '--strip-xml-prolog',
        '--remove-metadata',
        '--remove-descriptive-elements',
        '--disable-embed-rasters',
        '--enable-viewboxing',
        '--create-groups',
        '--renderer-workaround',
        '--set-precision=2',
    ])

    svg = scour.scourString(svg, options)
    
    svg = svg.replace('id=""', '')
    svg = svg.replace('version="1.0"', '')
    svg = svg.replace('version="1.1"', '')
    svg = svg.replace('version="2.0"', '')
    svg = svg.replace('  ', ' ')
    svg = svg.replace('>\n', '>')
    
    return svg

In [4]:
with open("/home/anhndt/pysvgenius/notebooks/test_svg.svg", "r") as f:
    svg = f.read()
image = Image.open("/home/anhndt/pysvgenius/data/test/raw_image.png").convert("RGB")
resized_image = image.resize((384, 384), Image.Resampling.LANCZOS)


In [5]:
optimized_svg = optimize_svg(svg, resized_image, 200)
optimized_svg = optimize_svg_with_scour(optimized_svg)

100%|██████████| 200/200 [00:05<00:00, 36.76it/s]
