## Text Analysis - Topic Modelling
### <span style='color: green'>SETUP </span> Prepare and Setup Notebook <span style='float: right; color: red'>MANDATORY</span>

In [2]:
# Setup
%load_ext autoreload
%autoreload 2

import sys, os, collections, zipfile

sys.path = [ '/home/roger/source/text_analytic_tools' ] + sys.path

import re, typing.re
import warnings
import nltk, textacy, spacy 
import pickle
import pandas as pd
import ipywidgets as widgets
import bokeh, bokeh.plotting, bokeh.models, matplotlib.pyplot as plt

import text_analytic_tools

import text_analytic_tools.utility.utils as utility
import text_analytic_tools.utility.widgets as widgets
import text_analytic_tools.common.text_corpus as text_corpus
import text_analytic_tools.common.textacy_utility as textacy_utility
import text_analytic_tools.text_analysis.topic_model as topic_model
import text_analytic_tools.text_analysis.topic_model_utility as topic_model_utility

from beakerx.object import beakerx
from beakerx import *
from IPython.display import display, set_matplotlib_formats

warnings.filterwarnings("ignore", category=DeprecationWarning) 
warnings.filterwarnings("ignore", category=UserWarning) 
logger = utility.getLogger('corpus_text_analysis')

utility.setup_default_pd_display(pd)

from text_analytic_tools.config import get_current_domain

domain_logic = get_current_domain()

%matplotlib inline

# set_matplotlib_formats('svg')
bokeh.plotting.output_notebook()

current_corpus_container = lambda: textacy_utility.CorpusContainer.container()
current_corpus           = lambda: textacy_utility.CorpusContainer.corpus()
current_state            = lambda: topic_model_utility.TopicModelContainer.singleton()
current_data             = lambda: current_state().data
current_topic_model      = lambda: current_state().topic_model


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
%%bash
#mkdir ./tmp
#ln -s /home/roger/source/STTM ./lib

## <span style='color: green;'>MODEL</span> Compute Topic Model Based on Raw Source Text Corpus<span style='color: red; float: right'>ALTERNATIVE #1</span>

#### <span style='color: green'>PREPARE</span> Load (or Create) The Corpus <span style='float: right; color: red'>OPTIONAL</span>
Setup a new corpus from the raw source text files the reside in a zip archive. This step uses the spaCy and textaCy frameworks for PoS tagging. This will take some time, several minutes, For large text files. If the same processing and filtering rules are repeatedly, then it is recommended to prepare the corpus once and for all using "1_extract_corpus_text" (also see next step).


In [17]:
import text_analytic_tools.notebooks_gui.load_corpus_gui as load_corpus_gui

try:
    container = current_corpus_container()
    load_corpus_gui.display_corpus_load_gui(domain_logic.DATA_FOLDER, document_index=None, container=container)
except Exception as ex:
    raise
    logger.error(ex)

IndexError: list index out of range

#### <span style='color: green;'>MODEL</span> Compute the Topic Model<span style='color: red; float: right'>OPTIONAL</span>


In [None]:
import topic_model_gui

try:
    gui = topic_model_gui.TextacyCorpusUserInterface(
        data_folder=domain_logic.DATA_FOLDER,
        state=current_state(),
        document_index=domain_logic.compile_documents(current_corpus()),
        tagset=domain_logic.get_tagset(),
        substitution_filename=domain_logic.SUBSTITUTION_FILENAME
    )
    gui.display(current_corpus())
    
except Exception as ex:
    raise
    logger.error(ex)

## <span style='color: green'>MODEL </span> Compute Topic Model Based on a Previously Prepared Text Corpus <span style='float: right; color: red'>ALTERNATIVE #2</span>
This step loads a text corpus consisting of pre-processed tokens. This is much faster compered to previous step since the corpus is assumed to be tokenized, lemmatized and filtered, and the corpus can be used by the topic modelling engines without further processing.  

- Use the **1_extract_corpus_text** notebook to prepare this kind of corpus.
- This is recommended for large corpora when the pre-process take a long and if the same filters and setup are to be used several times.


In [None]:
import topic_model_gui

try:
    
    def fn_doc_index(corpus):
        return domain_logic.compile_documents_by_filename(corpus.filenames)
    
    gui = topic_model_gui.PreparedCorpusUserInterface(data_folder=DATA_FOLDER, state=current_state(), fn_doc_index=fn_doc_index)
    
    gui.display(None)
    
except Exception as ex:
    raise
    logger.error(ex)

## <span style='color: green;'>MODEL</span> Store the Current Model or Load a Previously Computed Topic Model<span style='color: red; float: right'>OPTIONAL</span>

In [None]:
import pickle
import glob
import topic_model
import topic_model_utility

def get_persisted_model_paths():
    return sorted([ x for x in glob.glob(os.path.join(DATA_FOLDER, '*.pickle')) ])

def get_store_filename(identifier):
    filename = os.path.join(DATA_FOLDER, 'topic_model.pickle')
    filename = utility.path_add_date(filename)
    filename = utility.path_add_suffix(filename, identifier)
    return filename
    
def display_persist_topic_model_gui(state):
    
    gui = types.SimpleNamespace(
        stored_path=widgets.Dropdown(description='Path', options=get_persisted_model_paths(), layout=widgets.Layout(width='40%')),
        load=widgets.Button(description='Load', button_style='Success', layout=widgets.Layout(width='80px')),
        store=widgets.Button(description='Store', button_style='Success', layout=widgets.Layout(width='80px')),
        identifier=widgets.Text(description='Identifier', layout=widgets.Layout(width='300px')),
        output=widgets.Output()
    )
    
    boxes = widgets.VBox([
        widgets.HBox([gui.stored_path, gui.load, gui.store, gui.identifier ]),
        widgets.HBox([
            widgets.Label(value="", layout=widgets.Layout(width='40%')),
            widgets.Label(value="Stored models will be named ./data/topic_model_yyyymmdd_$identifier$.pickle", layout=widgets.Layout(width='40%')),
        ]),
        widgets.VBox([gui.output])
    ])
    
    def load_handler(*args):
        
        with gui.output:
            
            if gui.stored_path.value is None:
                print("Please specify which model to load.")
                return

            state.data = topic_model.load_model(gui.stored_path.value)

            topics = topic_model_utility.get_lda_topics(state.topic_model, n_tokens=20)

            display(topics)

    def store_handler(*args):
        
        gui.output.clear_output()

        with gui.output:

            if gui.identifier.value == '':
                print("Please specify a unique identifier for the model.")
                return

            if gui.identifier.value != utility.filename_whitelist(gui.identifier.value):
                print("Please use ONLY valid filename characters in identifier.")
                return

            filename = get_store_filename(gui.identifier.value)

            topic_model.store_model(state.data, filename)

            gui.stored_path.options = get_persisted_model_paths()
            gui.stored_path.value = filename if filename in gui.stored_path.options else None

            print('Model stored in file {}'.format(filename))
            
    gui.load.on_click(load_handler)
    gui.store.on_click(store_handler)
    
    display(boxes)

