## Import Libraries

In [1]:
import pandas as pd
from sentence_transformers import SentenceTransformer
from umap import UMAP

In [3]:
df = pd.read_csv("../data/ushmm_encyclopedia_articles.csv")
df

Unnamed: 0,title,link,text
0,The 101st Airborne Division during World War II,https://encyclopedia.ushmm.org/content/en/arti...,101st Airborne Division Campaigns during World...
1,The 102nd Infantry Division,https://encyclopedia.ushmm.org/content/en/arti...,"Campaigns\nFormed in September 1942, the 102nd..."
2,The 103rd Infantry Division during World War II,https://encyclopedia.ushmm.org/content/en/arti...,103rd Infantry Division Campaigns during World...
3,The 104th Infantry Division during World War II,https://encyclopedia.ushmm.org/content/en/arti...,104th Infantry Division Campaigns during World...
4,The 10th Armored Division during World War II,https://encyclopedia.ushmm.org/content/en/arti...,10th Armored Division Campaigns during World W...
...,...,...,...
961,Zdziecioł (Zhetel),https://encyclopedia.ushmm.org/content/en/arti...,A Project of the Miles Lerman Center\nZdziecio...
962,Zeilsheim Displaced Persons Camp,https://encyclopedia.ushmm.org/content/en/arti...,Background\nZeilsheim was a displaced persons ...
963,Ziegenhain Displaced Persons Camp,https://encyclopedia.ushmm.org/content/en/arti...,Ziegenhain was a displaced persons (DP) camp i...
964,Zimbabwe: Overview,https://encyclopedia.ushmm.org/content/en/arti...,Introduction\nZimbabwe is a country of approxi...


## Encode Sentences

In [6]:
model = SentenceTransformer('nli-mpnet-base-v2')

articles = df["text"]

# Calculate embeddings 
X =  model.encode(articles)

## Dimensionality Reduction

In [7]:
# Reduce the dimensions with UMAP
umap = UMAP()
X_tfm = umap.fit_transform(X)

In [13]:
output_file = "../data/ushmm_encyclopedia_articles_coords.csv"

In [14]:
# Apply coordinates
df['x'] = X_tfm[:, 0]
df['y'] = X_tfm[:, 1]
df.to_csv(output_file, index=False)

## Function for these Two Steps

In [None]:
def create_coords_map(input_file, text_column, output_file):
    df = pd.read_csv(input_file)
    model = SentenceTransformer('nli-mpnet-base-v2')

    texts = df[text_column]

    # Calculate embeddings 
    X =  model.encode(texts)
    
    # Apply coordinates
    df['x'] = X_tfm[:, 0]
    df['y'] = X_tfm[:, 1]
    df.to_csv(output_file, index=False)

In [None]:
create_coords_map(
    input_file="../data/ushmm_encyclopedia_articles.csv",
    text_column="text",
    output_file = "../data/ushmm_encyclopedia_articles_coords.csv"

)

## Visualizing our Clusters

In [23]:
from bokeh.io import curdoc
from bokeh.layouts import column, row
from bokeh.models import (Button, ColumnDataSource, DataTable, TableColumn, TextInput)
from bokeh.plotting import figure, show
from bokeh.models import DataTable, TableColumn, ColorBar, HTMLTemplateFormatter, Spinner, RangeSlider
from bokeh.io import output_notebook
from bokeh.application import Application
from bokeh.application.handlers import FunctionHandler
import numpy as np

In [24]:
output_notebook()

In [30]:
def bulk_text(path, keywords=None):
    df = pd.read_csv(path)
    df['alpha'] = 0.5
    if keywords:
        df['color'] = [determine_keyword(str(t), keywords) for t in df['text']]
        df['alpha'] = [0.4 if c == 'none' else 1 for c in df['color']]

    highlighted_idx = []

    # mapper, df = get_color_mapping(df)
    columns = [
        TableColumn(field="text", title="text", width=500),
        TableColumn(field="title", title="title"),
        TableColumn(field="link", title="article", formatter=HTMLTemplateFormatter(template=r'<a href="<%= link %>", target="_blank">View Article</a>')),
    ]

    def update(attr, old, new):
        """Callback used for plot update when lasso selecting"""
        global highlighted_idx
        subset = df.iloc[new]
        highlighted_idx = new
        subset = subset.iloc[np.random.permutation(len(subset))]
        source.data = subset

    def save():
        """Callback used to save highlighted data points"""
        global highlighted_idx
        df.iloc[highlighted_idx][['text']].to_csv(text_filename.value, index=False)

    source = ColumnDataSource(data=dict())
    source_orig = ColumnDataSource(data=df)

    data_table = DataTable(source=source, columns=columns, width=1500, height=700)
    source.data = df

    p = figure(title="", sizing_mode="scale_both", tools=["lasso_select", "box_select", "pan", "box_zoom", "wheel_zoom", "reset"])
    p.toolbar.active_drag = None
    p.toolbar.active_inspect = None

    circle_kwargs = {"x": "x", "y": "y", "size": 1, "source": source_orig, "alpha": "alpha"}

    scatter = p.circle(**circle_kwargs)
    p.plot_width = 1000
    if "color" in df.columns:
        p.plot_width=350
    p.plot_height = 700
    ## Spinner for Node Size
    spinner = Spinner(title="Circle Size", low = 1, high=60, step=1, value=scatter.glyph.size, width=200)
    spinner.js_link("value", scatter.glyph, "size")

        
    scatter.data_source.selected.on_change('indices', update)

    text_filename = TextInput(value="out.csv", title="Filename:")
    save_btn = Button(label="SAVE")
    save_btn.on_click(save)

    plot = column(p)
    controls_main = column(spinner)
    controls = column(text_filename, save_btn)
    
    def make_doc(doc):
        doc.add_root(row(controls_main))
        doc.add_root(row(plot, controls))
        doc.add_root(row(data_table))
    handler = FunctionHandler(make_doc)
    app=Application(handler)
    return app


app = bulk_text(output_file)
show(app)