In [97]:
import sys
sys.path.append("..")

from ris_evaluation.evaluator import Evaluator

from sklearn.decomposition import LatentDirichletAllocation
from sklearn.feature_extraction.text import CountVectorizer

import pandas as pd
import numpy as np

from bertopic import BERTopic

In [98]:
class Model:

    def __init__(self, documents, labels) -> None:
        self.documents = documents
        self.labels = labels

    def train(self):
        raise NotImplementedError
    
    def get_output(self):
        raise NotImplementedError
    
    def get_results_df(self):
        raise NotImplementedError
    
    def get_words_for_topics(self, topics):
        words_by_topics = {}
        for idx, topic in enumerate(topics):
            words = self.documents[idx].split(" ")

            if topic not in words_by_topics:
                words_by_topics[topic] = {}

            for word in words:
                if word not in words_by_topics[topic]:
                    words_by_topics[topic][word] = 0

                words_by_topics[topic][word] += 1

        return words_by_topics

In [99]:
class LDAModel(Model):

    def __init__(self, documents, labels) -> None:
        super().__init__(documents, labels)
        self.n_topics = len(set(labels))

    def train(self):
        self.vectorizer = CountVectorizer()
        X = self.vectorizer.fit_transform(self.documents)

        self.lda = LatentDirichletAllocation(n_components=self.n_topics, random_state=0)
        self.lda.fit(X)

    def get_results_df(self):
        results_df = pd.DataFrame()
        results_df['document'] = self.documents
        results_df['y_true'] = self.labels

        X = self.vectorizer.transform(self.documents)
        results_df['y_pred'] = self.lda.transform(X).argmax(axis=1)
        results_df['y_pred_highest_proba'] = self.lda.transform(X).max(axis=1)
        return results_df

    def get_output(self):
        topics = []
        for topic in self.lda.components_:
            topic_words = []
            for i in topic.argsort()[-10:]:
                topic_words.append(self.vectorizer.get_feature_names_out()[i])
            topics.append(topic_words)

        return {
            "topics": topics,
            "topic-document-matrix": None,
            "topic-word-matrix": None,
            "test-topic-document-matrix": None
        }

In [100]:
class BERTopicModel(Model):

    def __init__(self, documents, labels) -> None:
        super().__init__(documents, labels)
        self.n_topics = len(set(labels)) + 1  # +1 for outliers

    def train(self):
        self.bert_model = BERTopic(language="english", calculate_probabilities=True, nr_topics=self.n_topics)
        self.topics, self.probs = self.bert_model.fit_transform(self.documents)

    def get_results_df(self):
        results_df = pd.DataFrame()
        results_df['document'] = self.documents
        results_df['y_true'] = self.labels

        results_df['y_pred'] = self.topics
        results_df['y_pred_highest_proba'] = np.max(self.probs, axis=1)

        relevant_results_df = results_df[results_df['y_pred'] != -1]
        return relevant_results_df
    
    def get_output(self):
        return {
            "topics": [item for item in self.bert_model.get_topic_info()["Representation"]],
            "topic-document-matrix": self.probs.transpose(),
            "topic-word-matrix": self.bert_model.c_tf_idf_,
            "test-topic-document-matrix": self.probs.transpose()
        }

In [101]:
models = [
    BERTopicModel
]

dataset_names = ['BBC_News']

In [107]:
for dataset in dataset_names:
    # -- Get dataset
    documents_df = pd.read_csv(f'../datasets/data/{dataset}/documents.csv')
    documents_df = documents_df

    documents = documents_df['document'].tolist()
    labels = documents_df['class_name'].tolist()

    metrics_df = pd.DataFrame()

    for model in models:
        # -- Train model
        trained_model = model(documents, labels)
        trained_model.train()

        model_output = trained_model.get_output()

        # -- Evaluate model
        evaluator = Evaluator(model_output)
        results_df = trained_model.get_results_df()

        for idx, topic in enumerate(results_df['y_pred'].tolist()):
            words = documents_df.iloc[idx]['document'].split()
            print(words)
            words = documents[idx].split(" ")
            print(words)
        break

        words_by_extracted_topics = trained_model.get_words_for_topics(results_df['y_pred'].tolist())
        words_by_class = trained_model.get_words_for_topics(results_df['y_true'].tolist())

        coherence = evaluator.compute_coherence()
        diversity = evaluator.compute_diversity()
        supervised_correlation = evaluator.compute_supervised_correlation(words_by_extracted_topics, words_by_class)

        metrics_results = {}
        for coherence_type, coherence_value in coherence.items():
            metrics_results[f'coherence_{coherence_type}'] = coherence_value
        metrics_results['diversity'] = diversity
        metrics_results['supervised_correlation'] = supervised_correlation

        metrics_df = pd.concat([metrics_df, pd.DataFrame(metrics_results, index=[trained_model.__class__.__name__])])

    print(metrics_df)

['hit', 'shelf', 'combine', 'medium', 'player', 'phone', 'gaming', 'gadget', 'sale', 'price', 'handheld', 'device', 'debut', 'sale', 'week', 'catalogue', 'game', 'prepare', 'gadget', 'include', 'great', 'escape', 'conflict', 'back', 'gadget', 'face', 'stiff', 'competition', 'handheld', 'gaming', 'device', 'device', 'pack', 'lot', 'function', 'black', 'cover', 'aim', 'gamer', 'game', 'play', 'gadget', 'play', 'game', 'music', 'track', 'movie', 'store', 'digital', 'photo', 'mobile', 'phone', 'send', 'text', 'multimedia', 'mail', 'message', 'phone', 'service', 'enable', 'send', 'message', 'provide', 'pre', 'pay', 'vodafone', 'account', 'device', 'work', 'global', 'position', 'system', 'aid', 'support', 'variety', 'location', 'base', 'service', 'datum', 'system', 'multi', 'player', 'game', 'gadget', 'retail', 'partner', 'device', 'roll', 'impressive', 'list', 'function', 'face', 'competition', 'establish', 'mobile', 'game', 'main', 'competition', 'sale', 'price', 'cost', 'ready', 'pool', '