display_persist_topic_model_gui(current_state())


## <span style='color: green;'>VISUALIZE</span> Display Topic's Word Distribution as a Wordcloud<span style='color: red; float: right'>TRY IT</span>

In [None]:
# Display LDA topic's token wordcloud
opts = { 'max_font_size': 100, 'background_color': 'white', 'width': 900, 'height': 600 }
import wordcloud
import matplotlib.pyplot as plt

def plot_wordcloud(df, token='token', weight='weight', figsize=(14, 14/1.618), **args):
    token_weights = dict({ tuple(x) for x in df[[token, weight]].values })
    image = wordcloud.WordCloud(**args,)
    image.fit_words(token_weights)
    plt.figure(figsize=figsize) #, dpi=100)
    plt.imshow(image, interpolation='bilinear')
    plt.axis("off")
    plt.show()
    
def display_wordcloud(
    state,
    topic_id=0,
    n_words=100,
    output_format='Wordcloud',
    gui=None
):
    def tick(n=None):
        gui.progress.value = (gui.progress.value + 1) if n is None else n
        
    if gui.n_topics != state.num_topics:
        gui.n_topics = state.num_topics
        gui.topic_id.value = 0
        gui.topic_id.max=state.num_topics - 1
        
    tick(1)
    
    try:
        topic_token_weights = state.processed.topic_token_weights

        df = topic_token_weights.loc[(topic_token_weights.topic_id == topic_id)]

        tokens = topic_model_utility.get_topic_title(topic_token_weights, topic_id, n_tokens=n_words)
        gui.text.value = 'ID {}: {}'.format(topic_id, tokens)

        tick()

        if output_format == 'Wordcloud':
            plot_wordcloud(df, 'token', 'weight', max_words=n_words, **opts)
        else:
            tick()
            df = topic_model_utility.get_topic_tokens(topic_token_weights, topic_id=topic_id, n_words=n_words)
            tick()
            display(df)
    except IndexError:
        print('No data for topic')
    tick(0)
    
def display_wordcloud_gui(state):
    
    output_options = ['Wordcloud', 'Table']
    text_id = 'tx02'
    
    gui = widgets_utility.WidgetUtility(
        n_topics=state.num_topics,
        text_id=text_id,
        text=widgets_config.text(text_id),
        topic_id=widgets.IntSlider(description='Topic ID', min=0, max=state.num_topics - 1, step=1, value=0, continuous_update=False),
        word_count=widgets.IntSlider(description='#Words', min=5, max=250, step=1, value=25, continuous_update=False),
        output_format=widgets.Dropdown(description='Format', options=output_options, value=output_options[0], layout=widgets.Layout(width="200px")),
        progress=widgets.IntProgress(min=0, max=4, step=1, value=0, layout=widgets.Layout(width="95%"))
    )

    gui.prev_topic_id = gui.create_prev_id_button('topic_id', state.num_topics)
    gui.next_topic_id = gui.create_next_id_button('topic_id', state.num_topics)

    iw = widgets.interactive(
        display_wordcloud,
        state=widgets.fixed(state),
        topic_id=gui.topic_id,
        n_words=gui.word_count,
        output_format=gui.output_format,
        gui=widgets.fixed(gui)
    )

    display(widgets.VBox([
        gui.text,
        widgets.HBox([gui.prev_topic_id, gui.next_topic_id, gui.topic_id, gui.word_count, gui.output_format]),
        gui.progress,
        iw.children[-1]
    ]))

    iw.update()

try:
    display_wordcloud_gui(current_state())
except topic_model_utility.TopicModelException as ex:
    logger.info(ex)


## <span style='color: green'>EXPLORE </span> pyLDAvis <span style='float: right; color: red'>TRY IT</span>
http://www.aclweb.org/anthology/W14-3110 presented at the 2014 ACL Workshop on Interactive Language Learning, Visualization, and Interfaces in Baltimore on June 27, 2014.
https://github.com/bmabey/pyLDAvis

In [None]:
import pyLDAvis, pyLDAvis.gensim, pyLDAvis.sklearn
import gensim
pyLDAvis.enable_notebook()
def display_pyLDAvis(state):
    
    try:
        if isinstance(state.data.topic_model, textacy.tm.topic_model.TopicModel):
            topic_model = state.data.topic_model.model
        elif isinstance(state.data.topic_model, gensim.models.wrappers.LdaMallet):
            topic_model = topic_model_utility.malletmodel2ldamodel(state.data.topic_model)
        else:
            topic_model = state.data.topic_model

        if 'sklearn' in str(type(topic_model)):
            p = pyLDAvis.sklearn.prepare(topic_model, state.data.g_corpus, state.data.id2term)
        else:
            p = pyLDAvis.gensim.prepare(topic_model, state.data.g_corpus, state.data.id2term)

        display(p)
    except Exception as ex:
        logger.warning('This model cannot be visualized with pyLDAvis')
        logger.error(ex)
        
display_pyLDAvis(current_state())


## <span style='color: green;'>VISUALIZE</span> Display Topic's Word Distribution as a Chart<span style='color: red; float: right'>TRY IT</span>

FIXME: Number of topics as specified in compute is not relevant for all topics. state.num_topics is to high for these models wich gives an error.*


