Skip to content

Commit

Permalink
When merging zero-shot, keep single zero-shot label if meets threshol…
Browse files Browse the repository at this point in the history
…d with new topic embedding (#2)
  • Loading branch information
ianrandman committed Jun 18, 2024
1 parent ecd0224 commit 19af331
Showing 1 changed file with 39 additions and 24 deletions.
63 changes: 39 additions & 24 deletions bertopic/_bertopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4070,30 +4070,6 @@ def _reduce_to_n_topics(self, documents: pd.DataFrame) -> pd.DataFrame:
for topic_to, topics_from in basic_mappings.items()
}

# Combine merged zero-shot topics with '_' and
# remap self.topic_id_to_zeroshot_topic_idx based on new self.zeroshot_topic_list.
# This does not keep the ordering where zero-shot topics come before clustered topics.
if self._is_zeroshot():
new_topic_id_to_zeroshot_topics = {
topic_to: [
self.zeroshot_topic_list[self.topic_id_to_zeroshot_topic_idx[topic_id]]
# multiple original topics combined; one or more are zero-shot topics
for topic_id in topics_from if topic_id in self.topic_id_to_zeroshot_topic_idx
] for topic_to, topics_from in basic_mappings.items()
# create mapping if any of the original topics are zero-shot
if any(topic_id in self.topic_id_to_zeroshot_topic_idx for topic_id in topics_from)
}
new_topic_id_to_zeroshot_topics = {
topic_id: '_'.join(topics)
for topic_id, topics in new_topic_id_to_zeroshot_topics.items()
}
self.topic_id_to_zeroshot_topic_idx = {
topic_id: zeroshot_topic_idx
for zeroshot_topic_idx, (topic_id, _)
in enumerate(new_topic_id_to_zeroshot_topics.items())
}
self.zeroshot_topic_list = list(new_topic_id_to_zeroshot_topics.values())

# Map topics
documents.Topic = new_topics
self._update_topic_size(documents)
Expand All @@ -4102,6 +4078,45 @@ def _reduce_to_n_topics(self, documents: pd.DataFrame) -> pd.DataFrame:
# Update representations
documents = self._sort_mappings_by_frequency(documents)
self._extract_topics(documents, mappings=mappings)

# When zero-shot topic(s) are present in the topics to merge,
# determine whether to take one of the zero-shot topic labels
# or use a calculated representation.
if self._is_zeroshot():
new_topic_id_to_zeroshot_topic_idx = {}
topics_to_map = {topic_mapping[0]: topic_mapping[1] for topic_mapping in
np.array(self.topic_mapper_.mappings_)[:, -2:]}

for topic_to, topics_from in basic_mappings.items():
# When extracting topics, the reduced topics were reordered.
# Must get the updates topic_to.
topic_to = topics_to_map[topic_to]

# which of the original topics are zero-shot
zeroshot_topic_ids = [
topic_id for topic_id in topics_from
if topic_id in self.topic_id_to_zeroshot_topic_idx
]
if len(zeroshot_topic_ids) == 0:
continue

# If any of the original topics are zero-shot, take the best fitting zero-shot label
# if the cosine similarity with the new topic exceeds the zero-shot threshold
zeroshot_labels = [
self.zeroshot_topic_list[self.topic_id_to_zeroshot_topic_idx[topic_id]]
for topic_id in zeroshot_topic_ids
]
zeroshot_embeddings = self._extract_embeddings(zeroshot_labels)
cosine_similarities = cosine_similarity(
zeroshot_embeddings, [self.topic_embeddings_[topic_to]]
).flatten()
best_zeroshot_topic_idx = np.argmax(cosine_similarities)
best_cosine_similarity = cosine_similarities[best_zeroshot_topic_idx]
if best_cosine_similarity >= self.zeroshot_min_similarity:
new_topic_id_to_zeroshot_topic_idx[topic_to] = zeroshot_topic_ids[best_zeroshot_topic_idx]

self.topic_id_to_zeroshot_topic_idx = new_topic_id_to_zeroshot_topic_idx

self._update_topic_size(documents)
return documents

Expand Down

0 comments on commit 19af331

Please sign in to comment.