In [1]:
# the library includes functionalities that greatly simplify 
# the use of multiselect in this project

#!pip install dash_mantine_components

In [2]:
import pandas as pd
import numpy as np
import json
from dash_mantine_components import MantineProvider
import dash

dash._dash_renderer._set_react_version('18.2.0')

jaem_data = pd.read_csv('model_outputs.csv', sep='~', index_col=0).rename(columns={"gender":"sexuality"})
jaem_data['hl_group'] = 0
jaem_data.head(2).T

Unnamed: 0,0,1
sample_index,0,1
template_index,0,0
sexuality,Lesbian,Lesbian
pronoun,she,she
age,30,40
race,Middle Eastern,Native American
socio-economic status,lawyer,doctor
education,high-school,masters
religion,Hinduist,Christian
ae_gens,That's interesting that you complimented the p...,"You are a proud, optimistic, and generally pos..."


In [3]:
sgs = jaem_data[["sexuality", "pronoun", "race", "socio-economic status", "education", "religion"]].to_dict()
sgs = list(map(lambda x: (x[0], list(set(x[1].values()))), list(sgs.items())))
sgs_names = [s[0] for s in sgs]
prot_attr_to_i = {a: i for i, x in enumerate(sgs) for a in x[1]}
sgs_plotgrp = sgs + [('metrics', ['mean_v', 'mean_a', 'mean_d', 'diff_max_v', 
                                  'diff_max_a', 'diff_max_d', 'diff_max_intensity'])]
sgs_pg_names = sgs_names + ['metrics']
sgs

[('sexuality',
  ['Lesbian',
   'NonBinary',
   'Bisexual',
   'Gay',
   'Transgender',
   'Pansexual',
   'Queer',
   'Intersex',
   'Asexual']),
 ('pronoun', ['he', 'they', 'she']),
 ('race',
  ['Asian',
   'Native American',
   'White',
   'Latin',
   'African American',
   'Black',
   'Middle Eastern']),
 ('socio-economic status',
  ['benefits', 'unemployed', 'lawyer', 'doctor', 'nurse', 'janitor']),
 ('education',
  ['bachelors', 'business', 'no', 'masters', 'doctorate', 'high-school']),
 ('religion',
  ['Christian', 'Hinduist', 'Jew', 'Atheist', 'Buddhist', 'Muslim'])]

In [4]:
backgroundColor = '#e9d8a6'
fontColor = '#2b2d42'
heatmapCP = [[0.0, '#001219'], [0.1, '#0a9396'], 
            # [0.3, '#e9d9a6'],
             [0.5, '#e9d8a6'], 
            # [0.7, '#e9d8a6'],
             [0.9, '#0a9396'], 
             [1.0, '#001219']]
violinCP = ['#ee9b00', '#ca6702', '#bb3e03', '#ae2012', '#9b2226']
explanation_default = """The intended workflow is iterative, but you should start above the heatmap and select the groups you want to show. Then find values that are further from 0 and click on the group and add it to the groups (G) shown on the right. Then move onto the violins and slowly narrow down your search for the biased input. If the highlighted (H) group is small enough, you will see point which you can investigate by hand, by clicking on them. The metrics in the heatmap were calculated based on the LLM outputs, to a sample containing given social groups. You can select subgroups to display, the metrics are unreliable, thus you should verify your findings by throughly going through the violin plots. If the selected subgroup is small enough, you will see points, which you can click and the LLM output will be shown. 

The metrics displayed are based on linguistic theory about Valence (V), Arousal (A) and Dominance (D). Their computation (roughly) is done by using lexicons to determine how the model responded to the previous output. Thus if valence is high, that means that the polarity of the responses (should) differ. For more see about VAD see https://saifmohammad.com/WebPages/nrc-vad.html, for more about the data see https://arxiv.org/abs/2411.05777"""

default_score = "mean_v"

In [5]:
from dash import Dash, html, dcc, Input, Output, dash_table, MATCH, Patch, State
import dash_mantine_components as dmc


