In [None]:
%matplotlib inline
import dbutils
import pandas as pd
import matplotlib.pyplot as plt
import os
from datetime import date
from time import mktime
import calendar
import time
import numpy as np
from IPython.display import display
import random
import math
from PIL import Image
from tqdm import tqdm, trange
from IPython.display import display


measure_keys = ["height", "weight", "age", "muac", "head_circumference"]

# Create measure quality table.

In [None]:
# TODO move this to schema.

main_connector = dbutils.connect_to_main_database()

sql_statement = """
DROP TABLE IF EXISTS measure_quality;

CREATE TABLE IF NOT EXISTS measure_quality (
    PRIMARY KEY(measure_id, type, key),
    type TEXT NOT NULL,
    key TEXT NOT NULL,
    real_value REAL,
    bool_value BOOLEAN,
    text_value TEXT,
    measure_id VARCHAR(255) REFERENCES measure(id)
);
"""

results = main_connector.execute(sql_statement)

# Filter measures by STD and render.

In [None]:
main_connector = dbutils.connect_to_main_database()

def select_measures_in_std_range(std_factor=None):

    sql_statement = ""
    
    # Create temporary table for average and STD.
    if std_factor != None:
        sql_statement += "WITH AvgStd AS (" + "\n"
        sql_statement += "  SELECT" + "\n"
        fields = []
        for key in measure_keys:
            fields.append("AVG({}) AS {}_avg".format(key, key))
            fields.append("STDDEV({}) AS {}_stddev".format(key, key))
        sql_statement += "    " + ",\n    ".join(fields) + "\n"
        sql_statement += "  FROM measure WHERE type='manual'" + "\n"
        sql_statement += ")" + "\n"

    # Select all fields.
    fields = ["measure.id", "qr_code", "measure.timestamp"]
    fields.extend(measure_keys)
    fields = ", ".join(fields)
    sql_statement += "SELECT {} FROM measure".format(fields) + "\n"
    sql_statement += " INNER JOIN person ON measure.person_id = person.id" + "\n"
    
    # Use temporary table.
    if std_factor != None:
        sql_statement += " CROSS JOIN AvgStd" + "\n"
    
    # Allow only manual measurements.
    sql_statement += " WHERE type='manual'" + "\n"
    
    # Filter my mean and STD.
    if std_factor != None:
        for key in measure_keys:
            sql_statement += " AND ABS({} - {}_avg) / {}_stddev < {}".format(key, key, key, std_factor) + "\n"

    # Done.
    sql_statement += ";"

    #print(sql_statement)

    # Retrieve all scans from database.
    #print("Getting all scans...")
    results = main_connector.execute(sql_statement, fetch_all=True)
    results = np.array(results)
    print("Found {} scans for STD factor {}.".format(len(results), std_factor))
    return results
    #print(results)

# Select measures with multiple STD ranges.
overall_results = []
for std_factor in [None, 2.0, 1.5, 1.0, 0.75]:
    results = select_measures_in_std_range(std_factor=std_factor)
    overall_results.append((std_factor, results))
    
# For rendering a scatterplot.
def render_for_keys(key1, key2, xlim, ylim):

    # Indices of the keys.
    index1 = measure_keys.index(key1)
    index2 = measure_keys.index(key2)

    # Prepare plot.
    plt.figure(figsize=(10, 10))
    plt.xlabel(key1)
    plt.ylabel(key2)
    plt.xlim(xlim)
    plt.ylim(ylim)

    # Render.
    for std_factor, all_scans in overall_results:
        # QR-codes and timestamp.
        all_scans_head = all_scans[:,:3]

        # Actual measurements. All numbers.
        all_scans_tail = all_scans[:,3:].astype("float32")

        # Render.
        plt.scatter(all_scans_tail[:,index1], all_scans_tail[:,index2], s=2, label="STD factor {}".format(std_factor))

    # Done.
    plt.legend()
    plt.show()
    plt.close()

# Render all.
render_parameters = []
render_parameters.append(("height", "weight", (0, 200), (0, 100)))
render_parameters.append(("height", "age", (0, 200), (0, 4000)))
render_parameters.append(("weight", "age", (0, 100), (0, 4000)))
render_parameters.append(("weight", "muac", (0, 100), (0, 100)))
render_parameters.append(("weight", "head_circumference", (0, 20), (0,200)))
for key1, key2, xlim, ylim in render_parameters:
    render_for_keys(key1, key2, xlim, ylim)

# Interactive cell for accepting/rejecting scans.

In [None]:
# Get all scans.
all_scans = select_measures_in_std_range(std_factor=2.0)

# QR-codes and timestamp.
all_scans_head = all_scans[:,:3]

# Actual measurements. All numbers.
all_scans_tail = all_scans[:,3:].astype("float32")

In [None]:
import ipywidgets as widgets
from ipywidgets import HBox, Output
from IPython.display import display, clear_output

main_connector = dbutils.connect_to_main_database()