In [None]:
# Display topic's word distribution
import numpy as np
warnings.filterwarnings("ignore", category=DeprecationWarning) 

def plot_topic_word_distribution(tokens, **args):

    source = bokeh.models.ColumnDataSource(tokens)

    p = bokeh.plotting.figure(toolbar_location="right", **args)

    cr = p.circle(x='xs', y='ys', source=source)

    label_style = dict(level='overlay', text_font_size='8pt', angle=np.pi/6.0)

    text_aligns = ['left', 'right']
    for i in [0, 1]:
        label_source = bokeh.models.ColumnDataSource(tokens.iloc[i::2])
        labels = bokeh.models.LabelSet(x='xs', y='ys', text_align=text_aligns[i], text='token', text_baseline='middle',
                          y_offset=5*(1 if i == 0 else -1),
                          x_offset=5*(1 if i == 0 else -1),
                          source=label_source, **label_style)
        p.add_layout(labels)

    p.xaxis[0].axis_label = 'Token #'
    p.yaxis[0].axis_label = 'Probability%'
    p.ygrid.grid_line_color = None
    p.xgrid.grid_line_color = None
    p.axis.axis_line_color = None
    p.axis.major_tick_line_color = None
    p.axis.major_label_text_font_size = "6pt"
    p.axis.major_label_standoff = 0
    return p

def display_topic_tokens(state, topic_id=0, n_words=100, output_format='Chart', gui=None):
    
    def tick(n=None):
        gui.progress.value = (gui.progress.value + 1) if n is None else n
        
    if gui.n_topics != state.num_topics:
        gui.n_topics = state.num_topics
        gui.topic_id.value = 0
        gui.topic_id.max=state.num_topics - 1
        
    tick(1)
    
    tokens = topic_model_utility.get_topic_tokens(state.processed.topic_token_weights, topic_id=topic_id, n_tokens=n_words).\
        copy()\
        .drop('topic_id', axis=1)\
        .assign(weight=lambda x: 100.0 * x.weight)\
        .sort_values('weight', axis=0, ascending=False)\
        .reset_index()\
        .head(n_words)
    
    if output_format == 'Chart':
        tick()
        tokens = tokens.assign(xs=tokens.index, ys=tokens.weight)
        p = plot_topic_word_distribution(tokens, plot_width=1200, plot_height=500, title='', tools='box_zoom,wheel_zoom,pan,reset')
        bokeh.plotting.show(p)
        tick()
    else:
        display(tokens)
        
    tick(0)
    
def display_topic_distribution_gui(state):
    
    text_id = 'wc01'
    output_options = ['Chart', 'Table']
    
    gui = widgets_utility.WidgetUtility(
        n_topics=state.num_topics,
        text_id=text_id,
        text=widgets_config.text(text_id),
        topic_id=widgets.IntSlider(description='Topic ID', min=0, max=state.num_topics - 1, step=1, value=0),
        n_words=widgets.IntSlider(description='#Words', min=5, max=500, step=1, value=75),
        output_format=widgets.Dropdown(description='Format', options=output_options, value=output_options[0], layout=widgets.Layout(width="200px")),
        progress=widgets.IntProgress(min=0, max=4, step=1, value=0, layout=widgets.Layout(width="95%"))
    )

    gui.prev_topic_id = gui.create_prev_id_button('topic_id', state.num_topics)
    gui.next_topic_id = gui.create_next_id_button('topic_id', state.num_topics)

    iw = widgets.interactive(
        display_topic_tokens,
        state=widgets.fixed(state),
        topic_id=gui.topic_id,
        n_words=gui.n_words,
        output_format=gui.output_format,
        gui=widgets.fixed(gui)
    )

    display(widgets.VBox([
        gui.text,
        widgets.HBox([gui.prev_topic_id, gui.next_topic_id, gui.topic_id, gui.n_words, gui.output_format]),
        gui.progress,
        iw.children[-1]
    ]))

    iw.update()

try:
    display_topic_distribution_gui(current_state())
except Exception as ex:
    logger.error(ex)
    


## <span style='color: green;'>VISUALIZE</span> Display Topic's Trend Over Time or Documents<span style='color: red; float: right'>TRY IT</span>
- Displays topic's share over documents.


In [None]:
# Plot a topic's yearly weight over time in selected LDA topic model
import math

def plot_topic_trend(df, category_column, value_column, x_label=None, y_label=None, **figopts):
    
    xs = df[category_column].astype(np.str)
    ys = df[value_column]
    
    figopts = utility.extend(dict(title='', toolbar_location="right"), figopts)
    
    p = bokeh.plotting.figure(**figopts)

    glyph = p.vbar(x=xs, top=ys, width=0.5, fill_color="#b3de69")
    
    p.xaxis.major_label_orientation = math.pi/4
    p.xgrid.grid_line_color = None
    p.xaxis[0].axis_label = (x_label or category_column.title().replace('_', ' ')).title()
    p.yaxis[0].axis_label = (y_label or value_column.title().replace('_', ' ')).title()
    p.y_range.start = 0.0
    p.x_range.range_padding = 0.01
    
    return p

def display_topic_trend(
    state,
    topic_id,
    year,
    year_aggregate,
    threshold=0.01,
    output_format='Chart',
    topic_changed=utility.noop
):
    figopts = dict(plot_width=1000, plot_height=700, title='', toolbar_location="right")
    
    document_topic_weights = state.processed.document_topic_weights

    topic_changed(topic_id)
    
    # FIXME VARYING ASPECT: name 'signed_year'
    year_column = 'year'
    
    pivot_column = year_column if year is None else None
    value_column = year_aggregate if year is None else 'weight'

    df = document_topic_weights[(document_topic_weights.topic_id == topic_id)]
    # FIXME MISSING YEAR IN FILENAME HACK
    df = df[(df[year_column] > 0)]
    
    if year is not None:
        # FIXME VARYING ASPECT: name 'signed_year'
        df = df[(df[year_column] == year)]
        
    df = df[(df.weight > threshold)].reset_index()
    
    if len(df) == 0:
        print('NO DATA')
        return
    
    if year is None:
        
        min_year, max_year = df[year_column].min(), df[year_column].max()
        figopts['x_range'] = list(map(str, range(min_year, max_year+1))) # utility.complete_value_range(df[category_column].unique(), str)
        
        df = df.groupby([year_column, 'topic_id']).agg([np.mean, np.max])['weight'].reset_index()
        df.columns = [year_column, 'topic_id', 'mean', 'max']
        category_column = year_column

    else:
        # FIXME: Varying ASPECTS
        category_column = 'document_name'
        df[category_column] = df.filename # df.treaty_id + ' ' + df.party1 + ' ' + df.party2
        figopts['x_range'] = df[category_column].unique()
        
    if output_format == 'Table':
        display(df)
    else:
        p = plot_topic_trend(df, category_column, value_column, **figopts)
        bokeh.plotting.show(p)

