In [None]:
%matplotlib inline
import dbutils
import pandas as pd
import matplotlib.pyplot as plt
import os
from datetime import date
from time import mktime
import calendar
import time
import numpy as np
from IPython.display import display
import random
import math
from PIL import Image
from tqdm import tqdm, trange
from IPython.display import display
import ipywidgets as widgets
from ipywidgets import HBox, Output
from IPython.display import display, clear_output
import getpass

measure_keys = ["height", "weight", "age", "muac", "head_circumference"]
main_connector = dbutils.connect_to_main_database()

# Create measure quality table.

In [None]:
# TODO move this to schema.

sql_statement = """
DROP TABLE IF EXISTS measure_quality;

CREATE TABLE IF NOT EXISTS measure_quality (
    PRIMARY KEY(measure_id, type, key),
    type TEXT NOT NULL,
    key TEXT NOT NULL,
    real_value REAL,
    bool_value BOOLEAN,
    text_value TEXT,
    created_by TEXT NOT NULL,
    measure_id VARCHAR(255) REFERENCES measure(id)
);
"""

#results = main_connector.execute(sql_statement)

# Filter measures by STD and render.

In [None]:
main_connector = dbutils.connect_to_main_database()

def select_measures_in_std_range(std_factor=None):

    sql_statement = ""
    
    # Create temporary table for average and STD.
    if std_factor != None:
        sql_statement += "WITH AvgStd AS (" + "\n"
        sql_statement += "  SELECT" + "\n"
        fields = []
        for key in measure_keys:
            fields.append("AVG({}) AS {}_avg".format(key, key))
            fields.append("STDDEV({}) AS {}_stddev".format(key, key))
        sql_statement += "    " + ",\n    ".join(fields) + "\n"
        sql_statement += "  FROM measure WHERE type='manual'" + "\n"
        sql_statement += ")" + "\n"

    # Select all fields.
    fields = ["measure.id", "qr_code", "measure.timestamp"]
    fields.extend(measure_keys)
    fields = ", ".join(fields)
    sql_statement += "SELECT {} FROM measure".format(fields) + "\n"
    sql_statement += " INNER JOIN person ON measure.person_id = person.id" + "\n"
    
    # Use temporary table.
    if std_factor != None:
        sql_statement += " CROSS JOIN AvgStd" + "\n"
    
    # Allow only manual measurements.
    sql_statement += " WHERE type='manual'" + "\n"
    
    # Filter my mean and STD.
    if std_factor != None:
        for key in measure_keys:
            sql_statement += " AND ABS({} - {}_avg) / {}_stddev < {}".format(key, key, key, std_factor) + "\n"

    # Done.
    sql_statement += ";"

    #print(sql_statement)

    # Retrieve all scans from database.
    #print("Getting all scans...")
    results = main_connector.execute(sql_statement, fetch_all=True)
    results = np.array(results)
    print("Found {} scans for STD factor {}.".format(len(results), std_factor))
    return results
    #print(results)

In [None]:
# Select measures with multiple STD ranges.
overall_results = []
for std_factor in [None, 2.0, 1.5, 1.0, 0.75]:
    results = select_measures_in_std_range(std_factor=std_factor)
    overall_results.append((std_factor, results))
    
# For rendering a scatterplot.
def render_for_keys(key1, key2, xlim, ylim):

    # Indices of the keys.
    index1 = measure_keys.index(key1)
    index2 = measure_keys.index(key2)

    # Prepare plot.
    plt.figure(figsize=(10, 10))
    plt.xlabel(key1)
    plt.ylabel(key2)
    plt.xlim(xlim)
    plt.ylim(ylim)

    # Render.
    for std_factor, all_scans in overall_results:
        # QR-codes and timestamp.
        all_scans_head = all_scans[:,:3]

        # Actual measurements. All numbers.
        all_scans_tail = all_scans[:,3:].astype("float32")

        # Render.
        plt.scatter(all_scans_tail[:,index1], all_scans_tail[:,index2], s=2, label="STD factor {}".format(std_factor))

    # Done.
    plt.legend()
    plt.show()
    plt.close()

# Render all.
render_parameters = []
render_parameters.append(("height", "weight", (0, 200), (0, 100)))
render_parameters.append(("height", "age", (0, 200), (0, 4000)))
render_parameters.append(("weight", "age", (0, 100), (0, 4000)))
render_parameters.append(("weight", "muac", (0, 100), (0, 100)))
render_parameters.append(("weight", "head_circumference", (0, 20), (0,200)))
for key1, key2, xlim, ylim in render_parameters:
    render_for_keys(key1, key2, xlim, ylim)

# Interactive cell for accepting/rejecting scans.

In [None]:
# Get all scans.
all_scans = select_measures_in_std_range(std_factor=2.0)

# QR-codes and timestamp.
all_scans_head = all_scans[:,:3]

# Actual measurements. All numbers.
all_scans_tail = all_scans[:,3:].astype("float32")

In [None]:
measure_id = None
measure_targets = None

# Runs the whole thing.
def clear_and_display():
    clear_output()
    artifacts, qr_code, timestamp = select_artifacts()
    render_artifacts(artifacts, qr_code, timestamp, 10)
    display(HBox([standing_button, lying_button, mixed_button, reject_button, delete_button]))
    
