In [None]:
import pickle
from tqdm.notebook import tqdm
from sklearn.cluster import DBSCAN
from sklearn.cluster import KMeans
from sklearn.cluster import MeanShift, estimate_bandwidth
import chromadb
import numpy as np

In [None]:
#SPLITS_FILE = "splits-1000-200.pickle"
#VECS_FILE = "vecs-1000-200.pickle"
SPLITS_FILE = "reddit-splits-1000-200.pickle"
VECS_FILE = "reddit-vecs-1000-200.pickle"

with open(SPLITS_FILE, "rb") as file:
    splits = pickle.load(file)
with open(VECS_FILE, "rb") as file:
    vecs = pickle.load(file)

In [None]:
#clusters = DBSCAN(eps=.5, min_samples=3).fit(vecs)

#clusters = KMeans(n_clusters=10).fit(vecs)

bandwidth = estimate_bandwidth(vecs, quantile=0.2, n_samples=len(vecs))
#clusters = MeanShift(bandwidth=bandwidth, n_jobs=-1, cluster_all=False).fit(vecs)
clusters = MeanShift(bandwidth=bandwidth*1.1, n_jobs=-1, cluster_all=False).fit(vecs)

labels = clusters.labels_

# Number of clusters in labels, ignoring noise if present.
n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
n_noise_ = list(labels).count(-1)

print("Estimated number of clusters: %d" % n_clusters_)
print("Estimated number of noise points: %d" % n_noise_)
print(f"Labels: {set(labels)}")

In [None]:
# get back texts for labeled vecs stored in vector store
chroma_client = chromadb.EphemeralClient()
collection = chroma_client.create_collection(name="docs")
ids=[str(i) for i in range(len(splits))]
collection.add(
    documents=[d.page_content for d in splits],
    embeddings=vecs,
    # trouble with some metadata types
    #metadatas=[d.metadata for d in splits],
    ids=ids
)
print(f"{collection.count()} docs added to Chroma")

In [None]:
unique_labels = set([l for l in labels if l >= 0])
# get representative vec for all labels (centroid)
representatives = []
for label in unique_labels:
    vecs_with_label = np.array([v for ix, v in zip(range(len(vecs)), vecs) if labels[ix] == label])
    centroid = np.mean(vecs_with_label, axis=0).tolist()
    representatives.append(centroid)

In [None]:
representative_docs = collection.query(
    query_embeddings=representatives,
    n_results=1,
    include=["documents"]
)

In [None]:
# ask LLM for summary/tag
from langchain_community.llms import Ollama
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

llm = Ollama(model="llama3")
prompt = ChatPromptTemplate.from_messages([
    ("system", "Summarize in maximum three words. No other output. No punctuation."),
    ("user", "{input}")
])
output_parser = StrOutputParser()
chain = prompt | llm | output_parser

tags = [chain.invoke({ "input": docs[0] }) for docs in tqdm(representative_docs["documents"])]
print(tags)

In [None]:
# create graph from tags and vecs
from scipy.sparse.csgraph import minimum_spanning_tree

def similarities(vec, vecs):
    return np.array([np.dot(vec, other_vec) / (np.linalg.norm(vec) * np.linalg.norm(other_vec)) for other_vec in vecs])

adjacency_matrix = np.array([1 - similarities(vec, representatives) for vec in representatives])
# scale to 1
adjacency_matrix /= np.max(adjacency_matrix)

mst = minimum_spanning_tree(adjacency_matrix)
edges = (mst.toarray() > 0).astype(int)
#print(edges)

import networkx as nx
import matplotlib.pyplot as plt
from adjustText import adjust_text

G = nx.from_numpy_array(edges)
nodes = { label: tag for label, tag in zip(unique_labels, tags) }

# Draw the graph
plt.figure(figsize=(12, 9))
pos = nx.spring_layout(G)  # Positioning of nodes
# Draw nodes and edges
nx.draw_networkx_nodes(G, pos, node_color='lightblue', node_size=250)
nx.draw_networkx_edges(G, pos, edge_color='gray')

# Draw labels
#nx.draw_networkx_labels(G, pos, labels=nodes, font_size=15)
# Draw labels
texts = []
for node, (x, y) in pos.items():
    texts.append(plt.text(x, y, nodes[node], fontsize=8, ha='center', va='center'))

# Adjust text to avoid overlap
adjust_text(texts)

# Set the title
plt.title("Topic network")

plt.show()