In [1]:
import sys
sys.path.append("..")

import dash
import dash_bootstrap_components as dbc
import plotly.express as px
import pandas as pd
import plotly.graph_objs as go
import pickle
import numpy

from jupyter_dash import JupyterDash
from dash import dcc
from dash import html
from dash import Input, Output, State

from collections import defaultdict
from src.dataset import Dataset
from src.vectorizers import TokenVectorizer
from gensim.models import Word2Vec, KeyedVectors
from src.lda_utils import get_word_relevance, get_words_relevance, print_topics

In [2]:
PLOTLY_LOGO = "https://images.plot.ly/logo/new-branding/plotly-logomark.png"
TEMPLATE = 'plotly_white'

In [3]:
def get_distribution(term, word2id, doc_dates, transposed_vectors, dates_frequencies, low_filter=-1, high_filter=9999):
    ind = word2id.get(term, -1)
    if ind < 0:
        return []
    docs = [(index, 1) for index, occ in 
            enumerate(transposed_vectors[ind].toarray()[0]) if occ > 0]
    dates = [(doc_dates[index], occ) for index, occ in docs]
    freqs = defaultdict(lambda:0)
    for year, occ in dates:
        freqs[year] += occ
            
    return sorted([(year, occ/dates_frequencies[year]) 
                   for year, occ in freqs.items() if low_filter <= year <= high_filter and occ != dates_frequencies[year]])

In [4]:
dataset = Dataset()
dates = dataset.load_dataset(year=None, fields=["decision_date"], courts={"Illinois Appellate Court"})

In [5]:
vectors, vectorizer = TokenVectorizer.load_vectors_vectorizer(method="count")
vocab = vectorizer.get_feature_names()
word2id = dict((v, idx) for idx, v in enumerate(vocab))
id2word = dict((idx, v) for idx, v in enumerate(vocab))

vectors_trans = vectors.transpose()

https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations


In [6]:
model_path = "../data/models/test_w2v.model"
model = Word2Vec.load(model_path)

In [7]:
intervals = 10
norm_dates = [e['decision_date'] - e['decision_date']%intervals for e in dates]

dates_frequencies = defaultdict(lambda:0)

for d in norm_dates:
    dates_frequencies[d] += 1

In [8]:
lda_model = pickle.load(open("../data/models/IAC_exp_seed_minf_10_max_50%.pk", "rb"))

In [26]:
app = JupyterDash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])

##### Navbar

In [27]:
NAVBAR = dbc.Navbar(
    children=[
        html.A(
            # Use row and col to control vertical alignment of logo / brand
            dbc.Row(
                [
                    dbc.Col(html.Img(src=PLOTLY_LOGO, height="30px")),
                    dbc.Col(
                        dbc.NavbarBrand("Illinois Cases Analysis", className="ml-2"),
                        style={"marginLeft": 10}
                    ),
                ],
                align="center",
                className="g-0",
            ),
            href="https://github.com/tomfran/legal-texts-information-retrieval",
            style={"margin": 10, "textDecoration": "none"}
        )
    ],
    color="dark",
    dark=True,
    sticky="top",
)

##### Searchbox

In [28]:
SEARCH_BOX = dbc.InputGroup(
    [
        dbc.Button("Search", id="search-button", n_clicks=0),
        dbc.Input(id="search-input", placeholder="cocaine, drug - gun, weapon"),
    ],
    style={"marginTop": 20}
)

Word Analysis

In [29]:
CONTEXT_GRAPH = dcc.Loading(
    id="loading-similar-context-words",
    children=[dcc.Graph(id="similar-context-graph")],
    type="default",
)
GRAMS_GRAPH = dcc.Loading(
    id="loading-grams", 
    children=[dcc.Graph(id="grams-graph")],
    type="default",
)
WORD_GENERIC_TOPIC_DISTRIBUTION_GRAPH = dcc.Loading(
    id="loading-word-generic-topics", 
    children=[dcc.Graph(id="word-generic-topics-graph")],
    type="default",
)
WORD_SPECIFIC_TOPIC_DISTRIBUTION_GRAPH = dcc.Loading(
    id="loading-word-specific-topics", 
    children=[dcc.Graph(id="word-specific-topics-graph")],
    type="default",
)

WORD_TOPIC_TABS = dcc.Tabs(
    id="word-topics-tabs",
    children=[
        dcc.Tab(
            label="Generic",
        ),
        dcc.Tab(
            label="Specific",
        ),
    ]
)

In [30]:
WORD_CARD = [
    dbc.CardHeader(html.H5("Word analysis")),
    dbc.Alert(
        "Not enough data to render these plots, please adjust the filters",
        id="no-word-data-alert",
        color="warning",
        style={"display": "none"},
    ),
    dbc.CardBody(
        [
            dbc.Row(
                [
                    dbc.Col([ 
                        CONTEXT_GRAPH
                    ]),
                    dbc.Col([GRAMS_GRAPH], md=8)
                ]
            ),
            dbc.Row([WORD_TOPIC_TABS, WORD_GENERIC_TOPIC_DISTRIBUTION_GRAPH])
        ]
    )
]

Topic Analysis

