In [None]:
import sys
sys.path.insert(0, "..")
from cgmcore import modelutils, utils
import dbutils
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
import numpy as np
from tqdm import tqdm
import pandas as pd
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt

def mkdir_if_not_exists(path_to_create):
    if os.path.exists(path_to_create) == False:
        os.mkdir(path_to_create)

# Select the model.

In [None]:
# Getting the models.
db_connector = dbutils.connect_to_main_database()
select_sql_statement = "SELECT DISTINCT(type) FROM artifact_quality;"
types = db_connector.execute(select_sql_statement, fetch_all=True)
models_in_database = [t[0] for t in types if len(t[0]) > 20 and "height" in t[0]]

selected_model = None

@interact(models=models_in_database)
def select_model(models):
    global selected_model
    selected_model = models
    print("Selected model '{}'.".format(selected_model))

# Load the model.

In [None]:
model_path = "/whhdata/models/{0}/{0}-pointnet-model-weights.h5".format(selected_model)

def load_model(model_path):

    input_shape = (10000, 3)
    output_size = 1
    model = modelutils.create_point_net(input_shape, output_size, hidden_sizes = [512, 256, 128])
    model.load_weights(model_path)
    model.compile(
        optimizer="rmsprop",
        loss="mse",
        metrics=["mae"]
    )
    return model 

print("Loading model...")
model = load_model(model_path)
print("Model loaded.")

# Get the paths of the scans.

In [None]:
import glob2 as glob
import os

measurements_path = "/localssd/20190724_Standardization_AAH/"

scan_paths = glob.glob(os.path.join(measurements_path, "*"))
scan_paths = sorted(scan_paths)

print("Scans to be used:")
print("\n".join(scan_paths))

In [None]:
pd.set_option("max_colwidth", 1000)
pose_types = ["100", "101", "102", "200", "201", "202"]

#@interact(scan_path=scan_paths)
def render(scan_path, file_path=None):
    rows = []
    
    # Get all PCDs.
    pcd_paths = glob.glob(os.path.join(scan_path, "**", "*.pcd"))
    if len(pcd_paths) == 0:
        print("Could not find any PCDs at {}.".format(scan_path))
        return
    print("Found {} PCDs for {}...".format(len(pcd_paths), scan_path))
    print()
    
    #pcd_paths_with_type = []
    #for pcd_path in pcd_paths:
    #    pose_type = pcd_path.split("/")[-1].split("_")[-2]
    #    pcd_paths_with_type.append((pcd_path, pose_type))
    #del pcd_paths
    
    # Load the artifact and evaluate.
    print("Loading PCDs. This might take a while...")
    pcd_arrays = []
    for pcd_path in tqdm(pcd_paths):
        pcd_array = utils.load_pcd_as_ndarray(pcd_path)
        pcd_array = utils.subsample_pointcloud(pcd_array, 10000)
        pcd_arrays.append(pcd_array)
    pcd_arrays = np.array(pcd_arrays)
    
    # Predict on all.
    print("Predicting...")
    predictions = model.predict(pcd_arrays, verbose=1)
    
    # Just check.
    assert len(pcd_paths) == len(predictions)
    
    # Prepare data.frame.
    for pcd_path, prediction in zip(pcd_paths, predictions):
        rows.append((pcd_path, prediction[0]))

    # Create a data-frame.
    print("Creating data-frame...")
    df = pd.DataFrame.from_records(rows, columns=["path", "predicted target"])
    if file_path != None:
        df.to_csv(file_path + ".csv")
    else:
        display(df)
    
    # Render the barchart.
    print("Rendering plot...")
    x_values = []
    colors = []
    for index, pcd_path in enumerate(pcd_paths):
        if "_100_" in pcd_path:
            x_values.append("{}-100".format(index + 1))
            colors.append("C0")
        elif "_101_" in pcd_path:
            x_values.append("{}-101".format(index + 1))
            colors.append("C1")
        elif "_102_" in pcd_path:
            x_values.append("{}-102".format(index + 1))
            colors.append("C3")
        elif "_200_" in pcd_path:
            x_values.append("{}-200".format(index + 1))
            colors.append("C4")
        elif "_201_" in pcd_path:
            x_values.append("{}-201".format(index + 1))
            colors.append("C5")
        elif "_202_" in pcd_path:
            x_values.append("{}-202".format(index + 1))
            colors.append("C6") 
        else:
            print(pcd_Path)
            assert False
    bottom_value = 60
    x = list(range(len(predictions)))
    plt.figure(figsize=(20, 5))
    plt.bar(x, [prediction[0] - bottom_value for prediction in predictions], color=colors, bottom=bottom_value)
    #plt.xticks(x, x_values, rotation='vertical')
    plt.title(scan_path.split("/")[-1])
    if file_path != None:
        plt.savefig(file_path + ".png")
    else:
        plt.show()
    plt.close()


# Create some folders.
root_path = "/whhdata/standardization_test_results"
mkdir_if_not_exists(root_path)
model_path = os.path.join(root_path, selected_model)
mkdir_if_not_exists(model_path)

# Process each scan.
for scan_path in scan_paths:
    qr_code = scan_path.split("/")[-1]
    print("Processing {}...".format(qr_code))
    output_path = os.path.join(model_path, qr_code)
    render(scan_path, output_path)
    
print("Done.")