In [None]:
import os
import json
import multiprocessing as mp

from tqdm import tqdm

import numpy as np
import pandas as pd

from umap import UMAP
from hdbscan import HDBSCAN
from sklearn.feature_extraction.text import CountVectorizer

from bertopic import BERTopic
from bertopic.representation import MaximalMarginalRelevance

from utils.data import load_data
from utils.embeddings import load_embeddings

import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords

# set random seed:
np.random.seed(42)

**Load data & embeddings:**

In [None]:
!ls /datasets/idw-reddit

In [None]:
DATA = '/datasets/idw-reddit/training_data.csv'
df = load_data(DATA)

In [None]:
EMBEDDING_MODEL = 'all-mpnet-base-v2'
EMBEDDING_MODEL_PATH = os.path.join('embeddings', f'{EMBEDDING_MODEL}.pickle')
embeddings = load_embeddings(EMBEDDING_MODEL_PATH)

assert len(embeddings) == len(df), "Error! Embedding length does not match dataframe length!"

df['embedding'] = list(embeddings)
del embeddings

In [None]:
df.head()

**Create average embedding representations:**

In [None]:
# average embeddings:
print('Averaging embeddings')
agg_df = df.groupby(['full_id', 'source'])['embedding'].apply(np.vstack).reset_index()
agg_df['embedding'] = agg_df['embedding'].apply(lambda row: row.mean(axis=0))

# Aggregate sentences & map to embeddings:
print('Mapping aggregated sentences to embeddings')
df = df.groupby(['full_id', 'source'])['tokens'].apply(list).reset_index()
df['tokens'] = df['tokens'].apply(lambda row: ' '.join(row))
agg_df['tokens'] = agg_df['full_id'].map(
    dict(
        zip(
            df['full_id'],
            df['tokens']
        )
    )
)

# save memory
del df

# SORT!
agg_df.sort_values('full_id', ascending=True, inplace=True)

**Map text representations back to dataframe:**

In [None]:
TEXT_PATH = '/datasets/idw-reddit/text_representations.csv'
text_reps = pd.read_csv(TEXT_PATH)

In [None]:
text_reps.head()

In [None]:
agg_df['text_representation'] = agg_df['full_id'].map(
    dict(
        zip(
            text_reps['full_id'],
            text_reps['text_representation']
        )
    )
)

In [None]:
agg_df

**Save embeddings mapped to full_ids:**
- This ensures no issues with re-arranged rows in the Pandas `groupby`.

In [None]:
full_ids = agg_df['full_id'].tolist()
arrays_to_list = [arr.tolist() for arr in tqdm(agg_df['embedding'].tolist())]
out_arrs = [{'full_id': k, 'embeddings': v} for k,v in zip(full_ids, arrays_to_list)]

with open(f'embeddings/full_id_to_embeddings.jsonl', 'w') as f:
    for arr in tqdm(out_arrs):
        f.write(json.dumps(arr) + '\n')
                                               
del full_ids, arrays_to_list, out_arrs

print('Done!')

## Train BERTopic

In [None]:
# UMAP
umap_model = UMAP(
    n_components=5,
    n_neighbors=15,
    metric='cosine',
    min_dist=0.0,
    init='tswspectral',
    unique=True,
    n_epochs=400,
    low_memory=True,
    random_state=137,
    verbose=True
)

# HDBSCAN
min_cluster_size = 196
min_samples = 34
cluster_method = 'eom'

hdbscan_model = HDBSCAN(
    min_cluster_size=min_cluster_size,
    min_samples=min_samples,
    metric='euclidean',
    cluster_selection_method=cluster_method,
    prediction_data=True,
    core_dist_n_jobs=mp.cpu_count()-1
)

# CountVectorizer
min_gram = 1
max_gram = 1

cv_model = CountVectorizer(
    min_df=1,
    max_df=0.95,
    stop_words=list(
        stopwords.words('english')
    ),
    ngram_range=(min_gram, max_gram)
)

# MMR for diversity:
diversity = 0.3
mmr_model = MaximalMarginalRelevance(diversity=diversity)

## Fit BERTopic Model
- Save model to `fit_bertopic_model` in the `models` directory.
- Get topics and save them to `fit_data.csv` in the `models` directory.

In [None]:
topic_model = BERTopic(
    umap_model=umap_model,
    hdbscan_model=hdbscan_model,
    vectorizer_model=cv_model,
    calculate_probabilities=True,
    top_n_words=20,
    representation_model=mmr_model
)

topics, probs = topic_model.fit_transform(
    agg_df['text_representation'].tolist(),
    np.array(agg_df['embedding'].tolist())
)

# save model:
MODEL_NAME = 'fit_bertopic_model'
MODEL_PATH = os.path.join('models', MODEL_NAME)

print(f'Saving model to {MODEL_PATH}')
topic_model.save(MODEL_PATH, serialization="safetensors", save_ctfidf=True, save_embedding_model=False)

In [None]:
topic_model.get_topic_info()

`---Complete---`