In [None]:
import torch
import os
from pathlib import Path
from networks.mrnet import MRFactory
from datasets.sampler import make_grid_coords
import yaml
from yaml.loader import SafeLoader
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

In [None]:
os.environ["WANDB_NOTEBOOK_NAME"] = "eval-net.ipynb"
BASE_DIR = Path('.').absolute().parents[0]
IMAGE_PATH = BASE_DIR.joinpath('img')
MODEL_PATH = BASE_DIR.joinpath('models')

In [None]:
project_name = "test_eval"
with open('../configs/config_base_m_net.yml') as f:
    hyper = yaml.load(f, Loader=SafeLoader)

In [None]:
# TODO: download model automatically from W&B if can't find it locally
mrmodel = MRFactory.load_state_dict(
    os.path.join(MODEL_PATH, 
                 #'MGchec_louco_3-3_w16T_hf289_MEp2000_hl1_256px.pth'
                 "MGpic7._1-1_w128F_hf256256_MEp200_hl1_r512_pr4.pth"
                 )
)

In [None]:
print("Model: ", type(mrmodel))

In [None]:
for p in mrmodel.parameters():
    print("p: ", p.shape, " = ", p.numel())
total_params = sum(p.numel() for p in mrmodel.parameters()) - mrmodel.n_stages()
print("TOTAL = ", total_params)
print("MODEL TOTAL = ", mrmodel.total_parameters())

In [None]:
output = mrmodel(make_grid_coords(256, 256, *hyper['domain']))
model_out = torch.clamp(output['model_out'], 0.0, 1.0)

plt.imshow(model_out.cpu().view(256, 256).detach().numpy())

In [None]:
import ipywidgets as widgets
from ipywidgets import interact, interactive, Box, interact_manual

## Extrapolation

In [None]:
slider = widgets.FloatRangeSlider(
    value=[-1.0, 1.0],
    min=-7,
    max=7,
    step=0.1,
    description='Interval:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
    layout=widgets.Layout(width='50%')
)

model = mrmodel
res = 512
channels=1
def plot_model(interval):
    grid = make_grid_coords(res, res, *interval)
    output = model(grid)
    model_out = torch.clamp(output['model_out'], 0.0, 1.0)

    pixels = model_out.cpu().detach().view(res, res, channels).numpy()
    pixels = (pixels * 255).astype(np.uint8)
    if channels == 1:
        pixels = np.repeat(pixels, 3, axis=-1)
    return Image.fromarray(pixels)

interact(plot_model, interval=slider)

## Continuous in Scale

In [None]:
level_slider = widgets.FloatSlider(
        value=1.0,
        min=0.0,
        max=float(mrmodel.n_stages()),
        step=0.05,
        description=f'Multilevel',
        disabled=False,
        continuous_update=True,
        readout=True,
        orientation='horizontal',
        readout_format='.2f',
        layout=widgets.Layout(width='50%')
)
def plot_model(level):
    grid = make_grid_coords(res, res, -1.0, 1.0)
    weights = []
    for s in range(mrmodel.n_stages()):
        if level >= s + 1:
             weights.append(1.0)
        else:
             weights.append(max(level - s, 0.0))

    output = model(grid, mrweights=torch.Tensor(weights))
    model_out = torch.clamp(output['model_out'], 0.0, 1.0)

    pixels = model_out.cpu().detach().view(res, res, channels).numpy()
    pixels = (pixels * 255).astype(np.uint8)
    if channels == 1:
        pixels = np.repeat(pixels, 3, axis=-1)
    return Image.fromarray(pixels)

interact(plot_model, level=level_slider)