def display_topic_trend_gui(state):
    
    year_low, year_high = int(state.processed.year_period[0]), int(state.processed.year_period[1])
    year_options = [ ('all years', None) ] + [ (str(x), x) for x in range(year_low, year_high + 1)]
    
    text_id = 'topic_share_plot'
    
    gui = widgets_utility.WidgetUtility(
        n_topics=state.num_topics,
        text_id=text_id,
        text=widgets_config.text(text_id),
        year=widgets.Dropdown(description='Year', options=year_options, value=None),
        year_aggregate=widgets.Dropdown(description='Aggregate', options=['mean', 'max'], value='max', layout=widgets.Layout(width="160px")),
        threshold=widgets.FloatSlider(description='Threshold', min=0.0, max=0.25, step=0.01, value=0.10, continuous_update=False),
        topic_id=widgets.IntSlider(description='Topic ID', min=0, max=state.num_topics - 1, step=1, value=0, continuous_update=False),
        output_format=widgets.Dropdown(description='Format', options=['Chart', 'Table'], value='Chart', layout=widgets.Layout(width="160px")),
        progress=widgets.IntProgress(min=0, max=4, step=1, value=0, layout=widgets.Layout(width="340px")),
    )
    
    gui.prev_topic_id = gui.create_prev_id_button('topic_id', state.num_topics)
    gui.next_topic_id = gui.create_next_id_button('topic_id', state.num_topics)
    
    def on_topic_changed(topic_id):
        try:
            if gui.n_topics != state.num_topics:
                gui.n_topics = state.num_topics
                gui.topic_id.value = 0
                gui.topic_id.max = state.num_topics - 1

            tokens = topic_model_utility.get_topic_title(state.processed.topic_token_weights, topic_id, n_tokens=200)
            gui.text.value = 'ID {}: {}'.format(topic_id, tokens)
        except:
            gui.text.value = 'ID {}: NO DATA'.format(topic_id)
            
    iw = widgets.interactive(
        display_topic_trend,
        state=widgets.fixed(state),
        topic_id=gui.topic_id,
        year=gui.year,
        year_aggregate=gui.year_aggregate,
        threshold=gui.threshold,
        output_format=gui.output_format,
        topic_changed=widgets.fixed(on_topic_changed)
    )

    display(widgets.VBox([
        gui.text,
        widgets.HBox([gui.prev_topic_id, gui.next_topic_id, gui.year, gui.year_aggregate, gui.output_format]),
        widgets.HBox([gui.topic_id, gui.threshold, gui.progress]),
        iw.children[-1]
    ]))
    
    iw.update()

try:
    display_topic_trend_gui(current_state())
except Exception as ex:
    logger.error(ex)

## <span style='color: green;'>VISUALIZE</span> Display Topic to Document Network<span style='color: red; float: right'>TRY IT</span>
The green nodes are documents, and blue nodes are topics. The edges (lines) indicates the strength of a topic in the connected document. The width of the edge is proportinal to the strength of the connection. Note that only edges with a strength above the certain threshold are displayed.

In [None]:
# Visualize year-to-topic network by means of topic-document-weights
from common.plot_utility import layout_algorithms, PlotNetworkUtility
import domain_logic_vatican as domain_logic
import gui_utility
from common.network_utility import NetworkUtility, DISTANCE_METRICS, NetworkMetricHelper

def plot_document_topic_network(network, layout, scale=1.0, titles=None):
    tools = "pan,wheel_zoom,box_zoom,reset,hover,previewsave"
    year_nodes, topic_nodes = NetworkUtility.get_bipartite_node_set(network, bipartite=0)  
    
    year_source = NetworkUtility.get_node_subset_source(network, layout, year_nodes)
    topic_source = NetworkUtility.get_node_subset_source(network, layout, topic_nodes)
    lines_source = NetworkUtility.get_edges_source(network, layout, scale=6.0, normalize=False)
    
    edges_alphas = NetworkMetricHelper.compute_alpha_vector(lines_source.data['weights'])
    
    lines_source.add(edges_alphas, 'alphas')
    
    p = bokeh.plotting.figure(plot_width=1000, plot_height=600, x_axis_type=None, y_axis_type=None, tools=tools)
    
    r_lines = p.multi_line(
        'xs', 'ys', line_width='weights', alpha='alphas', color='black', source=lines_source
    )
    r_years = p.circle(
        'x','y', size=40, source=year_source, color='lightgreen', level='overlay', line_width=1,alpha=1.0
    )
    
    r_topics = p.circle('x','y', size=25, source=topic_source, color='skyblue', level='overlay', alpha=1.00)
    
    p.add_tools(bokeh.models.HoverTool(renderers=[r_topics], tooltips=None, callback=widgets_utility.wf.\
        glyph_hover_callback(topic_source, 'node_id', text_ids=titles.index, text=titles, element_id='nx_id1'))
    )

    text_opts = dict(x='x', y='y', text='name', level='overlay', x_offset=0, y_offset=0, text_font_size='8pt')
    
    p.add_layout(
        bokeh.models.LabelSet(
            source=year_source, text_color='black', text_align='center', text_baseline='middle', **text_opts
        )
    )
    p.add_layout(
        bokeh.models.LabelSet(
            source=topic_source, text_color='black', text_align='center', text_baseline='middle', **text_opts
        )
    )
    
    return p
        
