In [None]:
import os
# import re
import numpy as np
import pandas as pd

import pickle
import textwrap # for wrapping text for plotly charts

from sklearn.feature_extraction import text #to access stop words
from sklearn.feature_extraction.text import TfidfVectorizer
import tensorflow_hub as hub

from sklearn.manifold import TSNE
import umap

import hdbscan
from sklearn.metrics import silhouette_score

from bokeh.plotting import figure, output_notebook, show, ColumnDataSource
from bokeh.models import HoverTool, CategoricalColorMapper, LinearColorMapper
from bokeh.palettes import d3, Magma256
output_notebook()
from bokeh.io import save

import plotly.express as px

## define functions

In [None]:
dir_name = "/project/data_preprocessed_csv/"
dir_pickle = "/project/pickled_vectors/"
df_metadata = pd.read_csv("/project/preprocessing/metadata.csv", index_col=0)


class QuestionClusterizer:
    def __init__(self, vectorizer, dim_reducer=umap.UMAP, random_state = 0):
        self.vectorizer = vectorizer
        self.filenames = None
        self.corpus = None
        self.df_metadata = pd.read_csv(
            "/project/preprocessing/metadata.csv", index_col=0
        )
        self.dim_reducer = dim_reducer
        self.random_state = random_state
        self.title = "Visualising questions from a set of depositions"
        self.dir_pickle = "/project/pickled_vectors/"

    def create_filenames_from_metadata(self):
        # i only include cases with multiple depositions
        # to change this, change the 1 to a 0 in the second line
        counts = self.df_metadata.groupby("case").case.count()
        cases = counts[counts > 1].index.values
        
        for i, case in enumerate(cases):
            print(f"{i:2}: {case}")
        
        case_selection = int(input("Select a case by providing the number: "))
        
        self.filenames = self.df_metadata.loc[
            self.df_metadata.case == cases[case_selection], "filename"
        ].values

    def create_corpus(self):
        self.corpus = []
        for filename in self.filenames:
            df = pd.read_csv(dir_name + filename[:-3] + "csv", index_col=0)
            df = df[df.text_type.isin(["q"])]
            try:
                self.corpus += df.text.values.tolist()
            except:
                print(filename)

    def vectorize(self):
        # do over-simplistic check for universal sentence embedding vectoriser
        if "tensorflow" in str(type(self.vectorizer)):
            self.vectors = self.vectorizer(self.corpus)
        else:
            # else assume we have vectoriser with fit_transform method
            self.vectors = self.vectorizer.fit_transform(self.corpus).toarray()

    def reduce_dimensions(self):
        self.vectors_dim_reduced = self.dim_reducer(
            random_state=self.random_state
        ).fit_transform(self.vectors)

    def clusterize(self):
        hdbscan_clusterer = hdbscan.HDBSCAN()
        hdbscan_clusterer.fit(self.vectors_dim_reduced)
        self.clusters = hdbscan_clusterer.labels_

    def score_clusters(self):
        # hdbscan has label of -1 for points it considers to be 'noise' or not a part of a cluster
        # these have been removed for purposes of calculating silhouette.
        # to include noise, just remove the `> -1`
        indices = self.clusters > -1
        self.silhouette = silhouette_score(
            self.vectors_dim_reduced[indices], self.clusters[indices]
        )

    def plot_plotly(self, clusters=True, remove_noise = True):
        if not clusters:
            raise NotImplementedError

        df_plot = pd.DataFrame(self.vectors_dim_reduced, columns=["x", "y"])
        df_plot["Text"] = self.corpus
        df_plot["cluster"] = self.clusters
        
        # need to manually wrap text for plotly graphs, using html newline tags
        df_plot.Text = df_plot.Text.apply(
            lambda txt: "<br>".join(textwrap.wrap(txt, width=30))
        )
        
        if remove_noise:
            df_plot = df_plot[df_plot.cluster > -1]

        fig = px.scatter(
            df_plot,
            x="x",
            y="y",
            hover_data=dict(x=False, y=False, Text=True, cluster=False),
            width=600,
            height=600,
            color="cluster",
            color_continuous_scale="rainbow",
        )

        fig.update(layout_coloraxis_showscale=False)
        fig.update_layout(
            plot_bgcolor="rgba(0, 0, 0, 0)",
            title="Visualising all questions from depositions of a certain case",
        )

        fig.update_xaxes(
            linecolor="black",
            mirror=True,
            title=None,
            showticklabels=False,
            linewidth=2,
        )
        fig.update_yaxes(
            linecolor="black",
            mirror=True,
            title=None,
            showticklabels=False,
            linewidth=2,
        )

        fig.show()

    def plot_bokeh(self, clusters=True, remove_noise = True):
        if not clusters:
            raise NotImplementedError

            
        df_plot = pd.DataFrame(self.vectors_dim_reduced, columns=["x", "y"])
        df_plot["text"] = self.corpus
        df_plot["clusters"] = self.clusters
        
        if remove_noise:
            df_plot = df_plot[df_plot.clusters > -1]
            
        source = ColumnDataSource(
            data=dict(
                x=df_plot.x,
                y=df_plot.y,
                clusters=df_plot.clusters,
                text=df_plot.text,
            )
        )

        color_map = LinearColorMapper(
            palette="Turbo256", low=self.clusters.min(), high=self.clusters.max()
        )

        TOOLS = "box_zoom,hover,reset"
        p = figure(title=self.title + f". Silhouette: {self.silhouette}", tools=TOOLS)
        p.background_fill_color = "white"
        p.xgrid.grid_line_color = None
        p.ygrid.grid_line_color = None

        p.scatter(
            x="x",
            y="y",
            color={"field": "clusters", "transform": color_map},
            source=source,
        )

        hover = p.select(dict(type=HoverTool))
        hover.tooltips = [
            ("text", "@text"),
        ]

        show(p)

    def fit(self, use_metadata=True, filenames=None):
        if use_metadata:
            self.create_filenames_from_metadata()
        else:
            self.filenames = filenames

        self.create_corpus()
        self.vectorize()
        self.reduce_dimensions()
        self.clusterize()
        self.score_clusters()
        self.plot_plotly()

    def pickle_clusterizer(self, filename):
        # note that the vectorizer and dim_reducer are not stored - only the name of them are
        # stored. this is because pickle module cannot store them.
        info = dict(
            vectorizer=str(self.vectorizer),
            dim_reducer=str(self.dim_reducer),
            filenames=self.filenames,
            corpus=self.corpus,
            vectors=self.vectors,
            vectors_dim_reduced=self.vectors_dim_reduced,
            clusters=self.clusters,
        )
        with open(self.dir_pickle + filename, "wb") as f:
            pickle.dump(info, f)

    def load_pickle(self, filename):
        with open(self.dir_pickle + filename, "rb") as f:
            info = pickle.load(f)
        self.vectorizer = info["vectorizer"]
        self.dim_reducer = info["dim_reducer"]
        self.filenames = info["filenames"]
        self.corpus = info["corpus"]
        self.vectors = info["vectors"]
        self.vectors_dim_reduced = info["vectors_dim_reduced"]
        self.clusters = info["clusters"]

## testing clusterizer class

In [None]:
vectorizer = TfidfVectorizer(use_idf = True,
                            ngram_range = (2,2))

In [None]:
clusterizer = QuestionClusterizer(vectorizer, dim_reducer=TSNE)
clusterizer.fit()

In [None]:
# clusterizer.pickle_clusterizer('tfidf22_fnone_snone_hdbscan_tsne_cell.pkl')

## testing universal sentence embedding

In [None]:
embed = hub.load("https://tfhub.dev/google/universal-sentence-encoder/4")

In [None]:
clusterizer = QuestionClusterizer(vectorizer = embed)
clusterizer.fit()

In [None]:
filename = 'use_hdbscan.pkl'
# clusterizer.pickle_clusterizer(filename)