In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.chdir("..")

In [3]:
from deepsvg.svglib.svg import SVG

from deepsvg import utils
from deepsvg.difflib.tensor import SVGTensor
from deepsvg.svglib.utils import to_gif
from deepsvg.svglib.geom import Bbox
from deepsvg.svgtensor_dataset import SVGTensorDataset, load_dataset
from deepsvg.utils.utils import batchify, linear

import torch
import numpy as np

# DeepSVG latent space operations

In [4]:
device = torch.device("cuda:0"if torch.cuda.is_available() else "cpu") 

Load the pretrained model and dataset

In [32]:
pretrained_path = "./pretrained/hierarchical_ordered.pth.tar"
from configs.deepsvg.hierarchical_ordered import Config

cfg = Config()
for key, value in cfg.__dict__.items():
    print(f"{key}: {value}")
    
model = cfg.make_model().to(device)
utils.load_model(pretrained_path, model)
model.eval();
print(model)

num_gpus: 4
dataloader_module: deepsvg.svgtensor_dataset
collate_fn: None
data_dir: ./dataset/icons_tensor/
meta_filepath: ./dataset/icons_meta.csv
loader_num_workers: 16
pretrained_path: None
model_cfg: <configs.deepsvg.hierarchical_ordered.ModelConfig object at 0x7f057e7ae090>
num_epochs: 50
num_steps: None
learning_rate: 0.002
batch_size: 1000
warmup_steps: 500
train_ratio: 1.0
nb_augmentations: 1
max_num_groups: 8
max_seq_len: 30
max_total_len: 50
filter_uni: None
filter_category: None
filter_platform: None
filter_labels: None
grad_clip: 1.0
log_every: 20
val_every: 2000
ckpt_every: 1000
stats_to_print: {'train': ['lr', 'time']}
model_args: ['commands', 'args', 'commands', 'args']
optimizer_starts: [0]
SVGTransformer(
  (encoder): Encoder(
    (embedding): SVGEmbedding(
      (command_embed): Embedding(7, 256)
      (arg_embed): Embedding(257, 64)
      (embed_fcn): Linear(in_features=704, out_features=256, bias=True)
      (pos_encoding): PositionalEncodingLUT(
        (dropout): 

In [6]:
dataset = load_dataset(cfg)

In [7]:
def load_svg(filename):
    svg = SVG.load_svg(filename)
    svg = dataset.simplify(svg)
    svg = dataset.preprocess(svg)
    return svg

In [8]:
def easein_easeout(t):
    return t*t / (2. * (t*t - t) + 1.);

def interpolate(z1, z2, n=25, filename=None, ease=True, do_display=True):
    alphas = torch.linspace(0., 1., n)
    if ease:
        alphas = easein_easeout(alphas)
    z_list = [(1-a) * z1 + a * z2 for a in alphas]
    
    img_list = [decode(z, do_display=False, return_png=True) for z in z_list]
    to_gif(img_list + img_list[::-1], file_path=filename, frame_duration=1/12)

In [9]:
def encode(data):
    model_args = batchify((data[key] for key in cfg.model_args), device)
    with torch.no_grad():
        z = model(*model_args, encode_mode=True)
        return z

def encode_icon(idx):
    data = dataset.get(id=idx, random_aug=False)
    return encode(data)
    
def encode_svg(svg):
    data = dataset.get(svg=svg)
    return encode(data)

def decode(z, do_display=True, return_svg=False, return_png=False):
    commands_y, args_y = model.greedy_sample(z=z)
    tensor_pred = SVGTensor.from_cmd_args(commands_y[0].cpu(), args_y[0].cpu())
    svg_path_sample = SVG.from_tensor(tensor_pred.data, viewbox=Bbox(256), allow_empty=True).normalize().split_paths().set_color("random")
    
    if return_svg:
        return svg_path_sample
    
    return svg_path_sample.draw(do_display=do_display, return_png=return_png)

In [10]:
def interpolate_icons(idx1=None, idx2=None, n=25, *args, **kwargs):
    z1, z2 = encode_icon(idx1), encode_icon(idx2)
    interpolate(z1, z2, n=n, *args, **kwargs)

# "Addition" operation

In [11]:
z_list = []

for i in range(500):
    tensors, fillings = dataset._load_tensor(dataset.random_id())
    t_sep = tensors[0]
    t_sep_rm, fillings_rm = t_sep[:-1], fillings[:-1]

    if len(t_sep) >= 2:
        z1 = encode(dataset.get_data(t_sep, fillings))
        z2 = encode(dataset.get_data(t_sep_rm, fillings_rm))
        z_list.append(z2 - z1)
z_rmv = torch.cat(z_list).mean(dim=0, keepdims=True)

`z_rmv` now represents the latent direction that removes the last path of an SVG icon.

In [12]:
z_baloon = encode_icon("548")

Now, what happens if one subtracts `z_rmv` from the representation of an SVG icon? 🤔

In [13]:
interpolate(z_baloon - 2 * z_rmv, z_baloon + 2 * z_rmv)

In [14]:
z_bubbles = encode_icon("76279")

In [15]:
interpolate(z_bubbles - 3 * z_rmv, z_bubbles + 3 * z_rmv)

It automagically adds new paths!!

# "Squarify" operation

In [16]:
svg1 = load_svg("docs/frames/circles.svg")
z1 = encode_svg(svg1)

In [17]:
svg2 = load_svg("docs/frames/squares.svg")
z2 = encode_svg(svg2)

In [18]:
z_squarify = z2 - z1

`z_squarify` is the latent direction that transforms round shapes to square shapes.

In [19]:
z_drill = encode_icon("29775")  # Drill
interpolate(z_drill - z_squarify/2, z_drill + z_squarify/2, n=25)

Quite surprisingly, adding or removing this vector to SVG icons makes them look more square/round!