def display_document_topic_network(
    layout_algorithm,
    state,
    threshold=0.10,
    document_filters=None,
    #parties=None,
    #period=None,
    ignores=None,
    scale=1.0,
    output_format='network',
    document_index=None,
    tick=utility.noop
):

    tick(1)
    
    corpus = current_corpus()
    
    corpus_docs = { x._.meta['document_id'] : x for x in gui_utility.get_documents_by_field_filters(corpus, document_index, document_filters) }
    
    topic_token_weights = state.processed.topic_token_weights
    document_topic_weights = state.processed.document_topic_weights
    
    titles = topic_model_utility.get_topic_titles(topic_token_weights)

    df = document_topic_weights[document_topic_weights.weight > threshold].reset_index()
    
    df = df[df.document_id.isin(list(corpus_docs.keys()))]
    # FIXME VARYING ASPECT: filters
    #if len(parties or []) > 0:
    #    df = df[df.party1.isin(parties) | df.party2.isin(parties)]

    #if len(period or []) == 2:
    #    df = df[(df.signed_year>=period[0]) & (df.signed_year<=period[1])]
        
    if len(ignores or []) > 0:
        df = df[~df.topic_id.isin(ignores)]

    df['weight'] = utility.clamp_values(list(df.weight), (0.1, 2.0))

    if len(df) == 0:
        print('No data')
        return
    
    # FIXME VARYING ASPECT: filters
    df['title'] = df.filename # df.treaty_id + ' ' + df.party1 + ' ' + df.party2

    network = NetworkUtility.create_bipartite_network(df, 'title', 'topic_id')
    tick()

    if output_format == 'network':
        args = PlotNetworkUtility.layout_args(layout_algorithm, network, scale)
        layout = (layout_algorithms[layout_algorithm])(network, **args)
        tick()
        p = plot_document_topic_network(network, layout, scale=scale, titles=titles)
        bokeh.plotting.show(p)

    elif output_format == 'table':
        display(df)

    tick(0)
        
def document_topic_network_gui(document_index, state, filter_options):
    
    lw = lambda w: widgets.Layout(width=w)
    
    text_id = 'nx_id1'
    layout_options = [ 'Circular', 'Kamada-Kawai', 'Fruchterman-Reingold']
    #party_preset_options = wti_index.get_party_preset_options()
    #parties_options = [ x for x in wti_index.get_countries_list() if x not in ['ALL', 'ALL OTHER'] ]
    year_min, year_max = state.processed.year_period
    
    n_topics = state.num_topics
    document_filters = gui_utility.generate_field_filters(document_index, filter_options)
    gui = types.SimpleNamespace(
        document_filters=document_filters,
        #group_by_columns=widgets.Dropdown(description='Group by', value=group_by_options[0][1], options=group_by_options, layout=lw('200px')),
        text=widgets_config.text(text_id),
        #period=widgets.IntRangeSlider(description='Time', min=year_min, max=year_min+5, step=1, value=(year_min, year_max), continues_update=False),
        scale=widgets.FloatSlider(description='Scale', min=0.0, max=1.0, step=0.01, value=0.1, continues_update=False),
        threshold=widgets.FloatSlider(description='Threshold', min=0.0, max=1.0, step=0.01, value=0.50, continues_update=False),
        output_format=widgets_utility.dropdown('Output', { 'Network': 'network', 'Table': 'table' }, 'network', layout=lw('200px')),
        layout=widgets_utility.dropdown('Layout', layout_options, 'Fruchterman-Reingold', layout=lw('250px')),
        #parties=widgets.SelectMultiple(description='Parties', options=parties_options, value=['FRANCE'], rows=7, layout=lw('180px')),
        #party_preset=widgets_config.dropdown('Presets', party_preset_options, None, layout=lw('180px')),
        progress=widgets.IntProgress(min=0, max=4, step=1, value=0, layout=widgets.Layout(width="95%")),
        ignores=widgets.SelectMultiple(description='Ignore', options=[('', None)] + [ ('Topic #'+str(i), i) for i in range(0, n_topics) ], value=[], rows=8, layout=lw('200px')),
        compute=widgets.Button(description='Compute', layout=lw('120px')),
        output=widgets.Output(layout={'border': '1px solid black'})
    )
    
    def tick(x=None):
        gui.progress.value = gui.progress.value + 1 if x is None else x
        
    #def on_party_preset_change(change):  # pylint: disable=W0613
    #    if gui.party_preset.value is None:
    #        return
    #    gui.parties.value = gui.parties.options if 'ALL' in gui.party_preset.value else gui.party_preset.value
            
    #gui.party_preset.observe(on_party_preset_change, names='value')
    
    def compute_callback_handler(*_args):
        gui.output.clear_output()
        with gui.output:
            display_document_topic_network(
                layout_algorithm=gui.layout.value,
                state=state,
                threshold=gui.threshold.value,
                #parties=gui.parties,
                document_filters=[ (x['field'], x['widget'].value) for x in gui.document_filters],
                #period=gui.period,
                ignores=gui.ignores.value,
                scale=gui.scale.value,
                output_format=gui.output_format.value,
                document_index=document_index,
                tick=tick
            )

    display(widgets.VBox([
        widgets.HBox([
            widgets.VBox([gui.layout, gui.threshold, gui.scale ]),  # , gui.period]), 
            widgets.VBox([ x['widget'] for x in gui.document_filters]),
            #widgets.VBox([gui.parties, gui.party_preset]), 
            widgets.VBox([gui.ignores, gui.output_format]), 
            widgets.VBox([gui.compute, gui.progress]),
        ]),
        gui.output,
        gui.text,
    ]))
    
    gui.compute.on_click(compute_callback_handler)

    #iw.update()

try:
    document_index = domain_logic.compile_documents(current_corpus())
    document_topic_network_gui(
        document_index,
        current_state(),
        filter_options=domain_logic.DOCUMENT_FILTERS
    )
except Exception as ex:
    logger.error(ex)


## <span style='color: green;'>VISUALIZE</span> Topic Trends Overview<span style='color: red; float: right'>TRY IT</span>

