Load previously downloaded Wikipedia docs.

In [None]:
import pickle
with open("docs.pickle", "rb") as file:
    docs = pickle.load(file)
print(f"{len(docs)} documents")

Split all docs to make them fit as context (or input) of a local llm.

In [None]:
from langchain_text_splitters import RecursiveCharacterTextSplitter

text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000, chunk_overlap=200, add_start_index=True
)
splits = text_splitter.split_documents(docs)
print(f"{len(splits)} splits")

Create embedding vectors for all splits using some Ollama-served llm.

In [None]:
from tqdm.notebook import tqdm
from langchain_community.embeddings import OllamaEmbeddings

embeddings = OllamaEmbeddings(model="llama3")
vecs = []
for split in tqdm(splits):
    vecs.append(embeddings.embed_documents([split])[0])
print(f"embedding space dim: {len(vecs[0])}")

In [None]:
with open("vecs.pickle", "wb") as file:
    pickle.dump(vecs, file)

In [None]:
from sklearn.cluster import DBSCAN
from sklearn.cluster import KMeans

#clusters = DBSCAN(eps=.5, min_samples=3).fit(vecs)
clusters = KMeans(n_clusters=10).fit(vecs)
labels = clusters.labels_

# Number of clusters in labels, ignoring noise if present.
n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
n_noise_ = list(labels).count(-1)

print("Estimated number of clusters: %d" % n_clusters_)
print("Estimated number of noise points: %d" % n_noise_)

In [None]:
print([l for l in labels if l > 0])

In [None]:
# get back texts for labeled vecs stored in vector store
import chromadb

chroma_client = chromadb.Client()
collection = chroma_client.create_collection(name="wikidocs")
ids=[str(i) for i in range(len(splits))]
collection.add(
    documents=[d.page_content for d in splits],
    embeddings=vecs,
    metadatas=[d.metadata for d in splits],
    ids=ids
)
print(f"{collection.count()} docs added to Chroma")

In [None]:
# retrieve docs
unique_labels = set([l for l in labels if l > 0])
query_embeddings = []
for label in unique_labels:
    ix = list(labels).index(label)
    query_embeddings.append(vecs[ix])

In [None]:
representatives = collection.query(
    query_embeddings=query_embeddings,
    n_results=5,
    include=["documents"]
)

In [None]:
# ask LLM for single term/tag
from langchain_community.llms import Ollama
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

llm = Ollama(model="llama3")
prompt = ChatPromptTemplate.from_messages([
    ("system", "Summarize in maximum three words. No other output."),
    ("user", "{input}")
])
output_parser = StrOutputParser()
chain = prompt | llm | output_parser

for response in representatives["documents"]:
    for text in response:
        print(chain.invoke({ "input": text }))