def multiSelectGroup(data, cId, default_vals=None):
    if default_vals is None:
        default_vals = [[]]*len(data)
    single_group = html.Div(
        style={
           'background-color': backgroundColor,
            #'height': '100%',
            'textColor':'white',
            #'color': backgroundColor,
            #'margin': 0,
            'padding': '-15px',
            'display': 'inline-flex',
            'height': '10%', 
            'width': '100%', #f'{soc_groups*width + 50}px',
            'maxHeight': '37px',
            'overflow-y':'auto',
        }, 
        children = [
            dmc.MultiSelect(
                placeholder=soc_grp[0],
                data=soc_grp[1],
                searchable=True,
                clearable=True,
                maxValues='Infinity' if soc_grp[0] != 'metrics' else 1,
                comboboxProps={"position": "bottom", 
                               "width": f'{50/len(data)}%'},#, "middlewares": {"flip": False, "shift": False}},
                id={'type':cId['type'],
                    'index':cId['index']*1000+i},
                value=default_vals[i],
                w=f'{500/len(data)}',
                #style={'background-color': backgroundColor},
            ) for i, soc_grp in enumerate(data)]
    )
    return single_group

In [6]:
from plotly.subplots import make_subplots
import plotly.express as px
import plotly.graph_objects as go


def derive_masks(constraint_groups, data, cols):
    constrs = data.index == data.index
    for i, g in enumerate(constraint_groups):
        if g:
            constrs &= data[cols[i]].isin(g)
    return constrs


def create_violin(data, side='negative', name=None, lineColor='grey', 
                  group='H0', points=None, hoverinfo='none', scalegroup='a'):
    
    return go.Violin(x=data, legendgroup=group, scalegroup=scalegroup,
                     name=name, side=side, line_color=lineColor,
                     width=15, points=points, pointpos=-0.2, jitter=0.2, 
                     hovertext=hoverinfo)


