Skip to content

Commit

Permalink
Format using ruff (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
ianrandman committed Jun 23, 2024
1 parent cecb683 commit 7766277
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 31 deletions.
91 changes: 62 additions & 29 deletions bertopic/_bertopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,10 @@ def _outliers(self):

@property
def topic_labels_(self):
topic_labels = {key: f"{key}_" + "_".join([word[0] for word in values[:4]])
for key, values in self.topic_representations_.items()}
topic_labels = {
key: f"{key}_" + "_".join([word[0] for word in values[:4]])
for key, values in self.topic_representations_.items()
}
if self._is_zeroshot():
# Need to correct labels from zero-shot topics
topic_id_to_zeroshot_label = {
Expand Down Expand Up @@ -446,15 +448,21 @@ def fit_transform(

# Zero-shot Topic Modeling
if self._is_zeroshot():
documents, embeddings, assigned_documents, assigned_embeddings = self._zeroshot_topic_modeling(documents, embeddings)
documents, embeddings, assigned_documents, assigned_embeddings = (
self._zeroshot_topic_modeling(documents, embeddings)
)
# Filter UMAP embeddings to only non-assigned embeddings to be used for clustering
umap_embeddings = self.umap_model.transform(embeddings)

if len(documents) > 0: # No zero-shot topics matched
# Cluster reduced embeddings
documents, probabilities = self._cluster_embeddings(umap_embeddings, documents, y=y)
documents, probabilities = self._cluster_embeddings(
umap_embeddings, documents, y=y
)
if self._is_zeroshot() and len(assigned_documents) > 0:
documents, embeddings = self._combine_zeroshot_topics(documents, embeddings, assigned_documents, assigned_embeddings)
documents, embeddings = self._combine_zeroshot_topics(
documents, embeddings, assigned_documents, assigned_embeddings
)
else:
# All documents matches zero-shot topics
documents = assigned_documents
Expand Down Expand Up @@ -501,7 +509,9 @@ def fit_transform(
else:
# Use `topics_before_reduction` because `self.topics_` may have already been updated from
# reducing topics, and the original probabilities are needed for `self._map_probabilities()`
probabilities = sim_matrix[np.arange(len(documents)), topics_before_reduction]
probabilities = sim_matrix[
np.arange(len(documents)), topics_before_reduction
]

# Resulting output
self.probabilities_ = self._map_probabilities(
Expand Down Expand Up @@ -1622,7 +1632,6 @@ def update_topics(
"c-TF-IDF embeddings instead of centroid embeddings."
)

# Extract words
documents = pd.DataFrame(
{"Document": docs, "Topic": topics, "ID": range(len(docs)), "Image": images}
)
Expand All @@ -1633,6 +1642,7 @@ def update_topics(
# Update topic sizes and assignments
self._update_topic_size(documents)

# Extract words and update topic labels
self.c_tf_idf_, words = self._c_tf_idf(documents_per_topic)
self.topic_representations_ = self._extract_words_per_topic(words, documents)

Expand Down Expand Up @@ -2272,7 +2282,7 @@ def merge_topics(
mappings = {
topic_to: {
"topics_from": topics_from,
"topic_sizes": [self.topic_sizes_[topic] for topic in topics_from]
"topic_sizes": [self.topic_sizes_[topic] for topic in topics_from],
}
for topic_to, topics_from in mappings.items()
}
Expand Down Expand Up @@ -4044,9 +4054,11 @@ def _zeroshot_topic_modeling(
# Check that if a number of topics was specified, it exceeds the number of zeroshot topics matched
num_zeroshot_topics = len(assigned_documents["Topic"].unique())
if self.nr_topics and not self.nr_topics > num_zeroshot_topics:
raise ValueError(f'The set nr_topics ({self.nr_topics}) must exceed the number of matched zero-shot topics '
f'({num_zeroshot_topics}). Consider raising nr_topics or raising the '
f'zeroshot_min_similarity ({self.zeroshot_min_similarity}).')
raise ValueError(
f"The set nr_topics ({self.nr_topics}) must exceed the number of matched zero-shot topics "
f"({num_zeroshot_topics}). Consider raising nr_topics or raising the "
f"zeroshot_min_similarity ({self.zeroshot_min_similarity})."
)

# Select non-assigned topics to be clustered
documents = documents.iloc[non_assigned_ids]
Expand All @@ -4067,23 +4079,37 @@ def _is_zeroshot(self):
return True
return False

def _combine_zeroshot_topics(self,
documents: pd.DataFrame,
embeddings: np.ndarray,
assigned_documents: pd.DataFrame,
assigned_embeddings: np.ndarray) -> tuple[pd.DataFrame, np.ndarray]:

def _combine_zeroshot_topics(
self,
documents: pd.DataFrame,
embeddings: np.ndarray,
assigned_documents: pd.DataFrame,
assigned_embeddings: np.ndarray,
) -> tuple[pd.DataFrame, np.ndarray]:
# Combine Zero-shot topics with topics from clustering
zeroshot_topic_idx_to_topic_id = {zeroshot_topic_id: new_topic_id for new_topic_id, zeroshot_topic_id in
enumerate(set(assigned_documents.Topic))}
self.topic_id_to_zeroshot_topic_idx = {new_topic_id: zeroshot_topic_id for new_topic_id, zeroshot_topic_id in
enumerate(set(assigned_documents.Topic))}
assigned_documents.Topic = assigned_documents.Topic.map(zeroshot_topic_idx_to_topic_id)
zeroshot_topic_idx_to_topic_id = {
zeroshot_topic_id: new_topic_id
for new_topic_id, zeroshot_topic_id in enumerate(
set(assigned_documents.Topic)
)
}
self.topic_id_to_zeroshot_topic_idx = {
new_topic_id: zeroshot_topic_id
for new_topic_id, zeroshot_topic_id in enumerate(
set(assigned_documents.Topic)
)
}
assigned_documents.Topic = assigned_documents.Topic.map(
zeroshot_topic_idx_to_topic_id
)
num_zeroshot_topics = len(zeroshot_topic_idx_to_topic_id)

# Insert zeroshot topics between outlier cluster and other clusters
documents.Topic = documents.Topic.apply(
lambda topic_id: topic_id + num_zeroshot_topics if topic_id != -1 else topic_id)
lambda topic_id: topic_id + num_zeroshot_topics
if topic_id != -1
else topic_id
)

# Combine the clustered documents/embeddings with assigned documents/embeddings in the original order
documents = pd.concat([documents, assigned_documents])
Expand Down Expand Up @@ -4661,7 +4687,7 @@ def _reduce_to_n_topics(
mappings = {
topic_to: {
"topics_from": topics_from,
"topic_sizes": [self.topic_sizes_[topic] for topic in topics_from]
"topic_sizes": [self.topic_sizes_[topic] for topic in topics_from],
}
for topic_to, topics_from in basic_mappings.items()
}
Expand All @@ -4680,8 +4706,10 @@ def _reduce_to_n_topics(
# or use a calculated representation.
if self._is_zeroshot():
new_topic_id_to_zeroshot_topic_idx = {}
topics_to_map = {topic_mapping[0]: topic_mapping[1] for topic_mapping in
np.array(self.topic_mapper_.mappings_)[:, -2:]}
topics_to_map = {
topic_mapping[0]: topic_mapping[1]
for topic_mapping in np.array(self.topic_mapper_.mappings_)[:, -2:]
}

for topic_to, topics_from in basic_mappings.items():
# When extracting topics, the reduced topics were reordered.
Expand All @@ -4690,7 +4718,8 @@ def _reduce_to_n_topics(

# which of the original topics are zero-shot
zeroshot_topic_ids = [
topic_id for topic_id in topics_from
topic_id
for topic_id in topics_from
if topic_id in self.topic_id_to_zeroshot_topic_idx
]
if len(zeroshot_topic_ids) == 0:
Expand All @@ -4699,7 +4728,9 @@ def _reduce_to_n_topics(
# If any of the original topics are zero-shot, take the best fitting zero-shot label
# if the cosine similarity with the new topic exceeds the zero-shot threshold
zeroshot_labels = [
self.zeroshot_topic_list[self.topic_id_to_zeroshot_topic_idx[topic_id]]
self.zeroshot_topic_list[
self.topic_id_to_zeroshot_topic_idx[topic_id]
]
for topic_id in zeroshot_topic_ids
]
zeroshot_embeddings = self._extract_embeddings(zeroshot_labels)
Expand All @@ -4709,7 +4740,9 @@ def _reduce_to_n_topics(
best_zeroshot_topic_idx = np.argmax(cosine_similarities)
best_cosine_similarity = cosine_similarities[best_zeroshot_topic_idx]
if best_cosine_similarity >= self.zeroshot_min_similarity:
new_topic_id_to_zeroshot_topic_idx[topic_to] = zeroshot_topic_ids[best_zeroshot_topic_idx]
new_topic_id_to_zeroshot_topic_idx[topic_to] = zeroshot_topic_ids[
best_zeroshot_topic_idx
]

self.topic_id_to_zeroshot_topic_idx = new_topic_id_to_zeroshot_topic_idx

Expand Down
8 changes: 6 additions & 2 deletions tests/test_bertopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,13 @@ def test_full_model(model, documents, request):
# Test zero-shot topic modeling
if topic_model._is_zeroshot():
if topic_model._outliers:
assert set(topic_model.topic_labels_.keys()) == set(range(-1, len(topic_model.topic_labels_) - 1))
assert set(topic_model.topic_labels_.keys()) == set(
range(-1, len(topic_model.topic_labels_) - 1)
)
else:
assert set(topic_model.topic_labels_.keys()) == set(range(len(topic_model.topic_labels_)))
assert set(topic_model.topic_labels_.keys()) == set(
range(len(topic_model.topic_labels_))
)

# Test topics over time
timestamps = [i % 10 for i in range(len(documents))]
Expand Down

0 comments on commit 7766277

Please sign in to comment.