# Select a random thingy.
def select_artifacts():
    artifacts = []
    while len(artifacts) ==  0:
        
        # Randomly select an artifact.
        global measure_id
        global measure_targets
        index = random.randint(0, len(all_scans))
        measure_id, qr_code, timestamp = all_scans_head[index]
        measure_targets = all_scans_tail[index]
        print("Counting artifacts for {} {} {}...".format(measure_id, qr_code, timestamp))
        
        # Check if measure is already in database.
        sql_statement = ""
        sql_statement += "SELECT COUNT(*) FROM measure_quality mq "
        sql_statement += " WHERE mq.measure_id='{}'".format(measure_id)
        sql_statement += " AND mq.key='expert_status'"
        sql_statement += ";"
        results = main_connector.execute(sql_statement, fetch_one=True)[0]
        if results != 0:
            print("Already in database. Skipped.")
        
        # Select all JPGs for that measure-id.
        def select_jpgs(qr_code, timestamp):
            
            sql_statement = ""
            sql_statement += "SELECT path FROM artifact AS a "
            sql_statement += " WHERE a.measure_id='{}'".format(measure_id)
            sql_statement += " AND a.type='rgb'".format(timestamp)
            sql_statement += ";"

            results = main_connector.execute(sql_statement, fetch_all=True)
            results = np.array(results)
            print("Found {} JPGs for QR-code {} and timestamp {}.".format(len(results), qr_code, timestamp))
            return results
        artifacts = select_jpgs(qr_code, timestamp)
    return artifacts, qr_code, timestamp


# Render a subsample of the artifacts.
def render_artifacts(artifacts, qr_code, timestamp, num_columns=10, target_size=(1920 // 4, 1080 // 4)):
    print("Rendering artifacts...")
    
    # Render results image.
    result_images = []
    row_images = []
    for artifact in artifacts:
        path = artifact[0].replace("whhdata", "localssd")
        img = Image.open(path)
        img = img.resize(target_size)
        img = np.array(img)
        img = np.rot90(img, 3)
        row_images.append(img)
        if len(row_images) == num_columns:
            row_image = np.hstack(row_images)
            result_images.append(row_image)
            row_images = []
    # Handle last row.
    if len(row_images) != 0:
        while len(row_images) != num_columns:
            black_image = np.zeros((target_size) + (3,)).astype("uint8")
            row_images.append(black_image)
        row_image = np.hstack(row_images)
        result_images.append(row_image)
    result_image = np.vstack(result_images)
    
    # Create title string.
    title_string = ""
    title_string += "QR-code: " + qr_code
    title_string += "Timestamp: " + timestamp
    title_string += " Targets: " + ", ".join([str(measure_target) for measure_target in measure_targets])
    
    # Render with plt.
    # TODO render target
    plt.figure(figsize=(20, int(20 * result_image.shape[0] / result_image.shape[1])))
    plt.imshow(result_image)
    plt.axis("off")
    plt.title(title_string)
    plt.show()
    plt.close()


# Standing button.
standing_button = widgets.Button(description="Standing")
def on_standing_button_clicked(_):
    insert_status_into_database(measure_id, status="standing")
    clear_and_display()
standing_button.on_click(on_standing_button_clicked)

# Lying button.
lying_button = widgets.Button(description="Lying")
def on_lying_button_clicked(_):
    insert_status_into_database(measure_id, status="lying")
    clear_and_display()
lying_button.on_click(on_lying_button_clicked)

# Mixed button.
mixed_button = widgets.Button(description="Mixed")
def on_mixed_button_clicked(_):
    insert_status_into_database(measure_id, status="mixed")
    clear_and_display()
mixed_button.on_click(on_mixed_button_clicked)

# Reject button.
reject_button = widgets.Button(description="Reject")
def on_reject_button_clicked(_):
    insert_status_into_database(measure_id, status="rejected")
    clear_and_display()
reject_button.on_click(on_reject_button_clicked)

# Delete button.
delete_button = widgets.Button(description="Delete")
def on_delete_button_clicked(_):
    insert_status_into_database(measure_id, status="delete")
    clear_and_display()
delete_button.on_click(on_delete_button_clicked)

# Update database.
def insert_status_into_database(measure_id, status):
    created_by = getpass.getuser()
    sql_statement = ""
    sql_statement += "INSERT INTO measure_quality"
    sql_statement += " (type, key, text_value, created_by, measure_id)"
    sql_statement += " VALUES ('{}', '{}', '{}', '{}', '{}');".format("?", "expert_status", status, created_by, measure_id)
    main_connector.execute(sql_statement)
    

# Initial render.
clear_and_display()

# Query database for accepted and rejected measures.

In [None]:
sql_statement = ""
sql_statement += "SELECT measure_id, text_value, created_by FROM measure_quality WHERE key='expert_status';"

results = main_connector.execute(sql_statement, fetch_all=True)
rows = []
#for result in results:
#    rows.
#    print(result)
    
import pandas as pd
df = pd.DataFrame(results, columns=['qr_code','status','created_by'])
display(df)
#df.to_html('test.html')

# Get the number of accepted pointclouds.

In [None]:
chart_data = []
chart_labels = []
for status in ["standing", "lying", "mixed", "rejected", "delete"]:
    sql_statement = ""
    sql_statement += "SELECT COUNT(*) FROM measure_quality mq"
    sql_statement += " INNER JOIN artifact a ON mq.measure_id = a.measure_id"
    sql_statement += " WHERE mq.key='expert_status' AND mq.text_value='{}'".format(status)
    sql_statement += " AND a.type='pcd'"
    sql_statement += ";"
    result = main_connector.execute(sql_statement, fetch_one=True)[0]
    print("{} PCDs with status '{}'.".format(result, status))
    chart_data.append(result)
    chart_labels.append("{} ({})".format(status, result))
    
plt.figure(figsize=(10, 10))
plt.pie(chart_data, labels=chart_labels)
plt.title("Showing {} rated pointclouds.".format(np.sum(chart_data)))
plt.show()
plt.close()

# TODO find number of unrated PCDs.