In [None]:
from google.cloud import bigquery
from google.oauth2 import service_account
from bertopic import BERTopic

key_path = "/Users/yco/.config/dbt-user-creds.json"
credentials = service_account.Credentials.from_service_account_file(
    key_path  # , scopes=["https://www.googleapis.com/auth/cloud-platform"],
)

client = bigquery.Client(
    credentials=credentials,
    project=credentials.project_id,
)
query_job = client.query(f"SELECT * FROM `reddit_texts.posts_clean`")
data = list(query_job)
docs = [r["text"] for r in data]
cats = [r["subreddit"] for r in data]
catset = list(set(cats))
cat_ids = [catset.index(cat) for cat in cats]

# model = BERTopic(language="english", calculate_probabilities=True)

from sentence_transformers import SentenceTransformer

sentence_model = SentenceTransformer("distilbert-base-nli-mean-tokens", device="cpu")
model = BERTopic(
    embedding_model=sentence_model, language="english", calculate_probabilities=True
)

topics, probabilities = model.fit_transform(docs, y=cat_ids)

In [None]:
model.get_topic_freq().head(10)

In [None]:
!mkdir -p ../models/bertopic/distilbert/

In [None]:
model.save("../models/bertopic/distilbert/model")

In [None]:
model.get_topics()

In [None]:
model.update_topics(docs, topics, n_gram_range=(1, 5))
model.visualize_barchart(top_n_topics=10)

In [None]:
model.visualize_topics()

In [None]:
doc_id = 200
print(docs[doc_id])
model.visualize_distribution(probabilities[doc_id], min_probability=0.015)

In [None]:
model.visualize_hierarchy(top_n_topics=50)

In [None]:
model.visualize_heatmap(n_clusters=8, width=1000, height=1000)

In [None]:
model.visualize_term_rank()

In [None]:
model