In [446]:
import sys
sys.path.append("..")

import dash
import dash_bootstrap_components as dbc
import plotly.express as px
import pandas as pd
import plotly.graph_objs as go
import pickle
import numpy
import random

from jupyter_dash import JupyterDash
from dash import dcc
from dash import html
from dash import Input, Output, State
from plotly import colors as plotly_colors
from wordcloud import WordCloud

from collections import defaultdict
from src.dataset import Dataset
from src.vectorizers import TokenVectorizer
from gensim.models import Word2Vec, KeyedVectors
from src.lda_utils import get_word_relevance, get_words_relevance, print_topics

In [235]:
PLOTLY_LOGO = "https://images.plot.ly/logo/new-branding/plotly-logomark.png"
TEMPLATE = 'plotly_white'

In [236]:
def get_distribution(term, word2id, doc_dates, transposed_vectors, dates_frequencies, low_filter=-1, high_filter=9999):
    ind = word2id.get(term, -1)
    if ind < 0:
        return []
    docs = [(index, 1) for index, occ in 
            enumerate(transposed_vectors[ind].toarray()[0]) if occ > 0]
    dates = [(doc_dates[index], occ) for index, occ in docs]
    freqs = defaultdict(lambda:0)
    for year, occ in dates:
        freqs[year] += occ
            
    return sorted([(year, occ/dates_frequencies[year]) 
                   for year, occ in freqs.items() if low_filter <= year <= high_filter and occ != dates_frequencies[year]])

In [237]:
dataset = Dataset()
dates = dataset.load_dataset(year=None, fields={"topic", "decision_date"}, courts={"Illinois Appellate Court"})

KeyboardInterrupt: 

In [5]:
vectors, vectorizer = TokenVectorizer.load_vectors_vectorizer(method="count")
vocab = vectorizer.get_feature_names()
word2id = dict((v, idx) for idx, v in enumerate(vocab))
id2word = dict((idx, v) for idx, v in enumerate(vocab))

vectors_trans = vectors.transpose()

https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations


In [6]:
model_path = "../data/models/test_w2v.model"
w2v_model = Word2Vec.load(model_path)

In [7]:
intervals = 10
norm_dates = [e['decision_date'] - e['decision_date']%intervals for e in dates]

dates_frequencies = defaultdict(lambda:0)

for d in norm_dates:
    dates_frequencies[d] += 1

In [8]:
generic_lda_model = pickle.load(open("../data/models/IAC_exp_seed_minf_10_max_50%.pk", "rb"))
specific_lda_model = pickle.load(open("../data/models/FULL_exp_seed_t_0_2_13_minf_10_max_50%.pk", "rb"))

https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations


In [496]:
def get_topic_words(model, topic_id, vectorizer, n_top_words=10, only_interesting=False, interesting_set={}):
    vocab = vectorizer.get_feature_names()
    topic_words = {}
    comp = model.components_[topic_id]
    if only_interesting:
        word_idx = numpy.argsort(comp)[::-1]
        max_rel = comp[word_idx[0]]
        words = [el for el in [(vocab[i], comp[i]/max_rel) for i in word_idx] 
                              if el[0] in interesting_set][:n_top_words]
    else: 
        word_idx = numpy.argsort(comp)[::-1][:n_top_words]
        max_rel = comp[word_idx[0]]
        words = [(vocab[i], comp[i]/max_rel) for i in word_idx]        
    return words

def get_wordcloud_graphs_topic_words(wordcloud):
    
    word_list = []
    freq_list = []
    fontsize_list = []
    position_list = []
    orientation_list = []
    color_list = []

    for (word, freq), fontsize, position, orientation, color in wordcloud.layout_:
        word_list.append(word)
        freq_list.append(freq)
        fontsize_list.append(fontsize)
        position_list.append(position)
        orientation_list.append(orientation)
        color_list.append(color)
    
     # get the positions
    x_arr = []
    y_arr = []
    for i in position_list:
        x_arr.append(i[0])
        y_arr.append(i[1])

    # get the relative occurence frequencies
    new_freq_list = []
    for i in freq_list:
        new_freq_list.append(i * 80)

    trace = go.Scatter(
        x=x_arr,
        y=y_arr,
        textfont=dict(size=new_freq_list, color=color_list),
        hoverinfo="text",
        textposition="top center",
        hovertext=["{0} - {1}".format(w, f) for w, f in zip(word_list, freq_list)],
        mode="text",
        text=word_list,
    )

    layout = go.Layout(
        {
            "xaxis": {
                "showgrid": False,
                "showticklabels": False,
                "zeroline": False,
                "automargin": True,
                "range": [-100, 250],
            },
            "yaxis": {
                "showgrid": False,
                "showticklabels": False,
                "zeroline": False,
                "automargin": True,
                "range": [-100, 450],
            },
            "margin": dict(t=0, b=0, l=0, r=0, pad=0),
            "hovermode": "closest",
        }
    )

    return {"data": [trace], "layout": layout}

