In [1]:
import dash
from dash import dcc, html
from dash.dependencies import Input, Output, State, ALL
import numpy as np
import plotly.graph_objects as go

import matplotlib.pyplot as plt
import json

from helpers import fill_template, read_config, read_json
from inference import KnnClassifier

# user_x, user_y = [0.5, 0, 0], 2


config = read_config()
model = KnnClassifier( visualization_method = "mde", config = config )
model.embedder.get_embeddings ## Compute the roles embeddings if not already in embeddings.txt file.
df = model.embedder.get_initial_corrdinates()
X = df[["x","y","z"]].values
labels = df["role_id"].values
rol_to_id_dict = read_json("./data/role_to_id.json")
id_to_rol_dict =  dict((v,k) for k,v in rol_to_id_dict.items())

  from tqdm.autonotebook import tqdm, trange
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


Batches:   0%|          | 0/39 [00:00<?, ?it/s]

Dimensionality Reduction with mde took: 19.183830976486206


In [5]:
with open('./data/suggestions.json', 'r') as f:
    form_dict = json.load(f)

## create a Div for each dropdown from the key and list of values from 
def group_from_json(key, values):
    return html.Div([
                html.Label(f"{key}:"),
                dcc.Dropdown(id={'type': 'input-field', 'key': key}, 
                             options=[{'label': option, 'value': option} for option in values], multi=True, value=values[0]),
                    ], className='form-group')

divs = [group_from_json(key, values) for key, values in form_dict.items()]
form = html.Div(divs, className='form-container')

In [58]:

# Create the Dash app
app = dash.Dash(__name__)

# User point params
c_user ="#b9a909"
# prediction params
c_pred = "#09b9a9"
opac_pred = 1
# default params
opac_def = 0.75


# Create a scatter plot given covariates X (n x 3) and labels y (n x 1)
def create_figure(X, y, id_to_rol_dict, p_x, p_y):

    # Array containing all the clusters and the user point
    layers = []

    # unique_labels = np.unique(y)
    unique_labels = id_to_rol_dict.keys()

    


    k = len(unique_labels)

    color_palette = plt.get_cmap("Greys")(np.linspace(0.2, 0.8, k))  # Using colormap from Matplotlib
    colors = [f'rgba({int(c[0] * 255)}, {int(c[1] * 255)}, {int(c[2] * 255)}, {c[3]})' for c in color_palette]
    
    # Do not plot the user point on initialization
    if p_x is not None:

        user_point = go.Scatter3d(
                    x=[p_x[0]], y=[p_x[1]], z=[p_x[2]], 
                    mode='markers', 
                    marker=dict(
                        size=10,
                        color=c_user,
                        symbol='cross'
                        ),
                    name=id_to_rol_dict[p_y],
                    text="Predicted Job:",
                    hovertemplate='<b>%{fullData.text}</b><br>%{fullData.name}<extra></extra>',
                    showlegend=False,
                )
        layers.append(user_point)

    # Create color-grouped scatter plots for each label in y
    for label in unique_labels:

        # select rows with corresponding label
        X_k = X[y==label]

        x_c, y_c, z_c  = X_k[:,0], X_k[:,1], X_k[:,2] 

        # default or prediction color
        if label == p_y:
            c_val = c_pred
            opac = opac_pred
        else:
            c_val = colors[label]
            opac = opac_def

        # Scatter plot for points
        scatter = go.Scatter3d(
            x=x_c, y=y_c, z=z_c, 
            mode='markers', 
            marker=dict(size=4, color=c_val, opacity=opac),
            name=id_to_rol_dict[label],
            text="test",
            hovertemplate='<b>%{fullData.name}</b><br>%{fullData.text}<extra></extra>',
            legendgroup=id_to_rol_dict[label],
            showlegend=False,
        )
        layers.append(scatter)

    # Create figure
    fig = go.Figure(data=layers)
    fig.update_layout(paper_bgcolor='white',
                      scene=dict(aspectmode='data',
                                xaxis=dict(visible=False),
                                yaxis=dict(visible=False),
                                zaxis=dict(visible=False)),
                      title="Job Space",
                    )

    return fig

# Layout of the app
app.layout = html.Div([
    form,
    html.Button('Submit', id='submit-button', n_clicks=0),
    dcc.Graph(id='3d-scatter', figure=create_figure(X, labels, id_to_rol_dict, None, None)),
])

@app.callback(
    Output('3d-scatter', 'figure'),
    Input('submit-button', 'n_clicks'),
    # capture the input key-value pairs
    State({'type': 'input-field', 'key': ALL}, 'value'),
    State({'type': 'input-field', 'key': ALL}, 'id'),
    State('3d-scatter', 'figure'),
    prevent_initial_call=True,
)
def update_output(n_clicks, input_values, input_ids, prev_fig):
    if n_clicks > 0:
        if None not in input_values:
            # Create a dictionary to hold form data based on the pattern-matched input fields
            form_data = {input_id['key']: value for input_id, value in zip(input_ids, input_values)}

            # Convert to JSON format
            json_output = json.dumps(form_data, indent=4)
            
            # generate template using json_output
            template = fill_template(input_dict=json_output)
            template_embedding = model.embedder.encode([template])
            X, user_x = model.embedder.get_corrdinates(template_embedding) 
            label = model.predict(template_embedding)
            user_y = model.embedder.rol_to_id_dict.get(label, None)
            print(f"label is: {label} and corresponding encoding is: {user_y}")
            # use template to generate embeddings (user_x) and predict label (user_y)
            # also generate X with reduced dimensions given the user input

            fig = create_figure(X, labels, id_to_rol_dict, user_x, user_y)
        else:
            fig = prev_fig

        return fig



# Run the app
if __name__ == '__main__':
    app.run_server(debug=True, port=8080)


In [47]:
p_x = [2, 2, 2]

X_k = np.ones((5,3))

[map_opacity(distance(d, p_x)) for d in X_k]

[0.6339745962155614,
 0.6339745962155614,
 0.6339745962155614,
 0.6339745962155614,
 0.6339745962155614]