In [20]:
"""
6. Docker
7. Github Actions
"""

import numpy as np
import torch
import sys
import codecs
import os
import shortuuid
from traitlets import dlink

from models.layers.mesh import Mesh
from models import create_model
from data import DataLoader
from options.test_options import TestOptions
from util.util import pad

import ipyvolume as ipv
import ipywidgets as widgets
from IPython.display import display

SCORE_THD = 0.5
ARGS = ['--dataroot', 'primitives', '--name', 'primitives', '--ncf', '64', '128', '256', '256',
        '--pool_res', '1200', '900', '600', '420', '--ninput_edges', '2000', '--norm', 'group',
        '--resblocks', '1', '--gpu_ids', '-1']


def file_to_xyzf(obj_path):
    """Load Wavefront file in convenience to ipv.plot_trisurf() manner

    :param obj_path: path to obj file
    :return: x, y, z coordinates of vertices as np.arrays; list of vertex triplets
    """
    with open(obj_path, 'rt') as file:
        odata = file.read()
    lines = odata.splitlines()
    x = []
    y = []
    z = []
    faces = []
    for line in lines:
        slist = line.split()
        if not slist:
            continue
        if slist[0] == 'v':
            x.append(float(slist[1]))
            y.append(float(slist[2]))
            z.append(float(slist[3]))
        elif slist[0] == 'f':
            faces.append([int(i.split('//')[0]) - 1 for i in slist[1:]])
        else:
            pass
    return *[np.array(i) for i in (x, y, z)], faces


def get_verdict(file_path, model, opt):
    """Classify obj file.

    :param file_path: path to obj file
    :param model: neural net for classification (MeshCNN)
    :param opt: MeshCNN configuration
    :return: human-readable string with top-1 score if model is certain enough,
             sad message otherwise
    """
    mesh = Mesh(file=file_path, opt=opt, hold_history=False)

    # get edge features
    edge_features = mesh.extract_features()
    edge_features = pad(edge_features, opt.ninput_edges)
    edge_features = (edge_features - dataset.dataset.mean) / dataset.dataset.std
    edge_features = torch.from_numpy(edge_features).float().to(model.device).unsqueeze(0)

    # model inference
    probs = model.net(edge_features, [mesh]).data.softmax(1)

    score = probs.max(1)[0][0] * 100
    if score < SCORE_THD:
        return "The model is unsure that uploaded object belongs to any of predefined classes :("

    label = dataset.dataset.classes[probs.max(1)[1]]

    return f"The model is {score:.6f}% sure it's {label}"


def test_options_from_list(arr=ARGS):
    """Imitate CLI arguments input to MeshCNN test config

    :param arr: arguments to pass as command-line
    :return: MeshCNN test config
    """
    temp = sys.argv[:]
    sys.argv = sys.argv[:1]
    sys.argv += arr
    opt = TestOptions().parse()
    sys.argv = temp[:]
    return opt


def show_obj(path):
    """Display obj file as interactive plot

    :param path: path to obj file
    :return: interactive plot
    """
    x, y, z, faces = file_to_xyzf(path)

    ipv.figure()
    ipv.style.use('minimal')
    m2 = ipv.plot_trisurf(x, y, z, triangles=faces, color='white')
    return ipv.gcc()


def get_name(file_upload):
    """Get name of uploaded file

    :param file_upload: widgets.upload object
    :return: name of uploaded file as string
    """
    if len(file_upload) == 0:
        return ''
    return next(iter(file_upload.values()))['metadata']['name']


# widgets
file_picker = widgets.FileUpload(accept='*.obj')
file_name = widgets.Text()
vbox = widgets.VBox([widgets.HBox([file_picker, file_name]), ])

# MeshCNN
opt = test_options_from_list()
dataset = DataLoader(opt)
model = create_model(opt)


def change_input(change, model=model, opt=opt):
    """Callback for object processing.
    Responsible for visualization of object and its classification

    :param change: ipywidgets event
    :param model: neural model (MeshCNN)
    :param opt: test configuration from CLI or test_options_from_list()
    :return:
    """
    global vbox, file_picker
    content = next(iter(file_picker.value.values()))['content']
    path = f"./{shortuuid.uuid()}.obj"
    with open(path, 'wt') as file:
        file.write(codecs.decode(content, encoding="utf-8"))
    plot = show_obj(path)
    verdict = widgets.Text(get_verdict(path, model, opt))
    os.remove(path)
    vbox.children = (widgets.HBox([file_picker, file_name]), plot, verdict)


dlink((file_picker, 'value'), (file_name, 'value'), get_name)
file_picker.observe(change_input, 'value')
vbox

loaded mean / std from cache
loading the model from ./checkpoints\primitives\latest_net.pth