In [31]:
TOPIC_WORDS_GRAPHS = dbc.Row(
    [
        dbc.Col(
            dcc.Loading(
                id="loading-frequencies",
                children=[dcc.Graph(id="frequency_figure")],
                type="default",
            )
        ),
        dbc.Col(
            [
                dcc.Tabs(
                    id="tabs",
                    children=[
                        dcc.Tab(
                            label="Treemap",
                            children=[
                                dcc.Loading(
                                    id="loading-treemap",
                                    children=[dcc.Graph(id="topic-treemap")],
                                    type="default",
                                )
                            ],
                        ),
                        dcc.Tab(
                            label="Wordcloud",
                            children=[
                                dcc.Loading(
                                    id="loading-wordcloud",
                                    children=[
                                        dcc.Graph(id="topic-wordcloud")
                                    ],
                                    type="default",
                                )
                            ],
                        ),
                    ],
                )
            ],
            md=8,
        ),
    ]
)

In [32]:
TOPIC_INFO_GRAPHS = dbc.Row(
    [
        dbc.Col(
            dcc.Loading(
                id="loading-topic-years",
                children=[
                    dcc.Graph(id="topic-years-histogram")
                ],
                type="default",
            )
        ),
        dbc.Col(
            dcc.Loading(
                id="loading-topic-courts",
                children=[
                    dcc.Graph(id="topic-courts-piechart")
                ],
                type="default",
            ),
            md=4
        )
    ]
)

In [33]:
TOPIC_CARD = [
    dbc.CardHeader(html.H5("Topic 5 - Driving Incidents", id="selected_topic_name")),
    dbc.Alert(
        "Not enough data to render these plots, please adjust the filters",
        id="no-topic-data-alert",
        color="warning",
        style={"display": "none"},
    ),
    dbc.CardBody(
        [TOPIC_WORDS_GRAPHS, dcc.RangeSlider(id="topic-time-window-slider"), TOPIC_INFO_GRAPHS]
    )
]

##### Body

In [34]:
BODY = dbc.Container(
    [
        SEARCH_BOX,
        dbc.Card(WORD_CARD, style={"marginTop": 20}),
        dbc.Card(TOPIC_CARD, style={"marginTop": 20, "marginBottom": 30}),
    ],
    className="mt-12",
)

##### Callbacks

In [35]:
@app.callback(
        Output("similar-context-graph", "figure"),
        Input('search-button', 'n_clicks'),
         State('search-input', 'value'))
def get_similar_context_graph(n_clicks, words):
    if words:
        word = words.split("-")[0].strip()
        sim = model.wv.most_similar(word, topn=15)[::-1]
        return px.histogram(
            y=[word[0] for word in sim],
            x=[word[1] for word in sim],
            orientation="h",     
            title="Similar context",
            color_discrete_sequence=['darkturquoise']
        ).update_layout(
        template=TEMPLATE,
        xaxis_title='',
        yaxis_title=''
    )
    return go.Figure()

In [36]:
@app.callback(
    Output('grams-graph', 'figure'),
    [Input('search-button', 'n_clicks')],
    [State('search-input', 'value')])
def update_output(n_clicks, searches):
    if searches:
        searches = searches.split("-")
        fig = go.Figure(layout=go.Layout(
            title = "Docs containing searched words frequency",
            template=TEMPLATE,
            xaxis=dict(
                rangeselector=dict(
                    buttons=list([
                        dict(count=1,
                             label="1y",
                             step="year",
                             stepmode="backward"),
                        dict(count=5,
                             label="5y",
                             step="year",
                             stepmode="backward"),
                        dict(count=10,
                             label="10y",
                             step="year",
                             stepmode="backward"),
                        dict(count=25,
                             label="25y",
                             step="year",
                             stepmode="backward"),
                        dict(count=50,
                             label="50y",
                             step="year",
                             stepmode="backward"),
                        dict(step="all")
                    ])
                ),
                rangeslider=dict(
                    visible=True
                ),
                title="year",
                type="date"
            ),
            yaxis=dict(title="freq")
        ))
        for search in searches:
            search = search.strip()
            grams = get_distribution(search, word2id, norm_dates, vectors_trans, dates_frequencies)
            if not grams:
                continue
            fig.add_trace(go.Scatter(x=[year_perc[0] for year_perc in grams], y=[year_perc[1] for year_perc in grams],
                                mode='lines',
                                name=search))
        return fig
    else:
        return go.Figure()

In [37]:
@app.callback(
        Output("word-generic-topics-graph", "figure"),
        Input('search-button', 'n_clicks'),
         State('search-input', 'value'))
def get_generic_topics_radar_graph(n_clicks, searches):
    if not searches:
        return go.Figure()

    fig = go.Figure(layout=go.Layout(
            title="Topic distribution",
            template=TEMPLATE,
    )               )
    
    searches = searches.split("-")
    for search in searches:
        search = search.strip()
    
        topics = get_word_relevance(search, word2id, vocab, lda_model, normalize=True)
        values = numpy.array(list(topics.values()), dtype='f') * 10 / max(topics.values())
        
        fig.add_trace(go.Scatterpolar(
                  r=values,
                  theta=[str(name) for name in list(topics.keys())],
                  fill='toself',
                  name=search
            ))

    fig.update_layout(
      polar=dict(
        radialaxis=dict(
          visible=True,
          range=[0, 10]
        )),
      showlegend=True
    )
    return fig

In [38]:
app.layout = html.Div(children=[NAVBAR, BODY])

app.run_server(mode='jupyterlab', dev_tools_ui=True, #debug=True, 
               dev_tools_hot_reload =True, threaded=True)

In [22]:
def _terminate_server_for_port(host, port):
        shutdown_url = "http://{host}:{port}/_shutdown_{token}".format(
            host=host, port=port, token=JupyterDash._token
        )
        try:
            response = requests.get(shutdown_url)
        except Exception as e:
            pass

In [25]:
_terminate_server_for_port("localhost", 8050)