In [10]:
def norm(date, interval=5):
    return date - date%interval

# count all frequencies of intervals
all_freqs = defaultdict(lambda:0)
for el in dates:
    d = norm(el['decision_date'])
    all_freqs[d] += 1

# compute topic frequencies normalized by the total of each interval
topic_dists = defaultdict(lambda:defaultdict(lambda:0))
for el in dates:
    d = norm(el['decision_date'])
    t = el["topic"]
    for i, e in enumerate(t):
        topic_dists[i][d] += e/all_freqs[d]   

In [605]:
app = JupyterDash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])

##### Navbar

In [606]:
NAVBAR = dbc.Navbar(
    children=[
        html.A(
            # Use row and col to control vertical alignment of logo / brand
            dbc.Row(
                [
                    dbc.Col(html.Img(src=PLOTLY_LOGO, height="30px")),
                    dbc.Col(
                        dbc.NavbarBrand("Illinois Cases Analysis", className="ml-2"),
                        style={"marginLeft": 10}
                    ),
                ],
                align="center",
                className="g-0",
            ),
            href="https://github.com/tomfran/legal-texts-information-retrieval",
            style={"margin": 10, "textDecoration": "none"}
        )
    ],
    color="dark",
    dark=True,
    sticky="top",
)

##### Searchbox

In [607]:
SEARCH_BOX = dbc.InputGroup(
    [
        dbc.Button("Search", id="search-button", n_clicks=0),
        dbc.Input(id="search-input", placeholder="cocaine, drug - gun, weapon"),
    ],
    style={"marginTop": 20}
)

Word Analysis

In [608]:
WORD_DROPDOWN = dcc.Dropdown(id="words-drop", clearable=False, style={"font-size": 12})
CONTEXT_GRAPH = dcc.Loading(
    id="loading-similar-context-words",
    children=[dcc.Graph(id="similar-context-graph")],
    type="default",
)
GRAMS_GRAPH = dcc.Loading(
    id="loading-grams", 
    children=[dcc.Graph(id="grams-graph")],
    type="default",
)
WORD_GENERIC_TOPIC_DISTRIBUTION_GRAPH = dcc.Loading(
    id="loading-word-topics", 
    children=[dcc.Graph(id="word-topics-graph")],
    type="default",
)

WORD_TOPIC_TABS = dcc.Tabs(
    id="word-topics-tabs",
    value="Generic",
    children=[
        dcc.Tab(
            label="Generic",
            value='Generic'
        ),
        dcc.Tab(
            label="Specific",
            value='Specific'
        ),
    ]
)

In [609]:
WORD_CARD = [
    dbc.CardHeader(html.H5("Word analysis")),
    dbc.Alert(
        "Not enough data to render these plots, please adjust the filters",
        id="no-word-data-alert",
        color="warning",
        style={"display": "none"},
    ),
    dbc.CardBody(
        [
            dbc.Row(
                [
                    dbc.Col([
                        WORD_DROPDOWN,
                        CONTEXT_GRAPH
                    ]),
                    dbc.Col([GRAMS_GRAPH], md=8)
                ]
            ),
            dbc.Row([WORD_TOPIC_TABS, WORD_GENERIC_TOPIC_DISTRIBUTION_GRAPH])
        ]
    )
]

Topic Analysis

In [610]:
TOPIC_WORDS_GRAPHS = dbc.Row(
    [
        dbc.Col(
            dcc.Loading(
                id="loading-topic-top-words",
                children=[dcc.Graph(id="topic-top-words-graph")],
                type="default",
            )
        ),
        dbc.Col(
            [
                dcc.Tabs(
                    id="tabs",
                    children=[
                        dcc.Tab(
                            label="Treemap",
                            children=[
                                dcc.Loading(
                                    id="loading-treemap",
                                    children=[dcc.Graph(id="topic-treemap")],
                                    type="default",
                                )
                            ],
                        ),
                        dcc.Tab(
                            label="Wordcloud",
                            children=[
                                dcc.Loading(
                                    id="loading-wordcloud",
                                    children=[
                                        dcc.Graph(id="topic-wordcloud")
                                    ],
                                    type="default",
                                )
                            ],
                        ),
                    ],
                )
            ],
            md=8,
        ),
    ]
)