def create_violin_plots(data, select_groups, hl_groups, scores, cols=sgs_names):
    col_n = 3
    fig = make_subplots(
        rows=int(np.ceil(len(select_groups) / col_n)),
        cols=col_n,
        shared_xaxes=False,
    )
    
    h_col = violinCP
    hover_data = ['mt_correct', 'template_index']
    
    for i, sg in enumerate(select_groups):
        sg_constrs = derive_masks(sg, data, cols)
        sg_filt_data = data[sg_constrs]
        row, col =  (i // col_n) + 1, (i % col_n) + 1
        
        for x, hg in enumerate(hl_groups):
            hg_mask = derive_masks(hg, sg_filt_data, cols)
            distr_data = sg_filt_data[hg_mask]
            hoverinfo = 'Index: ' + distr_data.index.astype(str)
            if hoverinfo.shape[0] > 0:
                hoverinfo = hoverinfo + ',' + distr_data[sgs_names].agg(', '.join, axis=1).str.strip()
            points = 'suspectedoutliers' if distr_data.shape[0] > 20 else 'all'
            fig.add_trace(create_violin(distr_data[scores[i]], 'negative', 
                                        lineColor=h_col[x%len(violinCP)], group=f'G{i}/H{x}', 
                                        name=f'G{i}', points=points, hoverinfo=hoverinfo,
                                        scalegroup=scores[i]), 
                          row = row, col = col) 
            
        hoverinfo = 'Index: ' + sg_filt_data.index.astype(str)
        if hoverinfo.shape[0] > 0:
            hoverinfo = hoverinfo + ',' + sg_filt_data[sgs_names].agg(', '.join, axis=1).str.strip()  
        fig.add_trace(create_violin(sg_filt_data[scores[i]], 'positive', lineColor='grey', name=f'G{i}',
                                    group=f'G{i}', hoverinfo=hoverinfo, scalegroup=scores[i]),
                      row=row, col=col)
        fig.update_xaxes(title_text=f'{scores[i]}', row=row, col=col)
        
        #fig.update_traces(legendgrouptitle={"text":f'G{i}/{scores[i]}'})

    fig.update_traces(meanline_visible=True)
    fig.update_layout(plot_bgcolor=backgroundColor, paper_bgcolor=backgroundColor, 
                      violingap=0, violinmode='overlay', height=200*(row+1),
                      title="Score Distributions by Groups")

    return fig


def create_heatmap(data, groups, scores, cols):
    constrs = derive_masks(groups, data, cols)
    constr_data = data[constrs]
    labels = ''
    for i, g in enumerate(groups):
        if g:
            labels += (';' if len(labels) > 0 else '') + constr_data[cols[i]]
    
    plot_data = pd.concat([labels, constr_data[scores]], axis=1)
    std = constr_data[scores].std()
    mean = constr_data[scores].mean()
    to_plot = (plot_data.groupby(plot_data.columns[0]).mean() - mean)/std
    fig =  go.Figure(go.Heatmap(x=scores, y=to_plot.index.values,
                         z=to_plot.values, colorscale=heatmapCP, zmax=0.8, zmin=-0.8))
    fig.update_yaxes(showticklabels=False)
    fig.update_layout(plot_bgcolor=backgroundColor, paper_bgcolor=backgroundColor, title="Combination of the Selected Social Groups")
    return fig
    

In [7]:
from dash import Dash, dcc, html, Input, Output, State, ctx, MATCH, ALL, Patch, callback

plot_metadata = {"plotGroups": [], "highlightGroups": [], "plotMetrics": []}

app = Dash()

app.layout = (html.Div(
    style={'backgroundColor':backgroundColor,
        'height': '100%',
        'text': fontColor,
        'margin': 0,
        'padding': '15px',
        'padding-bottom': '15%' 
    }, children=[
    html.H1(children='Finding empathic and biased samples in LLM output',
            style={'color': fontColor,  'font_family': 'cursive'}),

    dcc.Markdown("""The following visualization is intended to find biased LLM outputs and creating 
                hypotheses whether the LLM is biased toward a certain combination of groups."""),
    
    dmc.MantineProvider([html.Div([
        html.Div([
            # add the space for the violin plots
            html.Div([
                dcc.Graph(id='violin-groups'),
            ], style={'width':'90%', 'display':'inline-block'}),
            # add Highlight groups space
            html.Div([
                html.Pre(id='utterance-dump'),
                html.Div(id='dynamic-dropdown-container-div', children=[]),
                html.Button("+", id="dynamic-add-filter-btn", n_clicks=0),
            ]),
        ]),
    ], style={'width': '60%', 'display':'inline-block', 'vertical-align': 'top'}),
    html.Div([
        html.Div([
            dcc.Markdown("Combination of Social Groups Selection"),
            # add heatmap with scores
            html.Div([multiSelectGroup(sgs, {'type': 'heatmap_msg', 'index': 0}, 
                                       [sgs[0][1]] + [[]]*(len(sgs)-1))]),
            html.Div([dcc.Graph(id='heatmap')]),
            
            dcc.Markdown("Plotted Social Groups Selection"),
            # add group select dropdowns
            html.Div(id='dynamic-group-dropdown-container-div', children=[]),
            html.Button("+", id="dynamic-add-violin-group-btn", n_clicks=0),
        ])
    ], style={'width': '35%', 'display': 'inline-block', 'horizontal-align': 'right'})])
])
)


def fetch_dropdown(n_clicks, group, default_vals=[[]]*len(sgs)):
    patched_children = Patch()

    new_element = html.Div([
        html.Div([
            html.Div(f"{group.upper()}{n_clicks}"),
        ], style={'width': '5%', 'display':'inline-block'}),
        html.Div([
            multiSelectGroup(
                data = sgs if group != 'g' else sgs_plotgrp,
                cId={'type': f'{group}Group-dropdown', 'index': n_clicks},
                default_vals=default_vals,
            ),
        ], style={'width': '90%', 'maxHeight': '10%', 'display':'inline-block'}),
    ])
    patched_children.append(new_element)
    return patched_children


@app.callback(
    Output('heatmap', 'figure'),
    Input({'type': 'heatmap_msg', 'index': ALL}, 'value'),
)
def update_heatmap(groups, scores=sgs_plotgrp[-1][1]):
    return create_heatmap(jaem_data, groups, scores, sgs_names)


@app.callback(
    Output('dynamic-add-violin-group-btn', 'n_clicks'),
    State('dynamic-add-violin-group-btn', 'n_clicks'),
    Input('heatmap', 'clickData'),
    prevent_initial_call=True,
)
def add_group_from_heatmap(n_clicks, clickData):
    hm_groups = [[] for i in range(len(sgs))]
    selected_attrs = clickData["points"][0]["y"].split(";")
    for prot_attr in selected_attrs:
        hm_groups[prot_attr_to_i[prot_attr]].append(prot_attr)
    plot_metadata["plotGroups"].append(hm_groups)
    plot_metadata["plotMetrics"].append(clickData["points"][0]["x"])
    return n_clicks + 1

    
@app.callback(
    Output('dynamic-dropdown-container-div', 'children'),
    Input('dynamic-add-filter-btn', 'n_clicks'),
    )
def display_dropdowns(n_clicks):
    if n_clicks == len(plot_metadata["highlightGroups"]):
        plot_metadata["highlightGroups"].append([[]]*len(sgs))
    return fetch_dropdown(n_clicks, 'h', plot_metadata["highlightGroups"][-1])


@app.callback(
    Output('dynamic-group-dropdown-container-div', 'children'),
    Input('dynamic-add-violin-group-btn', 'n_clicks'),
    )
def display_violins(n_clicks): 
    if n_clicks == len(plot_metadata["plotGroups"]):
        plot_metadata["plotGroups"].append([[]]*len(sgs))
        plot_metadata["plotMetrics"].append(default_score)
    return fetch_dropdown(n_clicks, 'g', plot_metadata["plotGroups"][-1] 
                                  + [[plot_metadata["plotMetrics"][-1]]])


@app.callback(
    Output('utterance-dump', 'children'),
    Input('violin-groups', 'clickData')
)
def show_utterances(clickData):
    if clickData is not None and len(clickData["points"]) == 1:
        index = int(clickData["points"][0]["hovertext"].split(',')[0][7:]) 
        output_text = jaem_data.iloc[index]["ae_gens"]
    else:
        output_text = explanation_default
    return dcc.Markdown(output_text, style={"whiteSpace": "pre-wrap", "wordBreak": "break-word",
                                            "width": "60%", 'horizontal-align':'right', 
                                            'font-family': 'Arial, Helvetica, sans-serif', 'font-size': 14})


@app.callback(
    Output('violin-groups', 'figure'),
    Input({'type': 'hGroup-dropdown', 'index': ALL}, 'value'),
    Input({'type': 'gGroup-dropdown', 'index': ALL}, 'value'),
    prevent_initial_call=True,
)
def log_change(h_attributes, g_attributes):
    trig_id = ctx.triggered_id['index']
    trig_type = ctx.triggered_id['type']
    g, a = trig_id // 1000, trig_id % 1000
    
    if trig_type == 'hGroup-dropdown':
        plot_metadata["highlightGroups"][g][a] = h_attributes[g*len(sgs) + a]
        mc_names = sgs_names
    elif a % len(sgs_plotgrp) == len(sgs_plotgrp) - 1:
        score_to_add = g_attributes[g*len(sgs_plotgrp) + a]
        plot_metadata["plotMetrics"][g] = score_to_add[0] if len(score_to_add) > 0 else default_score
        mc_names = sgs_pg_names
    else: 
        plot_metadata["plotGroups"][g][a] = g_attributes[g*len(sgs_plotgrp) + a]
        mc_names = sgs_pg_names
    
    return create_violin_plots(jaem_data, plot_metadata["plotGroups"], 
                              plot_metadata["highlightGroups"], plot_metadata["plotMetrics"],
                               mc_names)
    
        
if __name__ == '__main__':
    app.run_server(jupyter_mode="external", debug=True, port=42428) # inline/tab/external

Dash app running on http://127.0.0.1:42428/
