In [1]:
import torch
import os
from pathlib import Path
from networks.mrnet import MRFactory
from datasets.sampler import make_grid_coords
import yaml
from yaml.loader import SafeLoader
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from logs.utils import output_on_batched_domain
import trimesh
from torch.utils.data import BatchSampler

In [2]:
os.environ["WANDB_NOTEBOOK_NAME"] = "eval-net.ipynb"
BASE_DIR = Path('.').absolute().parents[0]
MESH_PATH = BASE_DIR.joinpath('meshes')
MODEL_PATH = BASE_DIR.joinpath('models')

project_name = "eval3d"
with open('../configs/config_base_m_net.yml') as f:
    hyper = yaml.load(f, Loader=SafeLoader)
    if isinstance(hyper['batch_size'], str):
        hyper['batch_size'] = eval(hyper['batch_size'])
    print(hyper)

{'model': 'M', 'in_features': 2, 'hidden_layers': 1, 'hidden_features': [8, 8, 16, 32, 64, 128, 192], 'bias': True, 'max_stages': 7, 'period': 2, 'domain': [-1, 1], 'omega_0': [1, 2, 4, 6, 8, 16, 32], 'hidden_omega_0': [30, 30, 30, 30, 30, 30, 30], 'superposition_w0': False, 'sampling_scheme': 'regular', 'decimation': True, 'filter': 'gauss', 'attributes': ['d0'], 'loss_function': 'mse', 'opt_method': 'Adam', 'lr': 0.0005, 'loss_tol': 1e-12, 'diff_tol': 1e-09, 'max_epochs_per_stage': [1000, 400, 400, 400, 400, 400, 400], 'batch_size': 16384, 'image_name': 'hallpaz.jpg', 'width': 512, 'height': 512, 'channels': 1, 'device': 'cuda', 'eval_device': 'cpu', 'save_format': 'general', 'visualize_grad': True, 'extrapolate': [-2, 2]}


In [3]:
# TODO: download model automatically from W&B if can't find it locally
mrmodel = MRFactory.load_state_dict(
    os.path.join(MODEL_PATH, 
                 'MGcolor_3-3_w12F_hf320_MEp22_hl1_128px.pth')
)
print("Model: ", type(mrmodel))
print("MODEL TOTAL = ", mrmodel.total_parameters())

Model:  <class 'networks.mrnet.MNet'>
MODEL TOTAL =  300550


In [9]:
# mesh = meshio.read(os.path.join(MESH_PATH, 'armadillo.obj'))
mesh = trimesh.load_mesh(os.path.join(MESH_PATH, 'cube.ply'))
# print(np.max(mesh.points), np.min(mesh.points))
# points = torch.from_numpy(mesh.points).float()
# points.shape
mesh

1.942396 -1.272332


torch.Size([106289, 3])

In [10]:
colors = []
mrmodel.to(hyper['device'])
for batch in BatchSampler(points, 256*256, drop_last=False):
    batch = torch.stack(batch)
    with torch.no_grad():
        colors.append(mrmodel(batch.to(hyper['device']))['model_out'])
colors = torch.concat(colors).cpu()

In [6]:
mesh.point_data = {
    'red': (colors[: , 0].numpy() * 255).astype(np.uint8),
    'green': (colors[: , 1].numpy() * 255).astype(np.uint8),
    'blue': (colors[: , 2].numpy() * 255).astype(np.uint8)
}
mesh.write(os.path.join(MESH_PATH, 'colorful_armadillo.ply'), binary=False)

## Extrapolation

In [None]:
import ipywidgets as widgets
from ipywidgets import interact, interactive, Box, interact_manual

In [None]:
slider = widgets.FloatRangeSlider(
    value=[-1.0, 1.0],
    min=-7,
    max=7,
    step=0.1,
    description='Interval:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
    layout=widgets.Layout(width='50%')
)

model = mrmodel
res = 512
channels=1
def plot_model(interval):
    grid = make_grid_coords(res, res, *interval)
    output = model(grid)
    model_out = torch.clamp(output['model_out'], 0.0, 1.0)

    pixels = model_out.cpu().detach().view(res, res, channels).numpy()
    pixels = (pixels * 255).astype(np.uint8)
    if channels == 1:
        pixels = np.repeat(pixels, 3, axis=-1)
    return Image.fromarray(pixels)

interact(plot_model, interval=slider)

## Continuous in Scale

In [None]:
level_slider = widgets.FloatSlider(
        value=1.0,
        min=0.0,
        max=float(mrmodel.n_stages()),
        step=0.05,
        description=f'Multilevel',
        disabled=False,
        continuous_update=True,
        readout=True,
        orientation='horizontal',
        readout_format='.2f',
        layout=widgets.Layout(width='50%')
)
def plot_model(level):
    grid = make_grid_coords(res, res, -1.0, 1.0)
    weights = []
    for s in range(mrmodel.n_stages()):
        if level >= s + 1:
             weights.append(1.0)
        else:
             weights.append(max(level - s, 0.0))

    output = model(grid, mrweights=torch.Tensor(weights))
    model_out = torch.clamp(output['model_out'], 0.0, 1.0)

    pixels = model_out.cpu().detach().view(res, res, channels).numpy()
    pixels = (pixels * 255).astype(np.uint8)
    if channels == 1:
        pixels = np.repeat(pixels, 3, axis=-1)
    return Image.fromarray(pixels)

interact(plot_model, level=level_slider)