In [611]:
TOPIC_INFO_GRAPHS = dbc.Row(
    [
        dbc.Col(
            dcc.Loading(
                id="loading-topic-years",
                children=[
                    dcc.Graph(id="topic-years-histogram")
                ],
                type="default",
            )
        ),
        dbc.Col(
            dcc.Loading(
                id="loading-topic-courts",
                children=[
                    dcc.Graph(id="topic-courts-piechart")
                ],
                type="default",
            ),
            md=4
        )
    ]
)

In [612]:
TOPIC_CARD = [
    dbc.CardHeader(id="topic-header", children=[html.H5("Topic 5 - Driving Incidents", id="selected_topic_name")]),
    dbc.Alert(
        "Not enough data to render these plots, please adjust the filters",
        id="no-topic-data-alert",
        color="warning",
        style={"display": "none"},
    ),
    dbc.CardBody(
        [TOPIC_WORDS_GRAPHS, TOPIC_INFO_GRAPHS]
    )
]

##### Body

In [613]:
BODY = dbc.Container(
    [
        SEARCH_BOX,
        dbc.Card(WORD_CARD, style={"marginTop": 20}),
        dbc.Card(TOPIC_CARD, style={"marginTop": 20, "marginBottom": 30}),
    ],
    className="mt-12",
)

##### Callbacks

In [614]:
@app.callback(
    [
        Output("words-drop", "options"),
        Output("words-drop", "value"),
    ],
    Input('search-button', 'n_clicks'),
    State('search-input', 'value')
)
def populate_bank_dropdown(n_clicks, searches):
    if not searches:
        return [], None
    options = []
    for search in searches.split("-"):
        search.strip()
        options.append({"label": search, "value": search})
    return options, options[0]['value']

In [615]:
@app.callback(
    Output("similar-context-graph", "figure"),
    Input("words-drop", "value")
)
def get_similar_context_graph(word):
    if not word:
        return {}
    word = word.strip()
    sim = w2v_model.wv.most_similar(word, topn=15)[::-1]
    return px.histogram(
        y=[word[0] for word in sim],
        x=[word[1] for word in sim],
        orientation="h",     
        title="Similar context",
        color_discrete_sequence=['darkturquoise']
    ).update_layout(
        template=TEMPLATE,
        xaxis_title='',
        yaxis_title=''
    )

In [616]:
@app.callback(
    Output('grams-graph', 'figure'),
    [Input('search-button', 'n_clicks')],
    [State('search-input', 'value')])
def update_output(n_clicks, searches):
    if not searches:
        return {}
    
    searches = searches.split("-")
    fig = go.Figure(layout=go.Layout(
        title = "Docs containing searched words frequency",
        template=TEMPLATE,
        xaxis=dict(
            rangeselector=dict(
                buttons=list([
                    dict(count=1,
                         label="1y",
                         step="year",
                         stepmode="backward"),
                    dict(count=5,
                         label="5y",
                         step="year",
                         stepmode="backward"),
                    dict(count=10,
                         label="10y",
                         step="year",
                         stepmode="backward"),
                    dict(count=25,
                         label="25y",
                         step="year",
                         stepmode="backward"),
                    dict(count=50,
                         label="50y",
                         step="year",
                         stepmode="backward"),
                    dict(step="all")
                ])
            ),
            rangeslider=dict(
                visible=True
            ),
            title="year",
            type="date"
        ),
        yaxis=dict(title="freq")
    ))
    for search in searches:
        search = search.strip()
        grams = get_distribution(search, word2id, norm_dates, vectors_trans, dates_frequencies)
        if not grams:
            continue
        fig.add_trace(go.Scatter(x=[year_perc[0] for year_perc in grams], y=[year_perc[1] for year_perc in grams],
                            mode='lines',
                            name=search))
    return fig

In [617]:
@app.callback(
    Output("word-topics-graph", "figure"),
    [
        Input('search-button', 'n_clicks'),
        Input('word-topics-tabs', 'value')
    ],
    State('search-input', 'value'))
