get datasets from huggingface

In [None]:
from datasets import load_dataset
dataset = load_dataset("ag_news", split="train")
print(dataset[0])

BERTopic 期望的输入是一个字符串列表。因此，你需要从 Hugging Face 数据集中提取包含文本的列，并将其转换为 Python 列表

In [2]:
docs = dataset['text']

可以根据平台特性选择使用CPU还是GPU来进行嵌入

In [None]:
import torch

if torch.cuda.is_available():
    print(f"cuda available: {torch.cuda.is_available()}")
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    # mps for MacOS
    print(f"mps available: {torch.backends.mps.is_available()}")
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")


使用 `sentence_transformers` 库来计算文档嵌入，并进行预嵌入，方便后续调试 BERTopic

In [None]:
import numpy as np
import sentence_transformers

# 初始化 BERTopic 模型
embeddings = sentence_transformers.SentenceTransformer("all-MiniLM-L6-v2", device=device)
pre_embeddings = embeddings.encode(docs, show_progress_bar=True)
np.save("pre_embeddings.npy", pre_embeddings)

In [None]:
from bertopic import BERTopic


model = BERTopic(verbose=True, language='english') 

# 训练模型并获取文档的主题
topics, probs = model.fit_transform(docs, embeddings=pre_embeddings)

In [None]:
# 查看主题信息
topic_info = model.get_topic_info()
print("\n--- 主题信息 ---")
print(topic_info.head(n=10))

In [None]:
print(model.get_document_info(docs))

In [None]:
fig_topics = model.visualize_topics()
fig_topics.write_html("bertopic_topics_viz.html")