- The topic shares  displayed as a scattered heatmap plot using gradient color based on topic's weight in document.
- [Stanford’s Termite software](http://vis.stanford.edu/papers/termite) uses a similar visualization.

In [None]:
# plot_topic_relevance_by_year
import bokeh.transform

def get_topic_weight_by_year_or_document(document_topic_weights, key='mean', year=None):
    pivot_column = 'year' if year is None else 'document_id'
    #if df[(df.year == year)]
    df = self.get_document_topic_weights(year) \
        .groupby([pivot_column,'topic_id']) \
        .agg(config.AGGREGATES[key])[['weight']].reset_index()
    return df, pivot_column
    
def setup_glyph_coloring(df):
    max_weight = df.weight.max()
    #colors = list(reversed(bokeh.palettes.Greens[9]))
    colors = ['#ffffff', '#f7fcf5', '#e5f5e0', '#c7e9c0', '#a1d99b', '#74c476', '#41ab5d', '#238b45', '#006d2c', '#00441b']
    mapper = bokeh.models.LinearColorMapper(palette=colors, low=0.0, high=1.0) # low=df.weight.min(), high=max_weight)
    color_transform = bokeh.transform.transform('weight', mapper)
    color_bar = bokeh.models.ColorBar(color_mapper=mapper, location=(0, 0),
                         ticker=bokeh.models.BasicTicker(desired_num_ticks=len(colors)),
                         formatter=bokeh.models.PrintfTickFormatter(format=" %5.2f"))
    return color_transform, color_bar

def compute_int_range_categories(values):
    categories = values.unique()
    if all(map(utility.isint, categories)):
        categories = sorted(list(map(int, categories)))
        return list(map(str, categories))
    else:
        return sorted(list(categories))

HEATMAP_FIGOPTS = dict(title="Topic heatmap", toolbar_location="right",  x_axis_location="above", plot_width=1000)

def plot_topic_relevance_by_year(df, xs, ys, flip_axis, titles, text_id, **figopts):

    line_height = 7
    if flip_axis is True:
        xs, ys = ys, xs
        line_height = 10

    x_range = compute_int_range_categories(df[xs])
    y_range = compute_int_range_categories(df[ys])
    
    color_transform, color_bar = setup_glyph_coloring(df)
    
    source = bokeh.models.ColumnDataSource(df)

    if x_range is not None:
        figopts['x_range'] = x_range

    if y_range is not None:
        figopts['y_range'] = y_range
        figopts['plot_height'] = max(len(y_range) * line_height, 500)
    
    p = bokeh.plotting.figure(**figopts)

    args = dict(x=xs, y=ys, source=source, alpha=1.0, hover_color='red')
    
    cr = p.rect(width=1, height=1, line_color=None, fill_color=color_transform, **args)

    p.x_range.range_padding = 0
    p.ygrid.grid_line_color = None
    p.xgrid.grid_line_color = None
    p.axis.axis_line_color = None
    p.axis.major_tick_line_color = None
    p.axis.major_label_text_font_size = "8pt"
    p.axis.major_label_standoff = 0
    p.xaxis.major_label_orientation = 1.0
    p.add_layout(color_bar, 'right')
    
    p.add_tools(bokeh.models.HoverTool(tooltips=None, callback=widgets_utility.WidgetUtility.glyph_hover_callback(
        source, 'topic_id', titles.index, titles, text_id), renderers=[cr]))
    
    return p

def display_doc_topic_heatmap(state, key='max', flip_axis=False, glyph='Circle', year=None, year_aggregate=None, output_format=None):
    try:

        titles = topic_model_utility.get_topic_titles(state.processed.topic_token_weights, n_tokens=100)
        
        df = state.processed.document_topic_weights.copy()

        if year is not None:
            df = df[(df.signed_year == year)]

        if year is None:
            
            ''' Display aggregate value grouped by year  '''
            df = df.groupby(['signed_year', 'topic_id']).agg([np.mean, np.max])['weight'].reset_index()
            df.columns = ['signed_year', 'topic_id', 'mean', 'max']
            df['weight'] = df[year_aggregate]
            df['signed_year'] = df.signed_year.astype(str)
            category_column = 'signed_year'
            
        else:
            ''' Display individual treaties for selected year  '''
            df['treaty'] = df.treaty_id + ' ' + df.party1 + ' ' + df.party2
            df = df[['treaty', 'treaty_id', 'topic_id', 'weight']]
            category_column = 'treaty'  
        
        df['document_id'] = df.index.astype(str)
        df['topic_id'] = df.topic_id.astype(str)
         
        if output_format.lower() == 'heatmap':
            
            p = plot_topic_relevance_by_year(
                df,
                xs=category_column,
                ys='topic_id',
                flip_axis=flip_axis,
                titles=titles,
                text_id='topic_relevance',
                **HEATMAP_FIGOPTS)

            bokeh.plotting.show(p)
            
        else:
            display(df)
        
    except Exception as ex:
        raise
        logger.error(ex)
        
def doc_topic_heatmap_gui(state):

    lw = lambda w: widgets.Layout(width=w)
    
    text_id = 'topic_relevance'
    
    year_min, year_max = state.processed.year_period
    year_options = [ ('all years', None) ] + [ (x,x) for x in range(year_min, year_max + 1)]
    
    gui = types.SimpleNamespace(
        text_id=text_id,
        text=widgets_config.text(text_id),
        flip_axis=widgets.ToggleButton(value=True, description='Flip', icon='', layout=lw("80px")),
        year=widgets.Dropdown(description='Year', options=year_options, value=None, layout=lw("160px")),
        year_aggregate=widgets.Dropdown(description='Aggregate', options=['mean', 'max'], value='max', layout=lw("160px")),
        output_format=widgets.Dropdown(description='Output', options=['Heatmap', 'Table'], value='Heatmap', layout=lw("180px"))
    )
    
    iw = widgets.interactive(
        display_doc_topic_heatmap,
        state=widgets.fixed(state),
        flip_axis=gui.flip_axis,
        year=gui.year,
        year_aggregate=gui.year_aggregate,
        output_format=gui.output_format
    )

    display(widgets.VBox([
        widgets.HBox([gui.year, gui.year_aggregate, gui.output_format, gui.flip_axis ]),
        widgets.HBox([iw.children[-1]]), gui.text
    ]))

    iw.update()

try:
    doc_topic_heatmap_gui(current_state())
except Exception as ex:
    logger.error(ex)


## <span style='color: green;'>VISUALIZE</span> Topic Cooccurrence<span style='color: red; float: right'>TRY IT</span>

Computes weighted graph of topics co-occurring in the same document. Topics are defined as co-occurring if they both exists  in the same document both having weights above threshold. Weight are number of co-occurrences (binary yes or no). Node size reflects topic proportions over the entire corpus (normalized document) length, and are computed in accordance to how node sizes are computed in LDAvis.

In [None]:
# Visualize topic co-occurrence

import common.plot_utility as plot_utility
import common.network_utility as network_utility
import bokeh.plotting # import figure, show, output_notebook, output_file

bokeh.plotting.output_notebook()

def get_topic_titles(topic_token_weights, topic_id=None, n_words=100):
    df_temp = topic_token_weights if topic_id is None else topic_token_weights[(topic_token_weights.topic_id==topic_id)]
    df = df_temp\
            .sort_values('weight', ascending=False)\
            .groupby('topic_id')\
            .apply(lambda x: ' '.join(x.token[:n_words].str.title()))
    return df

# FIXME: add doc token length to df_documents
def get_topic_proportions(corpus_documents, document_topic_weights):
    topic_proportion = topic_model.compute_topic_proportions(document_topic_weights, corpus_documents)
    return topic_proportion
    
def display_topic_co_occurrence_network(
    tm_data,
    parties=None,
    period=None,
    ignores=None,
    threshold=0.10,
    layout='Fruchterman-Reingold',
    scale=1.0,
    output_format='table'
):
    try:
        
        model_data = tm_data.compiled_data
        
        titles = topic_model_utility.get_topic_titles(model_data.topic_token_weights)
        df = model_data.document_topic_weights
        df['document_id'] = df.index
        
        node_sizes = topic_model.compute_topic_proportions(df, model_data.documents)

        if ignores is not None:
            df = df[~df.topic_id.isin(ignores)]
            
        if len(parties or []) > 0:
            df = df[df.party1.isin(parties) | df.party2.isin(parties)]
            
        if period is not None:
            df = df[df.signed_year.between(period[0], period[1], inclusive=True)]
            
        df = df.loc[(df.weight >= threshold)]
        df = pd.merge(df, df, how='inner', left_on='document_id', right_on='document_id')
        df = df.loc[(df.topic_id_x < df.topic_id_y)]
        df = df.groupby([df.topic_id_x, df.topic_id_y]).size().reset_index()
        df.columns = ['source', 'target', 'weight']
        
        if len(df) == 0:
            print('No data. Please change selections.')
            return
        
        if output_format == 'table':
            display(df)
        else:
            network = network_utility.NetworkUtility.create_network(df, source_field='source', target_field='target', weight='weight')
            p = plot_utility.PlotNetworkUtility.plot_network(
                network=network,
                layout_algorithm=layout,
                scale=scale,
                threshold=0.0,
                node_description=titles,
                node_proportions=node_sizes,
                weight_scale=10.0,
                normalize_weights=True,
                element_id='cooc_id',
                figsize=(900,500)
            )
            bokeh.plotting.show(p)

    except Exception as x:
        raise
        print("No data: please adjust filters")

def topic_coocurrence_network_gui(wti_index, tm_data):
    
    lw = lambda w: widgets.Layout(width=w)
    n_topics = tm_data.tm_model.num_topics
    
    model = tm_data.tm_model
    text_id = 'cooc_id'
    layout_options = [ 'Circular', 'Kamada-Kawai', 'Fruchterman-Reingold']
    party_preset_options = wti_index.get_party_preset_options()
    parties_options = [ x for x in wti_index.get_countries_list() if x != 'ALL OTHER' ]
    year_min, year_max = tm_data.compiled_data.year_period
    
    gui = types.SimpleNamespace(
        n_topics=n_topics,
        text=widgets_utility.wf.create_text_widget(text_id),
        period=widgets.IntRangeSlider(description='Time', min=year_min, max=year_max, step=1, value=(year_min, year_max), continues_update=False),
        scale=widgets.FloatSlider(description='Scale', min=0.0, max=1.0, step=0.01, value=0.1, continues_update=False),
        threshold=widgets.FloatSlider(description='Threshold', min=0.0, max=1.0, step=0.01, value=0.20, continues_update=False),
        output_format=widgets_utility.dropdown('Output', { 'Network': 'network', 'Table': 'table' }, 'network', layout=lw('200px')),
        layout=widgets_utility.dropdown('Layout', layout_options, 'Fruchterman-Reingold', layout=lw('250px')),
        parties=widgets.SelectMultiple(description='Parties', options=parties_options, value=[], rows=7, layout=lw('180px')),
        party_preset=widgets_config.dropdown('Presets', party_preset_options, None, layout=lw('180px')),
        progress=widgets.IntProgress(min=0, max=4, step=1, value=0, layout=widgets.Layout(width="99%")),
        ignores=widgets.SelectMultiple(description='Ignore', options=[('', None)] + [ ('Topic #'+str(i), i) for i in range(0, n_topics) ], value=[], rows=8, layout=lw('180px')),
    )
    def tick(x=None):
        gui.progress.value = gui.progress.value + 1 if x is None else x
        
    def on_party_preset_change(change):  # pylint: disable=W0613
        if gui.party_preset.value is None:
            return
        gui.parties.value = gui.parties.options if 'ALL' in gui.party_preset.value else gui.party_preset.value
            
    gui.party_preset.observe(on_party_preset_change, names='value')
     
    iw = widgets.interactive(
        display_topic_co_occurrence_network,
        tm_data=widgets.fixed(tm_data),
        parties=gui.parties,
        period=gui.period,
        ignores=gui.ignores,
        threshold=gui.threshold,
        layout=gui.layout,
        scale=gui.scale,
        output_format=gui.output_format
    )
    display(widgets.VBox([
        gui.text,
        widgets.HBox([
            widgets.VBox([gui.layout, gui.threshold, gui.scale, gui.period]), 
            widgets.VBox([gui.parties, gui.party_preset]), 
            widgets.VBox([gui.ignores]), 
            widgets.VBox([gui.output_format, gui.progress]),
        ]),
        iw.children[-1]
    ]))
    iw.update()
    
try:
    tm_data = get_current_model()
    topic_coocurrence_network_gui(WTI_INDEX, tm_data)
except Exception as ex:
    logger.error(ex)

## <span style='color: green'>EXPLORE </span> Topic Similarity <span style='float: right; color: red'>WORK IN PROGRESS</span>


#### <span style='color: green'>EXPLORE </span> Topic Similarity Network<span style='float: right; color: red'>WORK IN PROGRESS</span>
This plot displays topic similarity based on **euclidean or cosine distances** between the **topic-to-word vectors**. Please note that the computations can take some time to exceute, especially for larger LDA models.

In [None]:
# Visualization
import types

# if 'zy_data' not in globals():
zy_data = types.SimpleNamespace(
    basename=None,
    network=None,
    X_n_space=None,
    X_n_space_feature_names=None,
    distance_matrix=None,
    metric=None,
    topic_proportions=None,
    n_words = 0
)

def plot_clustering_dendogram(clustering):
    plt.figure(figsize=(16,6))
    # https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.cluster.hierarchy.dendrogram.html
    R = dendrogram(clustering)
    plt.show()
    plt.close()

def VectorSpaceHelper_compute_distance_matrix(X_n_space, metric='euclidean'):
    # https://se.mathworks.com/help/stats/pdist.html
    metric = metric.lower()
    if metric == 'kullback–leibler': metric = VectorSpaceHelper.kullback_leibler_divergence
    if metric == 'scipy.stats.entropy': metric = scipy.stats.entropy
    #print(metric)
    X = X_n_space.toarray() if hasattr(X_n_space, 'toarray') else X_n_space
    #X_n_space += 0.00001
    distances = distance.pdist(X, metric=metric)
    #print(distances)
    distance_matrix = distance.squareform(distances)
    #print(distance_matrix)    
    return distance_matrix
    
def display_correlation_network(
    layout_algorithm,
    threshold=0.10,
    scale=1.0,
    metric='Euclidean',
    n_words=200,
    output_format='Network'
):
    global state, zy_data, zy

    try:

        zy.progress.value = 1
        metric = DISTANCE_METRICS[metric]

        node_description = state.get_topics_tokens_as_text()
        node_proportions = state.get_topic_proportions()

        zy.progress.value = 2
        if zy_data.network is None or state.basename != zy_data.basename or zy_data.metric != metric or zy_data.n_words != n_words:

            zy_data.basename = state.basename
            zy_data.n_words = n_words
            zy_data.X_n_space, zy_data.X_n_space_feature_names = state.compute_topic_terms_vector_space(n_words=n_words)
            
            #print(zy_data.X_n_space.shape)
            #print(zy_data.X_n_space_feature_names)
            zy.progress.value = 3
            zy_data.distance_matrix = VectorSpaceHelper_compute_distance_matrix(zy_data.X_n_space, metric=metric)
            zy_data.network = None

        edges_data = VectorSpaceHelper.lower_triangle_iterator(zy_data.distance_matrix, threshold)

        zy.progress.value = 4
        if output_format == 'List':
            df = pd.DataFrame(edges_data, columns=['x', 'y', 'weight'])
            zy.progress.value = 5
            display(HTML(df.to_html()))
        else:
            zy.progress.value = 5
            if zy_data.network is None:
                zy_data.network = NetworkUtility.create_network_from_xyw_list(edges_data) # zy_data.distance_matrix)
            zy.progress.value = 6
            p = PlotNetworkUtility.plot_network(
                network=zy_data.network,
                layout_algorithm=layout_algorithm,
                scale=scale,
                threshold=threshold,
                node_description=node_description,
                node_proportions=node_proportions,
                element_id='nx_id3',
                figsize=(1000,600)
            )
            zy.progress.value = 6
            show(p)

        zy.progress.value = 7
        zy.progress.value = 0
    except Exception as ex:
        # logger.exception(ex)
        print('Error: {}'.format(ex))
        print('Empty set: please change filters')
        zy.progress.value = 0

zy = widgets_utility.WidgetUtility(
    n_topics=state.n_topics,
    text_id='nx_id3',
    text=wf.create_text_widget('nx_id3'),
    scale=wf.create_float_slider('Scale', min=0.0, max=1.0, step=0.01, value=0.1),
    year=wf.create_int_slider(
        description='Year', min=state.min_year, max=state.max_year, step=1, value=state.min_year
    ),
    n_words=wf.create_int_slider(description='#words*', min=10, max=500, step=1, value=20),
    metric=wf.create_select_widget(label='Metric*', values=list(DISTANCE_METRICS.keys()), default='Euclidean'),
    threshold=wf.create_float_slider('Threshold', min=0.0, max=1.0, step=0.01, value=0.01),
    output_format=wf.create_select_widget('Format', ['Network', 'List'], default='Network'),
    layout=wf.create_select_widget('Layout', list(layout_algorithms.keys()), default='Fruchterman-Reingold'),
    progress=wf.create_int_progress_widget(min=0, max=7, step=1, value=0, layout=widgets.Layout(width="90%"))
) 
    
wy = widgets.interactive(
    display_correlation_network,
    layout_algorithm=zy.layout,
    threshold=zy.threshold,
    scale=zy.scale,
    metric=zy.metric,
    n_words=zy.n_words,
    output_format=zy.output_format
)

display(widgets.VBox(
    (zy.text, ) +
    (widgets.HBox((zy.threshold,) + (zy.metric,) + (zy.output_format,)),) +
    (widgets.HBox((zy.n_words,) + (zy.layout,) + (zy.scale,)),) +
    (zy.progress,) +
    (wy.children[-1],)))

wy.update()
