In [2]:
import dash
from dash import dcc, html
from dash.dependencies import Input, Output, State, ALL
import numpy as np
import plotly.graph_objects as go
from scipy.spatial import ConvexHull

import pandas as pd
import matplotlib.pyplot as plt
import json

from helpers import fill_template, read_config, read_json
from inference import KnnClassifier

user_x, user_y = [0.5, 0, 0], 2

df = pd.read_csv("./data/roles_coordinates_tsne.csv",index_col=0)

X = df[["x","y","z"]].values

labels = df["role_id"].values

representatives = np.unique(labels)

config = read_config()
model = KnnClassifier(config)
rol_to_id_dict = read_json("./data/role_to_id.json")



  from tqdm.autonotebook import tqdm, trange
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.6k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [3]:
with open('./data/suggestions.json', 'r') as f:
    form_dict = json.load(f)

## create a Div for each dropdown from the key and list of values from 
def group_from_json(key, values):
    return html.Div([
                html.Label(f"{key}:"),
                dcc.Dropdown(id={'type': 'input-field', 'key': key}, 
                             options=[{'label': option, 'value': option} for option in values], multi=True,),
                    ], className='form-group')

divs = [group_from_json(key, values) for key, values in form_dict.items()]
form = html.Div(divs, className='form-container')

In [6]:

# Create the Dash app
app = dash.Dash(__name__)

# Target Color
c_target = "red"

# Create a 3D-hull scatter plot given covariates X (n x 3) and labels y (n x 1)
def create_figure(X, y, repr, p_x, p_y):

    # Array containing all the clusters and their convex hulls
    layers = []

    unique_labels = np.unique(y)

    k = len(unique_labels)

    color_palette = plt.get_cmap("Greys")(np.linspace(0.2, 0.8, k))  # Using colormap from Matplotlib
    colors = [f'rgba({int(c[0] * 255)}, {int(c[1] * 255)}, {int(c[2] * 255)}, {c[3]})' for c in color_palette]
    
    # Do not plot the marker on initialization
    if p_x:
        user_point = go.Scatter3d(
                    x=[p_x[0]], y=[p_x[1]], z=[p_x[2]], 
                    mode='markers', 
                    marker=dict(
                        size=10,
                        color=c_target,
                        symbol='cross'),
                    name=f'Label {repr[p_y]}',
                    legendgroup=f'Label {repr[p_y]}',
                    showlegend=False,
                )
        layers.append(user_point)

    # Create scatter plots and convex hull volumes for each label in y
    for label in unique_labels:

        # greyscale or label color
        if label == p_y:
            c_val = c_target
        else:
            c_val = colors[label]

        # select rows with corresponding label
        X_k = X[y==label]

        # Compute the convex hull for this group if it has enough points
        if X_k.shape[0] >= 6:
            hull = ConvexHull(X_k)

            x_c, y_c, z_c  = X_k[:,0], X_k[:,1], X_k[:,2] 

            # Scatter plot for points
            scatter = go.Scatter3d(
                x=x_c, y=y_c, z=z_c, 
                mode='markers', 
                marker=dict(size=4, color=c_val, opacity=1),
                name=f'Label {label}',
                legendgroup=f'Label {label}',
                showlegend=False,
            )
            layers.append(scatter)

            # # Mesh3d for convex hull
            # mesh = go.Mesh3d(
            #     x=x_c,
            #     y=y_c,
            #     z=z_c,
            #     i=hull.simplices[:, 0],
            #     j=hull.simplices[:, 1],
            #     k=hull.simplices[:, 2],
            #     opacity=0.3,
            #     color=colors[label],
            #     hoverinfo='text',
            #     hovertext=f'{repr[label]}',
            #     name=f'Hull {label}',
            #     legendgroup=f'Hull {label}'
            # )
            # layers.append(mesh)

    # Create figure
    fig = go.Figure(data=layers)
    fig.update_layout(plot_bgcolor='lightblue',
                      paper_bgcolor='lightgrey',
                      scene=dict(aspectmode='data',
                                xaxis=dict(visible=False),
                                yaxis=dict(visible=False),
                                zaxis=dict(visible=False)),
                      title="Job Space",
                    )

    return fig

# Layout of the app
app.layout = html.Div([
    form,
    html.Button('Submit', id='submit-button', n_clicks=0),
    dcc.Graph(id='3d-scatter', figure=create_figure(X, labels, representatives, None, None)),
])

@app.callback(
    Output('3d-scatter', 'figure'),
    Input('submit-button', 'n_clicks'),
    # capture the input key-value pairs
    State({'type': 'input-field', 'key': ALL}, 'value'),
    State({'type': 'input-field', 'key': ALL}, 'id'),
    State('3d-scatter', 'figure'),
    prevent_initial_call=True,
)
def update_output(n_clicks, input_values, input_ids, prev_fig):
    if n_clicks > 0:
        if None not in input_values:
            # Create a dictionary to hold form data based on the pattern-matched input fields
            form_data = {input_id['key']: value for input_id, value in zip(input_ids, input_values)}

            # Convert to JSON format
            json_output = json.dumps(form_data, indent=4)
            
            # generate template using json_output
            template = fill_template(input_dict=json_output)
            print("Template is: ", template)
            # template_corrdinates, X = get_corrdinates(template) 
            label = model.predict(template)
            user_y = rol_to_id_dict.get(label, None)
            print(f"label is: {label} and corresponding encoding is: {user_y}")
            # use template to generate embeddings (user_x) and predict label (user_y)
            # also generate X with reduced dimensions given the user input

            fig = create_figure(X, labels, representatives, user_x, user_y)
        else:
            fig = prev_fig

        return fig



# Run the app
if __name__ == '__main__':
    app.run_server(debug=True, port=8080)


Template is:  We are seeking a professional with Bachelor's degree, possessing Analytical skills and HTML. This role requires the ability to Work independently, with excellent Effective strategic problem solving. Familiarity with Legal and Government Relations, and the capability to Optimize program delivery is preferred.


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

label is: Legal & Government Relations and corresponding encoding is: 9
Template is:  We are seeking a professional with Bachelor's degree, possessing Analytical skills and HTML. This role requires the ability to Work independently, with excellent Effective strategic problem solving. Familiarity with Legal and Government Relations, and the capability to Optimize program delivery is preferred.


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

label is: Legal & Government Relations and corresponding encoding is: 9
