In [None]:
!pip install -q git+https://github.com/soumik12345/implicit_geometric_representations wandb

In [None]:
from implicit_representation.dataloaders import PointCloud2DFromFont, PointCloud2DBatman
from implicit_representation.models import SDFModelBase
from implicit_representation.callbacks import ProgressBarCallback, SDFVisualizationCallback

import wandb
from wandb.keras import WandbCallback
from tqdm.autonotebook import tqdm

import tensorflow as tf
from tensorflow.keras import activations

In [None]:
ACTIVATION_DICT = {
    "softplus": activations.softplus,
    "swish": tf.nn.swish
}

In [None]:
wandb.init(
    project="implicit-geometric-representation",
    entity="geekyrakshit", # Put your wandb username/entity
    job_type="2d-font-point-cloud"
)
config = wandb.config
config.seed = 43
tf.keras.utils.set_random_seed(config.seed)

# Data Configs
config.sample_percentage = 0.1
config.font_file = "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf"
config.font_size = 200
config.query = "Q"
config.padding = 5

# Model Configs
config.units = 80
config.num_intermediate_layers = 2
config.activation = "softplus"
config.point_loss_coeff = 100.0
config.eikonal_coefficient = 2.
config.num_padding_points = 500

# Training Configs
config.learning_rate = 1e-3
config.epochs = 2000

In [None]:
font_point_cloud = PointCloud2DFromFont(
    sample_percentage=config.sample_percentage,
    font_file=config.font_file,
    font_size=config.font_size,
)
data = font_point_cloud.build(query=config.query, padding=config.padding)
font_point_cloud.plot_points()

In [None]:
model = SDFModelBase(
    num_points=data.shape[0],
    units=config.units,
    num_intermediate_layers=config.num_intermediate_layers,
    activation=ACTIVATION_DICT[config.activation],
    point_loss_coeff=config.point_loss_coeff,
    eikonal_coefficient=config.eikonal_coefficient,
    num_padding_points=config.num_padding_points,
)
model.compile(optimizer=tf.keras.optimizers.Adam(config.learning_rate))

In [None]:
history = model.fit(
    tf.expand_dims(data, axis=0),
    epochs=config.epochs,
    verbose=0,
    callbacks=[
        ProgressBarCallback(epochs=config.epochs),
        SDFVisualizationCallback(data),
        WandbCallback()
    ]
)

In [None]:
wandb.finish()