# Topic Clustering Experiment
## Overview
Uses data created from [download-reddit-data](./download-reddit-data.ipynb) and [create-embeddings](./create-embeddings.ipynb) notebooks to
 * generate topic clusters from embedding vectors using a clustering algorithm from *scikit-learn*
 * retrieve representative text for each cluster
 * query LLM for topic description texts (~tags) for each cluster using those texts
 * connect the topics to a topic graph

In [None]:
import pickle
from tqdm.notebook import tqdm
import chromadb
from langchain_community.llms import Ollama
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from adjustText import adjust_text

## Experiment Data
### Input/Output Setup
Loads *pickled* data from the [create-embeddings](./create-embeddings.ipynb) notebook.

In [None]:
SPLITS_FILE = "reddit-splits-1000-200.pickle"
VECS_FILE = "reddit-vecs-1000-200.pickle"

### Load

In [None]:
with open(SPLITS_FILE, "rb") as file:
    splits = pickle.load(file)
with open(VECS_FILE, "rb") as file:
    vecs = pickle.load(file)

## Clustering
Generate *topic clusters* from embedding vectors using a clustering algorithm from *scikit-learn*.

**Please note**: the ability to cluster the embedding vectors might somehow depend on how the embedding function works. For example, there are embeddings known to work better with certain similarity measures like *cosine similarity*.

**Experiment potential**: the generated clusters also heavily depend on the used algorithm and algorithm-specific parameters. It's also possible to influence the number of output topics in this step. The provided code is only some naive first shot to pull together everything.

https://scikit-learn.org/stable/modules/clustering.html#clustering provides a good overview of clustering algorithms available in *scikit-learn*.

In [None]:
# MeanShift
from sklearn.cluster import MeanShift, estimate_bandwidth
bandwidth = estimate_bandwidth(vecs, quantile=0.2, n_samples=len(vecs))
clusterer = MeanShift(bandwidth=bandwidth*1.1, n_jobs=-1, cluster_all=False)

# some alternatives for clustering

#from sklearn.cluster import DBSCAN
#clusterer = DBSCAN(eps=.5, min_samples=3)

#from sklearn.cluster import KMeans
#clusterer = KMeans(n_clusters=10)

clusters = clusterer.fit(vecs)

labels = clusters.labels_

# Number of clusters in labels, ignoring noise if present.
n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
n_noise_ = list(labels).count(-1)

print("Estimated number of clusters: %d" % n_clusters_)
print("Estimated number of noise points: %d" % n_noise_)
print(f"Labels: {set(labels)}")

## Topic Creation
The general idea is to map back topic labels (output of clustering) to text, which serve a prompt input for some LLM. The LLM is asked to produce small topic headlines (~tags) for each cluster.

### Retrieve Text for Topics
To retrieve *representative text* given a topic cluster, the natural choice is a vector DB. In this notebook, *ChromaDB* is used: https://docs.trychroma.com/

The approach is as follows:
 * Init a ChromaDB and index all document *splits* using their previously computed embeddings
 * Compute vectors representing each topic cluster (centroid of all vectors forming the topic/cluster)
 * Query text from ChromaDB for each vector representing a topic cluster

#### ChromaDB Text Indexing

In [None]:
# get back texts for labeled vecs stored in vector store
chroma_client = chromadb.EphemeralClient()
collection = chroma_client.create_collection(name="docs")
ids=[str(i) for i in range(len(splits))]
collection.add(
    documents=[d.page_content for d in splits],
    embeddings=vecs,
    # trouble with some metadata types
    #metadatas=[d.metadata for d in splits],
    ids=ids
)
print(f"{collection.count()} docs added to Chroma")

#### Representative Topic Vectors
**Experimentation potential**: computing the centroid of vectors belonging to a cluster is just one possible approach to tackle this problem. In particular, it keeps this experiment notebook simple, but might not be the best solution to this, depending on the clustering algorithm used and also depending on the produced cluster output. For example, one could also define some kind of boundary volume and collect representative vectors from this to be used in the next step.

In [None]:
unique_labels = set([l for l in labels if l >= 0])
# get representative vec for all labels (centroid)
representatives = []
for label in unique_labels:
    vecs_with_label = np.array([v for ix, v in zip(range(len(vecs)), vecs) if labels[ix] == label])
    centroid = np.mean(vecs_with_label, axis=0).tolist()
    representatives.append(centroid)

#### Representative Topic Text
**Experimentation potential**: the following code retrieves exactly one document split per topic vector. This inherently has some issues:
 * multiple splits could be concatenated to fill the LLM context window more efficiently
 * picking one split might yield bad results, if this split is not very elaborated (for example, it contains headlines only in some cases)
 * multiple splits could be picked from multiple vectors for more variance (see above)

In [None]:
representative_docs = collection.query(
    query_embeddings=representatives,
    n_results=1,
    include=["documents"]
)

### Generate Topic Text
Use representative text and some system prompt to let LLM generate *topc descriptions* per cluster.

**Experimentation potential**:
 * Different LLMs
 * Different system prompts
 * Combine multiple representative text (see above)

In [None]:
# ask LLM for summary/tag
llm = Ollama(model="llama3")
prompt = ChatPromptTemplate.from_messages([
    ("system", "Summarize in maximum three words. No other output. No punctuation."),
    ("user", "{input}")
])
output_parser = StrOutputParser()
chain = prompt | llm | output_parser

tags = [chain.invoke({ "input": docs[0] }) for docs in tqdm(representative_docs["documents"])]
print(tags)

## Topic Graph
Up to this step, we have a set of topic clusters associated with their respective topic vector in embedding space and generated topic text. The overall aim is to structure them in some kind of knowledge graph.

Natural candidates to solve this are graph algorithms. Each topic (text + vector) can be represented as nodes. The vectors can be used to compute some kind of edge *costs* in a fully connected undirected graph (complete graph), for example by utilizing some *similarity measure*, similarly to what a vector database does for retrieving relevant documents given a query vector.

### Create
This notebook computes a *Minimum Spanning Tree* (MST) using *SciPy* and using the general approach described above. Edge costs are modeled as node *similarities* using *cosine measure*.

**Experimentation potential**:
 * MST is only a naive first shop. Obviously, this will output a *tree*, but a more generic *graph* representation might be better suitable
 * Different models for edge costs

In [None]:
# create graph from tags and vecs
from scipy.sparse.csgraph import minimum_spanning_tree

# cosine sim for each (vec, other) with other in vecs
def similarities(vec, vecs):
    return np.array([np.dot(vec, other_vec) / (np.linalg.norm(vec) * np.linalg.norm(other_vec)) for other_vec in vecs])

adjacency_matrix = np.array([1 - similarities(vec, representatives) for vec in representatives])
# scale to 1
adjacency_matrix /= np.max(adjacency_matrix)

mst = minimum_spanning_tree(adjacency_matrix)
edges = (mst.toarray() > 0).astype(int)

### Vizualize

In [None]:
G = nx.from_numpy_array(edges)
nodes = { label: tag for label, tag in zip(unique_labels, tags) }

# Draw the graph
plt.figure(figsize=(12, 9))
pos = nx.spring_layout(G)  # Positioning of nodes
# Draw nodes and edges
nx.draw_networkx_nodes(G, pos, node_color='lightblue', node_size=250)
nx.draw_networkx_edges(G, pos, edge_color='gray')

# Draw labels
texts = []
for node, (x, y) in pos.items():
    texts.append(plt.text(x, y, nodes[node], fontsize=8, ha='center', va='center'))

# Adjust text to avoid overlap
adjust_text(texts)

# Set the title
plt.title("Topic network")

plt.show()