def get_generic_topics_radar_graph(n_clicks, tab, searches):
    if not searches:
        return {}

    fig = go.Figure(layout=go.Layout(
            title="Topic distribution",
            template=TEMPLATE,
    )               )
    
    searches = searches.split("-")
    for search in searches:
        search = search.strip()
        
        topics = get_word_relevance(search, word2id, vocab, generic_lda_model if tab == "Generic" else specific_lda_model, normalize=True)
        values = numpy.array(list(topics.values()), dtype='f') * 10 / max(topics.values())
        
        fig.add_trace(go.Scatterpolar(
                  r=values,
                  theta=[str(name) for name in list(topics.keys())],
                  fill='toself',
                  name=search
            ))

    fig.update_layout(
      polar=dict(
        radialaxis=dict(
          visible=True,
          range=[0, 10]
        )),
      showlegend=True
    )
    return fig

In [618]:
@app.callback(
    [
        Output("topic-header", "children"),
        Output("topic-top-words-graph", "figure"),
        Output("topic-treemap", "figure"),
        Output("topic-wordcloud", "figure")
    ],
    [
        Input('word-topics-graph', 'clickData'),
        Input('word-topics-tabs', 'value')
    ],)
def get_topic_words_radar_graph(selected_topic, tab):
    if not selected_topic:
        return [html.H5("Select a topic")], {}, {}, {}
    topic_id = selected_topic['points'][0]['pointNumber']
    sim = get_topic_words(generic_lda_model if tab == "Generic" else specific_lda_model, 
                          topic_id,
                          vectorizer, 
                          n_top_words=100, 
                          only_interesting=False)
    
    words = [word[0] for word in sim]
    freqs = [word[1] for word in sim]
    
    treemap_trace = go.Treemap(
        labels=words[:50], parents=[""] * len(words[:50]), values=freqs
    )
    treemap_layout = go.Layout({"margin": dict(t=0, b=0, l=0, r=0, pad=0)})
    treemap_figure = {"data": [treemap_trace], "layout": treemap_layout}
    
    wc = WordCloud().generate_from_frequencies(frequencies={word[0]: word[1] for word in sim})
    wordcloud = get_wordcloud_graphs_topic_words(wc)
    
    return [html.H5(f"{tab} Topic {topic_id}")],px.histogram(
        y=words[:20][::-1],
        x=freqs[:20][::-1],
        orientation="h",     
        color_discrete_sequence=['darkturquoise']
    ).update_layout(
        template=TEMPLATE,
        xaxis_title='',
        yaxis_title='',
        height=550
    ), treemap_figure, wordcloud

In [619]:
@app.callback(
    Output("topic-years-histogram", "figure"),
    [
        Input('word-topics-graph', 'clickData'),
        Input('word-topics-tabs', 'value')
    ],)
def get_topic_words_radar_graph(selected_topic, tab):
    if not selected_topic:
        return {}
    topic_id = selected_topic['points'][0]['pointNumber']
    data = [
        {
            "x": list(topic_dists[topic_id].keys()),
            "y": list(topic_dists[topic_id].values()),
            "text": list(topic_dists[topic_id].keys()),
            "type": "bar",
            "name": "",
        }
    ]
    layout = {
        "autosize": True,
        "margin": dict(t=10, b=20, l=40, r=0, pad=4),
        "xaxis": dict(rangeselector=dict(
            buttons=list([
                dict(count=1,
                     label="1y",
                     step="year",
                     stepmode="backward"),
                dict(count=5,
                     label="5y",
                     step="year",
                     stepmode="backward"),
                dict(count=10,
                     label="10y",
                     step="year",
                     stepmode="backward"),
                dict(count=25,
                     label="25y",
                     step="year",
                     stepmode="backward"),
                dict(count=50,
                     label="50y",
                     step="year",
                     stepmode="backward"),
                dict(step="all")
            ])
        ),
                       rangeslider=dict(
                            visible=True
                        ),
                       title="year",
                       type="date",
                       showticklabels=True, )
    }
    return {"data": data, "layout": layout}

In [620]:
app.layout = html.Div(children=[NAVBAR, BODY])

app.run_server(mode='jupyterlab', dev_tools_ui=True, #debug=True, 
               dev_tools_hot_reload =True, threaded=True)

In [395]:
def _terminate_server_for_port(host, port):
        shutdown_url = "http://{host}:{port}/_shutdown_{token}".format(
            host=host, port=port, token=JupyterDash._token
        )
        try:
            response = requests.get(shutdown_url)
        except Exception as e:
            pass

In [396]:
# _terminate_server_for_port("localhost", 8050)


Function get_feature_names is deprecated; get_feature_names is deprecated in 1.0 and will be removed in 1.2. Please use get_feature_names_out instead.