measure_id = None

# Runs the whole thing.
def clear_and_display():
    clear_output()
    artifacts, qr_code, timestamp = select_artifacts()
    render_artifacts(artifacts, qr_code, timestamp, 10)
    display(HBox([accept_button, reject_button]))
    
# Select a random thingy.
def select_artifacts():
    print("Selecting artifacts...")
    artifacts = []
    while len(artifacts) ==  0:
        index = random.randint(0, len(all_scans))
        global measure_id
        measure_id, qr_code, timestamp = all_scans_head[index]

        # Select all JPGs for that.
        def select_jpgs(qr_code, timestamp):
            print("Considering {} {} {}...".format(measure_id, qr_code, timestamp))
            sql_statement = ""
            sql_statement += "SELECT path FROM artifact AS a "
            sql_statement += " INNER JOIN measure AS m ON m.id=a.measure_id"
            sql_statement += " INNER JOIN person AS P ON p.id=m.person_id"
            sql_statement += " WHERE m.id='{}'".format(measure_id)
            #sql_statement += " AND m.timestamp='{}'".format(timestamp)
            sql_statement += " AND a.type='rgb'".format(timestamp)
            sql_statement += ";"

            results = main_connector.execute(sql_statement, fetch_all=True)
            results = np.array(results)
            print("Found {} JPGs for QR-code {} and timestamp {}.".format(len(results), qr_code, timestamp))
            return results
        artifacts = select_jpgs(qr_code, timestamp)
    print("Done.")
    return artifacts, qr_code, timestamp


# Render a subsample of the artifacts.
def render_artifacts(artifacts, qr_code, timestamp, num_columns=10, num_rows=5, target_size=(1920 // 4, 1080 // 4)):
    print("Rendering artifacts...")
    
    # Take a random subsample.
    indices = list(range(len(artifacts)))
    indices = random.sample(indices, num_columns * num_rows)
    artifacts = artifacts[indices]
    
    # Render compilation of those random images.
    artifact_index = 0
    result_images = []
    for row in range(num_rows):
        row_images = []
        for column in range(num_columns):
            path = artifacts[artifact_index][0].replace("whhdata", "localssd")
            img = Image.open(path)
            img = img.resize(target_size)
            img = np.array(img)
            img = np.rot90(img, 3)
            artifact_index += 1
            row_images.append(img)
        row_image = np.hstack(row_images)
        result_images.append(row_image)
    result_image = np.vstack(result_images)

    # Render with plt.
    plt.figure(figsize=(20, 20))
    plt.imshow(result_image)
    plt.axis("off")
    plt.show()
    plt.close()


# Accept button.
accept_button = widgets.Button(description="Accept")
def on_accept_button_clicked(_):
    insert_acceptance_into_database(measure_id, accepted=True)
    clear_and_display()
accept_button.on_click(on_accept_button_clicked)

# Rekect button.
reject_button = widgets.Button(description="Reject")
def on_reject_button_clicked(_):
    insert_acceptance_into_database(measure_id, accepted=False)
    clear_and_display()
reject_button.on_click(on_reject_button_clicked)

# Update database.
def insert_acceptance_into_database(measure_id, accepted):
    sql_statement = ""
    sql_statement += "INSERT INTO measure_quality"
    sql_statement += " (type, key, bool_value, measure_id)"
    sql_statement += " VALUES ('{}', '{}', '{}', '{}');".format("?", "accepted_by_expert", accepted, measure_id)
    main_connector.execute(sql_statement)
    

# Initial render.
clear_and_display()

# Query database for accepted and rejected measures.

In [None]:
main_connector = dbutils.connect_to_main_database()

sql_statement = ""
sql_statement += "SELECT measure_id, bool_value FROM measure_quality WHERE key='accepted_by_expert';"

results = main_connector.execute(sql_statement, fetch_all=True)
for result in results:
    print(result)

# Get the number of accepted pointclouds.

In [None]:
main_connector = dbutils.connect_to_main_database()

sql_statement = ""
sql_statement += "SELECT COUNT(*) FROM measure_quality mq"
sql_statement += " INNER JOIN artifact a ON mq.measure_id = a.measure_id"
sql_statement += " WHERE mq.key='accepted_by_expert' AND mq.bool_value='True'"
sql_statement += " AND a.type='pcd'"
sql_statement += ";"

result = main_connector.execute(sql_statement, fetch_one=True)
print("{} accepted PCDs.".format(result[0]))

sql_statement = ""
sql_statement += "SELECT COUNT(*) FROM measure_quality mq"
sql_statement += " INNER JOIN artifact a ON mq.measure_id = a.measure_id"
sql_statement += " WHERE mq.key='accepted_by_expert' AND mq.bool_value='False'"
sql_statement += " AND a.type='pcd'"
sql_statement += ";"

result = main_connector.execute(sql_statement, fetch_one=True)
print("{} rejected PCDs.".format(result[0]))

# TODO find number of unrated PCDs.