In [17]:
import torch
import os
from pathlib import Path
from networks.mrnet import MRFactory
from datasets.sampler import make_grid_coords
import yaml
from yaml.loader import SafeLoader
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from logs.utils import output_on_batched_domain
import meshio
from torch.utils.data import BatchSampler

In [2]:
os.environ["WANDB_NOTEBOOK_NAME"] = "eval-net.ipynb"
BASE_DIR = Path('.').absolute().parents[0]
MESH_PATH = BASE_DIR.joinpath('meshes')
MODEL_PATH = BASE_DIR.joinpath('models')

In [3]:
project_name = "eval3d"
with open('../configs/config_base_m_net.yml') as f:
    hyper = yaml.load(f, Loader=SafeLoader)

In [31]:
# TODO: download model automatically from W&B if can't find it locally
mrmodel = MRFactory.load_state_dict(
    os.path.join(MODEL_PATH, 'MGwood__3-3_w9F_hf256_MEp21_hl1_128px.pth')
)

In [5]:
print("Model: ", type(mrmodel))

Model:  <class 'networks.mrnet.MNet'>


In [6]:
for p in mrmodel.parameters():
    print("p: ", p.shape, " = ", p.numel())
total_params = sum(p.numel() for p in mrmodel.parameters()) - mrmodel.n_stages()
print("TOTAL = ", total_params)
print("MODEL TOTAL = ", mrmodel.total_parameters())

p:  torch.Size([256, 3])  =  768
p:  torch.Size([256])  =  256
p:  torch.Size([256, 256])  =  65536
p:  torch.Size([256])  =  256
p:  torch.Size([3, 256])  =  768
p:  torch.Size([3])  =  3
p:  torch.Size([256, 3])  =  768
p:  torch.Size([256])  =  256
p:  torch.Size([256, 512])  =  131072
p:  torch.Size([256])  =  256
p:  torch.Size([3, 256])  =  768
p:  torch.Size([3])  =  3
p:  torch.Size([256, 3])  =  768
p:  torch.Size([256])  =  256
p:  torch.Size([256, 512])  =  131072
p:  torch.Size([256])  =  256
p:  torch.Size([3, 256])  =  768
p:  torch.Size([3])  =  3
TOTAL =  333830
MODEL TOTAL =  333830


In [39]:
mesh = meshio.read(os.path.join(MESH_PATH, 'bunny.obj'))
points = 7* torch.from_numpy(mesh.points).float()
points.shape

torch.Size([2503, 3])

In [40]:
torch.min(points)

tensor(-0.6607)

In [41]:
colors = []
mrmodel.to(hyper['device'])
for batch in BatchSampler(points, 256*256, drop_last=False):
    batch = torch.stack(batch)
    with torch.no_grad():
        colors.append(mrmodel(batch.to(hyper['device']))['model_out'])
colors = torch.concat(colors).cpu()

In [42]:
mesh.point_data = {
    'red': (colors[: , 0].numpy() * 255).astype(np.uint8),
    'green': (colors[: , 1].numpy() * 255).astype(np.uint8),
    'blue': (colors[: , 2].numpy() * 255).astype(np.uint8)
}

In [43]:
mesh.write(os.path.join(MESH_PATH, 'textured_bunny.ply'), binary=False)

: 

In [None]:
# output = output_on_batched_domain(mrmodel, 128, 
#                                   [-1, 1], 3, 256*256, hyper['cuda'])

## Extrapolation

In [None]:
import ipywidgets as widgets
from ipywidgets import interact, interactive, Box, interact_manual

In [None]:
slider = widgets.FloatRangeSlider(
    value=[-1.0, 1.0],
    min=-7,
    max=7,
    step=0.1,
    description='Interval:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
    layout=widgets.Layout(width='50%')
)

model = mrmodel
res = 512
channels=1
def plot_model(interval):
    grid = make_grid_coords(res, res, *interval)
    output = model(grid)
    model_out = torch.clamp(output['model_out'], 0.0, 1.0)

    pixels = model_out.cpu().detach().view(res, res, channels).numpy()
    pixels = (pixels * 255).astype(np.uint8)
    if channels == 1:
        pixels = np.repeat(pixels, 3, axis=-1)
    return Image.fromarray(pixels)

interact(plot_model, interval=slider)

## Continuous in Scale

In [None]:
level_slider = widgets.FloatSlider(
        value=1.0,
        min=0.0,
        max=float(mrmodel.n_stages()),
        step=0.05,
        description=f'Multilevel',
        disabled=False,
        continuous_update=True,
        readout=True,
        orientation='horizontal',
        readout_format='.2f',
        layout=widgets.Layout(width='50%')
)
def plot_model(level):
    grid = make_grid_coords(res, res, -1.0, 1.0)
    weights = []
    for s in range(mrmodel.n_stages()):
        if level >= s + 1:
             weights.append(1.0)
        else:
             weights.append(max(level - s, 0.0))

    output = model(grid, mrweights=torch.Tensor(weights))
    model_out = torch.clamp(output['model_out'], 0.0, 1.0)

    pixels = model_out.cpu().detach().view(res, res, channels).numpy()
    pixels = (pixels * 255).astype(np.uint8)
    if channels == 1:
        pixels = np.repeat(pixels, 3, axis=-1)
    return Image.fromarray(pixels)

interact(plot_model, level=level_slider)