In [17]:
# Setting Arguments
args_text = '--base-model sentence-transformers/all-MiniLM-L6-v2 '+\
            '--dataset news --n-word 5000 --epochs-1 200 --epochs-2 50 ' + \
            '--bsz 32 --stage-2-lr 2e-2 --stage-2-repeat 5 --coeff-1-dist 50 '+ \
            '--n-cluster 50 ' + \
            '--stage-1-ckpt trained_model/news_model_all-MiniLM-L6-v2_stage1_50t_5000w_199e.ckpt'

In [18]:
import re
import os
import sys
import time
import copy
import math
import argparse
import string
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchtools.optim import RangerLars
import gensim.downloader
import itertools

from scipy.stats import ortho_group
from scipy.optimize import linear_sum_assignment as linear_assignment
import matplotlib.pyplot as plt
from sentence_transformers import SentenceTransformer

import numpy as np
from tqdm import tqdm_notebook
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from sklearn.datasets import fetch_20newsgroups
from nltk.corpus import stopwords

from sklearn.feature_extraction.text import CountVectorizer
from utils import AverageMeter
from collections import OrderedDict

import pandas as pd
from sklearn.preprocessing import normalize
from sklearn.feature_extraction.text import TfidfTransformer, CountVectorizer
from gensim.corpora.dictionary import Dictionary
from pytorch_transformers import *
from sklearn.mixture import GaussianMixture
import scipy.stats
from sklearn.decomposition import PCA
from sklearn.cluster import OPTICS
from nltk.corpus import stopwords

from gensim.models.coherencemodel import CoherenceModel
from tqdm import tqdm
import scipy.sparse as sp
import nltk
from nltk.corpus import stopwords

from datetime import datetime
from itertools import combinations
import gensim.downloader
from scipy.linalg import qr
from data import *
from model import ContBertTopicExtractorAE
from evaluation import get_topic_qualities
import warnings
warnings.filterwarnings("ignore")

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [19]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "0" 

# Data loading

In [20]:
def _parse_args():
    parser = argparse.ArgumentParser(description='Contrastive topic modeling')
    parser.add_argument('--epochs-1', default=100, type=int,
                        help='Number of training epochs for Stage 1')
    parser.add_argument('--epochs-2', default=10, type=int,
                        help='Number of training epochs for Stage 2')
    parser.add_argument('--bsz', type=int, default=64,
                        help='Batch size')
    parser.add_argument('--dataset', default='news', type=str,
                        choices=['news', 'twitter', 'wiki', 'nips', 'stackoverflow', 'reuters', 'r52', 'imdb'],
                        help='Name of the dataset')
    parser.add_argument('--n-cluster', default=20, type=int,
                        help='Number of clusters')
    parser.add_argument('--n-topic', type=int,
                        help='Number of topics. If not specified, use same value as --n-cluster')
    parser.add_argument('--n-word', default=2000, type=int,
                        help='Number of words in vocabulary')
    
    parser.add_argument('--base-model', type=str,
                        help='Name of base model in huggingface library.')
    
    parser.add_argument('--gpus', default=[0,1], type=int, nargs='+',
                        help='List of GPU numbers to use. Use 0 by default')
    
    parser.add_argument('--coeff-1-sim', default=1.0, type=float,
                        help='Coefficient for NN dot product similarity loss (Phase 1)')
    parser.add_argument('--coeff-1-dist', default=1.0, type=float,
                        help='Coefficient for NN SWD distribution loss (Phase 1)')
    parser.add_argument('--dirichlet-alpha-1', type=float,
                        help='Parameter for Dirichlet distribution (Phase 1). Use 1/n_topic by default.')
    
    parser.add_argument('--stage-1-ckpt', type=str,
                        help='Name of torch checkpoint file Stage 1. If this argument is given, skip Stage 1.')
    
    parser.add_argument('--coeff-2-recon', default=1.0, type=float,
                        help='Coefficient for VAE reconstruction loss (Phase 2)')
    parser.add_argument('--coeff-2-regul', default=1.0, type=float,
                        help='Coefficient for VAE KLD regularization loss (Phase 2)')
    parser.add_argument('--coeff-2-cons', default=1.0, type=float,
                        help='Coefficient for CL consistency loss (Phase 2)')
    parser.add_argument('--coeff-2-dist', default=1.0, type=float,
                        help='Coefficient for CL SWD distribution matching loss (Phase 2)')
    parser.add_argument('--dirichlet-alpha-2', type=float,
                        help='Parameter for Dirichlet distribution (Phase 2). Use same value as dirichlet-alpha-1 by default.')
    
    parser.add_argument('--stage-2-lr', default=2e-1, type=float,
                        help='Learning rate of phase 2')
    parser.add_argument('--stage-2-repeat', default=5, type=int,
                        help='Repetition count of phase 2')
    
    parser.add_argument('--result-file', type=str,
                        help='File name for result summary')
    parser.add_argument('--palmetto-dir', type=str, default='./',
                        help='Directory where palmetto JAR and the Wikipedia index are. For evaluation')
    
    
    # Check if the code is run in Jupyter notebook
    is_in_jupyter = False
    try:
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            is_in_jupyter = True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            is_in_jupyter = False  # Terminal running IPython
        else:
            is_in_jupyter = False  # Other type (?)
    except NameError:
        is_in_jupyter = False
    
    if is_in_jupyter:
        return parser.parse_args(args=args_text.split())
    else:
        return parser.parse_args()

def data_load(dataset_name):
    should_measure_hungarian = False
    if dataset_name == 'news':
        textData = newsData()
        should_measure_hungarian = True
    elif dataset_name == 'imdb':
        textData = IMDBData()
    elif dataset_name == 'agnews':
        textData = AGNewsData()
    elif dataset_name == 'yahoo':
        textData = YahooData()
    elif dataset_name == 'twitter':
        textData = twitterData('/home/data/topicmodel/twitter_covid19.tsv')
    elif dataset_name == 'wiki':
        textData = wikiData('/home/data/topicmodel/smplAbstracts/')
    elif dataset_name == 'nips':
        textData = nipsData('/home/data/topicmodel/papers.csv')
    elif dataset_name == 'stackoverflow':
        textData = stackoverflowData('/home/data/topicmodel/stack_overflow.csv')
    elif dataset_name == 'reuters':
        textData = reutersData('/home/data/topicmodel/reuters-21578.txt')
    elif dataset_name == 'r52':
        textData = r52Data('/home/data/topicmodel/r52/')
        should_measure_hungarian = True
    return textData, should_measure_hungarian

In [21]:
args = _parse_args()
bsz = args.bsz
epochs_1 = args.epochs_1
epochs_2 = args.epochs_2

n_cluster = args.n_cluster
n_topic = args.n_topic if (args.n_topic is not None) else n_cluster
args.n_topic = n_topic

textData, should_measure_hungarian = data_load(args.dataset)

ema_alpha = 0.99
n_word = args.n_word
if args.dirichlet_alpha_1 is None:
    dirichlet_alpha_1 = 1 / n_cluster
else:
    dirichlet_alpha_1 = args.dirichlet_alpha_1
if args.dirichlet_alpha_2 is None:
    dirichlet_alpha_2 = dirichlet_alpha_1
else:
    dirichlet_alpha_2 = args.dirichlet_alpha_2
    
bert_name = args.base_model
bert_name_short = bert_name.split('/')[-1]
gpu_ids = args.gpus

skip_stage_1 = (args.stage_1_ckpt is not None)

In [22]:
trainds = BertDataset(bert=bert_name, text_list=textData.data, N_word=n_word, vectorizer=None, lemmatize=True)
basesim_path = f"./save/{args.dataset}_{bert_name_short}_basesim_matrix_full.pkl"
if os.path.isfile(basesim_path) == False:
    model = SentenceTransformer(bert_name.split('/')[-1], device='cuda')
    base_result_list = []
    for text in tqdm_notebook(trainds.nonempty_text):
        base_result_list.append(model.encode(text))
        
    base_result_embedding = np.stack(base_result_list)
    basereduced_norm = F.normalize(torch.tensor(base_result_embedding), dim=-1)
    basesim_matrix = torch.mm(basereduced_norm, basereduced_norm.t())
    ind = np.diag_indices(basesim_matrix.shape[0])
    basesim_matrix[ind[0], ind[1]] = torch.ones(basesim_matrix.shape[0]) * -1
    torch.save(basesim_matrix, basesim_path)
else:
    basesim_matrix = torch.load(basesim_path)

Downloading: 100%|██████████| 350/350 [00:00<00:00, 350kB/s]
Downloading: 100%|██████████| 226k/226k [00:00<00:00, 440kB/s] 
Downloading: 100%|██████████| 455k/455k [00:00<00:00, 591kB/s] 
Downloading: 100%|██████████| 112/112 [00:00<?, ?B/s] 
100%|██████████| 11314/11314 [00:09<00:00, 1233.27it/s]
Downloading .gitattributes: 100%|██████████| 1.23k/1.23k [00:00<00:00, 410kB/s]
Downloading 1_Pooling/config.json: 100%|██████████| 190/190 [00:00<00:00, 95.0kB/s]
Downloading README.md: 100%|██████████| 10.7k/10.7k [00:00<00:00, 3.55MB/s]
Downloading config.json: 100%|██████████| 612/612 [00:00<00:00, 306kB/s]
Downloading (…)ce_transformers.json: 100%|██████████| 116/116 [00:00<00:00, 58.0kB/s]
Downloading data_config.json: 100%|██████████| 39.3k/39.3k [00:00<00:00, 152kB/s]
Downloading model.safetensors: 100%|██████████| 90.9M/90.9M [00:00<00:00, 112MB/s]
Downloading pytorch_model.bin: 100%|██████████| 90.9M/90.9M [00:02<00:00, 33.5MB/s]
Downloading (…)nce_bert_config.json: 100%|██████████

  0%|          | 0/11314 [00:00<?, ?it/s]

# Step 1: Discovery neighborhood pairs and clustering

In [23]:
def dist_match_loss(hidden, alpha=1.0):
    device = hidden.device
    hidden_dim = hidden.shape[-1]
    rand_w = torch.Tensor(np.eye(hidden_dim, dtype=np.float64)).to(device)
    loss_dist_match = get_swd_loss(hidden, rand_w, alpha)
    return loss_dist_match


def get_swd_loss(states, rand_w, alpha=1.0):
    device = states.device
    states_shape = states.shape
    states = torch.matmul(states, rand_w)
    states_t, _ = torch.sort(states.t(), dim=1)

    states_prior = torch.Tensor(np.random.dirichlet([alpha]*states_shape[1], states_shape[0])).to(device) # (bsz, dim)
    states_prior = torch.matmul(states_prior, rand_w) # (dim, dim)
    states_prior_t, _ = torch.sort(states_prior.t(), dim=1) # (dim, bsz)
    return torch.mean(torch.sum((states_prior_t - states_t) ** 2, axis=0))

In [24]:
try:
    del model
except:
    pass
finally:
    torch.cuda.empty_cache()

In [25]:
if not skip_stage_1:
    model = ContBertTopicExtractorAE(N_topic=n_cluster, N_word=n_word, bert=bert_name, bert_dim=768)
    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model, device_ids=gpu_ids)
    model.cuda(gpu_ids[0])

In [26]:
bsz = args.bsz = 128

In [27]:
if not skip_stage_1:
    losses = AverageMeter()
    closses = AverageMeter() 
    dlosses = AverageMeter() 
    rlosses = AverageMeter() 
    criterion = nn.CrossEntropyLoss()

    temp_basesim_matrix = copy.deepcopy(basesim_matrix)
    finetuneds = FinetuneDataset(trainds, temp_basesim_matrix, ratio=1, k=1)
    trainloader = DataLoader(finetuneds, batch_size=bsz, shuffle=True, num_workers=0)
    memoryloader = DataLoader(finetuneds, batch_size=bsz * 2, shuffle=False, num_workers=0)

    optimizer = RangerLars(model.parameters(), lr=0.001, weight_decay=0.0001)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

    global_step = 0
    memory_queue = F.softmax(torch.randn(512, n_cluster).cuda(gpu_ids[0]), dim=1)
    for epoch in range(epochs_1):
        model.train()
        #ema_model.train()
        tbar = tqdm_notebook(trainloader)
        for batch_idx, batch in enumerate(tbar):       
            org_input, pos_input, _, _ = batch
            org_input_ids = org_input['input_ids'].cuda(gpu_ids[0])
            org_attention_mask = org_input['attention_mask'].cuda(gpu_ids[0])
            pos_input_ids = pos_input['input_ids'].cuda(gpu_ids[0])
            pos_attention_mask = pos_input['attention_mask'].cuda(gpu_ids[0])
            batch_size = org_input_ids.size(0)

            all_input_ids = torch.cat((org_input_ids, pos_input_ids), dim=0)
            all_attention_masks = torch.cat((org_attention_mask, pos_attention_mask), dim=0)
            all_topics, _ = model(all_input_ids, all_attention_masks, return_topic=True)

            orig_topic, pos_topic = torch.split(all_topics, len(all_topics) // 2)
            pos_sim = torch.sum(orig_topic * pos_topic, dim=-1)

            # consistency loss
            consistency_loss = -pos_sim.mean()

            # distribution matching loss
            memory_queue = torch.cat((memory_queue.detach(), all_topics), dim=0)[all_topics.size(0):]
            distmatch_loss = dist_match_loss(memory_queue, dirichlet_alpha_1)
            loss = args.coeff_1_sim * consistency_loss + \
                   10 * distmatch_loss

            losses.update(loss.item(), bsz)
            closses.update(consistency_loss.item(), bsz)
            dlosses.update(distmatch_loss.item(), bsz)

            tbar.set_description("Epoch-{} / consistency: {:.5f} - dist: {:.5f}".format(epoch, 
                                                                                        closses.avg, 
                                                                                        dlosses.avg), refresh=True)
            tbar.refresh()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()        
            global_step += 1
        scheduler.step()

In [28]:
if not skip_stage_1:
    model_stage1_name = f'./trained_model/{args.dataset}_model_{bert_name_short}_stage1_{args.n_topic}t_{args.n_word}w_{epoch}e.ckpt'
    torch.save(model.module.state_dict(), model_stage1_name)
else:
    model_stage1_name = args.stage_1_ckpt

# Stage 2: extract vocab set

### Load model

In [29]:
from word_embedding_utils import *

In [30]:
try:
    del model
except:
    pass
try:
    del ema_model
except:
    pass
torch.cuda.empty_cache()

In [31]:
model = ContBertTopicExtractorAE(N_topic=n_cluster, N_word=n_word, bert=bert_name, bert_dim=768)
model.cuda(gpu_ids[0])

model.load_state_dict(torch.load(model_stage1_name), strict=True)

Downloading: 100%|██████████| 612/612 [00:00<?, ?B/s] 
Downloading: 100%|██████████| 86.7M/86.7M [00:00<00:00, 93.5MB/s]


<All keys matched successfully>

In [32]:
temp_basesim_matrix = copy.deepcopy(basesim_matrix)
finetuneds = FinetuneDataset(trainds, temp_basesim_matrix, ratio=1, k=1)
memoryloader = DataLoader(finetuneds, batch_size=bsz * 2, shuffle=False, num_workers=0)
result_list = []
model.eval()
with torch.no_grad():
    for idx, batch in enumerate(tqdm_notebook(memoryloader)):        
        org_input, _, _, _ = batch
        org_input_ids = org_input['input_ids'].to(gpu_ids[0])
        org_attention_mask = org_input['attention_mask'].to(gpu_ids[0])
        topic, embed = model(org_input_ids, org_attention_mask, return_topic = True)
        result_list.append(topic)
result_embedding = torch.cat(result_list)
_, result_topic = torch.max(result_embedding, 1)

  0%|          | 0/45 [00:00<?, ?it/s]

In [33]:
d = {'text': trainds.preprocess_ctm(trainds.nonempty_text), 
     'cluster_label': result_topic.cpu().numpy()}
cluster_df = pd.DataFrame(data=d)

In [34]:
docs_per_class = cluster_df.groupby(['cluster_label'], as_index=False).agg({'text': ' '.join})

In [35]:
count_vectorizer = CountVectorizer(token_pattern=r'\b[a-zA-Z]{2,}\b')
ctfidf_vectorizer = CTFIDFVectorizer()
count = count_vectorizer.fit_transform(docs_per_class.text)
ctfidf = ctfidf_vectorizer.fit_transform(count, n_samples=len(cluster_df)).toarray()
words = count_vectorizer.get_feature_names()

In [36]:
# transport to gensim
(gensim_corpus, gensim_dict) = vect2gensim(count_vectorizer, count)
vocab_list = set(gensim_dict.token2id.keys())
stopwords = set(line.strip() for line in open('stopwords_en.txt'))

In [37]:
normalized = [coherence_normalize(doc) for doc in trainds.nonempty_text]
gensim_dict = Dictionary(normalized)
resolution_score = (ctfidf - np.min(ctfidf, axis=1, keepdims=True)) / (np.max(ctfidf, axis=1, keepdims=True) - np.min(ctfidf, axis=1, keepdims=True))

n_word = args.n_word
# n_topic_word = n_word / len(docs_per_class.cluster_label.index)
n_topic_word = n_word
n_topic_word = 10

topic_word_dict = {}
for label in docs_per_class.cluster_label.index:
    total_score = resolution_score[label]
    score_higest = total_score.argsort()
    score_higest = score_higest[::-1]
    topic_word_list = [words[index] for index in score_higest]
    
    topic_word_list = [word for word in topic_word_list if len(word) >= 3]    
    topic_word_list = [word for word in topic_word_list if word not in stopwords]    
    topic_word_list = [word for word in topic_word_list if word in gensim_dict.token2id]
    topic_word_dict[docs_per_class.cluster_label.iloc[label]] = topic_word_list[:int(n_topic_word)]

In [38]:
for key in topic_word_dict:
    print(f"{key}: {topic_word_dict[key]},")
topic_words_list = list(topic_word_dict.values())

0: ['car', 'dealer', 'saturn', 'price', 'toyota', 'model', 'door', 'profit', 'sport', 'engine'],
1: ['ticket', 'battery', 'launch', 'lib', 'rocket', 'pat', 'doug', 'flight', 'cost', 'exploration'],
2: ['msg', 'food', 'sensitivity', 'chinese', 'superstition', 'reaction', 'circuit', 'diet', 'honda', 'taste'],
3: ['bus', 'card', 'ram', 'simms', 'motherboard', 'mac', 'slot', 'board', 'bit', 'memory'],
4: ['jesus', 'god', 'christian', 'christ', 'bible', 'church', 'life', 'faith', 'father', 'sin'],
5: ['scsi', 'drive', 'ide', 'controller', 'mac', 'tape', 'device', 'bus', 'hard', 'quadra'],
6: ['gun', 'firearm', 'handgun', 'weapon', 'criminal', 'amendment', 'control', 'crime', 'nra', 'safety'],
7: ['israel', 'israeli', 'greek', 'arab', 'jew', 'war', 'palestinian', 'lebanese', 'turkish', 'lebanon'],
8: ['israeli', 'israel', 'arab', 'jew', 'palestinian', 'policy', 'palestine', 'jewish', 'gaza', 'jerusalem'],
9: ['image', 'file', 'gif', 'format', 'picture', 'photography', 'bmp', 'bitmap', 'conve

In [41]:
from contextualized_topic_models.utils.data_preparation import TopicModelDataPreparation
from nltk.corpus import stopwords as stop_words
from gensim.utils import deaccent


class WhiteSpacePreprocessing():
    def __init__(self, documents, stopwords_language="english", vocabulary_size=2000):
        self.documents = documents
        self.stopwords = set(stop_words.words(stopwords_language))
        self.vocabulary_size = vocabulary_size

        warnings.simplefilter('always', DeprecationWarning)
        warnings.warn("WhiteSpacePreprocessing is deprecated and will be removed in future versions."
                      "Use WhiteSpacePreprocessingStopwords.")

    def preprocess(self):
        preprocessed_docs_tmp = self.documents
        preprocessed_docs_tmp = [deaccent(doc.lower()) for doc in preprocessed_docs_tmp]
        preprocessed_docs_tmp = [doc.translate(
            str.maketrans(string.punctuation, ' ' * len(string.punctuation))) for doc in preprocessed_docs_tmp]
        preprocessed_docs_tmp = [' '.join([w for w in doc.split() if len(w) > 0 and w not in self.stopwords])
                                 for doc in preprocessed_docs_tmp]

        vectorizer = CountVectorizer(max_features=self.vocabulary_size)
        vectorizer.fit_transform(preprocessed_docs_tmp)
        temp_vocabulary = set(vectorizer.get_feature_names())

        preprocessed_docs_tmp = [' '.join([w for w in doc.split() if w in temp_vocabulary])
                                 for doc in preprocessed_docs_tmp]

        preprocessed_docs, unpreprocessed_docs, retained_indices = [], [], []
        for i, doc in enumerate(preprocessed_docs_tmp):
            if len(doc) > 0:
                preprocessed_docs.append(doc)
                unpreprocessed_docs.append(self.documents[i])
                retained_indices.append(i)

        vocabulary = list(set([item for doc in preprocessed_docs for item in doc.split()]))

        return preprocessed_docs, unpreprocessed_docs, vocabulary, retained_indices
    
def _hungarian_match(flat_preds, flat_targets, num_samples, class_num):  
    num_k = class_num
    num_correct = np.zeros((num_k, num_k))
  
    for c1 in range(0, num_k):
        for c2 in range(0, num_k):
            votes = int(((flat_preds == c1) * (flat_targets == c2)).sum())
            num_correct[c1, c2] = votes
  
    match = linear_assignment(num_samples - num_correct)
    match = np.array(list(zip(*match)))
    res = []
    for out_c, gt_c in match:
        res.append((out_c, gt_c))
  
    return res

def get_document_topic(topic_words, preprocessed_documents_lemmatized):
    topic_words_flatten = list(itertools.chain.from_iterable(topic_words))
    if '' in topic_words_flatten:
        topic_words_flatten.remove('')
    topic_words_flatten = list(set(topic_words_flatten))
    
    vectorizer = CountVectorizer(vocabulary = topic_words_flatten)
    vectorizer = vectorizer.fit(preprocessed_documents_lemmatized)
    count_mat = vectorizer.transform(preprocessed_documents_lemmatized).toarray()
    
    count_mat_normalized = count_mat + 1e-4
    count_mat_normalized = count_mat_normalized / count_mat_normalized.sum(axis=1).reshape(-1, 1)
    
    topic_mat = vectorizer.transform([' '.join(i) for i in topic_words]).toarray()
    topic_mat_normalized = topic_mat + 1e-4
    topic_mat_normalized = topic_mat_normalized / topic_mat_normalized.sum(axis=1).reshape(-1, 1)
    
    topic_mat_inverse = topic_mat_normalized @ topic_mat_normalized.transpose()
    topic_mat_inverse = np.linalg.inv(topic_mat_inverse)
    topic_mat_inverse = topic_mat_normalized.transpose() @ topic_mat_inverse
    document_topic = count_mat_normalized @ topic_mat_inverse
    return document_topic

class TopicModelDataPreparationNoNumber(TopicModelDataPreparation):
    def fit(self, text_for_contextual, text_for_bow, labels=None, wordlist=None):
        """
        This method fits the vectorizer and gets the embeddings from the contextual model
        :param text_for_contextual: list of unpreprocessed documents to generate the contextualized embeddings
        :param text_for_bow: list of preprocessed documents for creating the bag-of-words
        :param labels: list of labels associated with each document (optional).
        """

        if self.contextualized_model is None:
            raise Exception("You should define a contextualized model if you want to create the embeddings")

        # TODO: this count vectorizer removes tokens that have len = 1, might be unexpected for the users
        self.vectorizer = CountVectorizer(token_pattern=r'\b[a-zA-Z]{2,}\b', vocabulary=wordlist)

        train_bow_embeddings = self.vectorizer.fit_transform(text_for_bow)
        train_contextualized_embeddings = bert_embeddings_from_list(text_for_contextual, self.contextualized_model)
        self.vocab = self.vectorizer.get_feature_names()
        self.id2token = {k: v for k, v in zip(range(0, len(self.vocab)), self.vocab)}

        if labels:
            self.label_encoder = OneHotEncoder()
            encoded_labels = self.label_encoder.fit_transform(np.array([labels]).reshape(-1, 1))
        else:
            encoded_labels = None

        return CTMDataset(train_contextualized_embeddings, train_bow_embeddings, self.id2token, encoded_labels)
    

topic_words_list = list(topic_word_dict.values())
qt = TopicModelDataPreparationNoNumber("sentence-transformers/all-MiniLM-L6-v2")
sp = WhiteSpacePreprocessing(textData.data, stopwords_language='english')
preprocessed_documents, unpreprocessed_corpus, vocab, retained_indices = sp.preprocess()
vectorizer_model = CountVectorizer(stop_words="english")
lemmatizer = WordNetLemmatizer()
preprocessed_documents_lemmatized = [' '.join([lemmatizer.lemmatize(w) for w in doc.split()]) for doc in preprocessed_documents]

document_topic = get_document_topic(topic_words_list, preprocessed_documents_lemmatized)
train_target_filtered = textData.targets.squeeze()[retained_indices]
flat_predict = torch.tensor(np.argmax(document_topic, axis=1))
flat_target = torch.tensor(train_target_filtered).to(flat_predict.device)
num_samples = flat_predict.shape[0]
match = _hungarian_match(flat_predict, flat_target, num_samples, 20)    
reordered_preds = torch.zeros(num_samples).to(flat_predict.device)
for pred_i, target_i in match:
    reordered_preds[flat_predict == pred_i] = int(target_i)
acc = int((reordered_preds == flat_target.float()).sum()) / float(num_samples)
print(acc)

0.17853986211773024


In [42]:
from sklearn.metrics.cluster import normalized_mutual_info_score

In [43]:
normalized_mutual_info_score(reordered_preds, flat_target)

0.2043057067221093

In [45]:
now = datetime.now().strftime('%y%m%d_%H%M%S')
results = get_topic_qualities(topic_words_list, args.palmetto_dir, reference_corpus=[doc.split() for doc in trainds.preprocess_ctm(trainds.nonempty_text)],
                              filename=f'results/{now}.txt')
results

./
[0.38297, 0.37417, 0.45394, 0.4029, 0.52711, 0.46109, 0.45418, 0.55963, 0.70803, 0.46116, 0.45936, 0.44008, 0.35423, 0.41555, 0.56226, 0.43982, 0.49678, 0.46482, 0.48524, 0.42333, 0.4014, 0.42947, 0.44147, 0.48501, 0.60479, 0.3648, 0.53699, 0.49483, 0.36144, 0.46488, 0.37283, 0.50334, 0.47903, 0.51682, 0.49874, 0.46293, 0.4598, 0.4054, 0.46753, 0.51152, 0.3641, 0.52537, 0.5616, 0.53919, 0.3365, 0.39359, 0.59823, 0.45953, 0.47088, 0.51779]
0.4671290000000002


{'topic_N': 50,
 'CV_wiki': 0.4671290000000002,
 'sim_w2v': 0.2019678517246961,
 'diversity': 0.75,
 'filename': 'results/240509_233632.txt'}

In [46]:
normalized = [coherence_normalize(doc) for doc in trainds.nonempty_text]
gensim_dict = Dictionary(normalized)

n_word = args.n_word
n_topic_word = n_word

words_to_idx = {k: v for v, k in enumerate(words)}
topic_word_dict = {}
topic_score_dict = {}
total_score_cat = []
for label in docs_per_class.cluster_label.index:
    total_score = resolution_score[label]
    score_higest = total_score.argsort()
    score_higest = score_higest[::-1]
    topic_word_list = [words[index] for index in score_higest]
    
    total_score_cat.append(total_score)
    topic_word_list = [word for word in topic_word_list if word not in stopwords]    
    topic_word_list = [word for word in topic_word_list if word in gensim_dict.token2id]
    topic_word_list = [word for word in topic_word_list if len(word) >= 3]    
    topic_word_dict[docs_per_class.cluster_label.iloc[label]] = topic_word_list[:int(n_topic_word)]
    topic_score_dict[docs_per_class.cluster_label.iloc[label]] = [total_score[words_to_idx[top_word]] for top_word in topic_word_list[:int(n_topic_word)]]
total_score_cat = np.stack(total_score_cat, axis = 0)

In [47]:
def remove_dup(seq):
    seen = set()
    seen_add = seen.add
    return [x for x in seq if not (x in seen or seen_add(x))]

topic_words_list = list(topic_word_dict.values())
topic_word_set = list(itertools.chain.from_iterable(pd.DataFrame.from_dict(topic_word_dict).values))
word_candidates = remove_dup(topic_word_set)[:n_word]
n_word = len(word_candidates)
n_word

3942

In [48]:
import pickle
with open('our_word_candidates_10000.pkl', 'wb') as f:
    pickle.dump(word_candidates, f)

In [49]:
weight_candidates = {}
for candidate in word_candidates:
    weight_candidates[candidate] = [total_score_cat[label, words_to_idx[candidate]] for label in range(n_cluster)]

In [50]:
weight_cand_to_idx = {k: v for v, k in enumerate(list(weight_candidates.keys()))}
weight_cand_matrix = np.array(list(weight_candidates.values()))

# Re-formulate the bow

In [51]:
def dist_match_loss(hiddens, alpha=1.0):
    device = hiddens.device
    hidden_dim = hiddens.shape[-1]
    H = np.random.randn(hidden_dim, hidden_dim)
    Q, R = qr(H) 
    rand_w = torch.Tensor(Q).to(device)
    loss_dist_match = get_swd_loss(hiddens, rand_w, alpha)
    return loss_dist_match


def js_div_loss(hidden1, hidden2):
    m = 0.5 * (hidden1 + hidden2)
    return kldiv(m.log(), hidden1) + kldiv(m.log(), hidden2)


def get_swd_loss(states, rand_w, alpha=1.0):
    device = states.device
    states_shape = states.shape
    states = torch.matmul(states, rand_w)
    states_t, _ = torch.sort(states.t(), dim=1)

    # Random vector with length from normal distribution
    states_prior = torch.Tensor(np.random.dirichlet([alpha]*states_shape[1], states_shape[0])).to(device) # (bsz, dim)
    states_prior = torch.matmul(states_prior, rand_w) # (dim, dim)
    states_prior_t, _ = torch.sort(states_prior.t(), dim=1) # (dim, bsz)
    return torch.mean(torch.sum((states_prior_t - states_t)**2, axis=0))

In [52]:
class Stage2Dataset(Dataset):
    def __init__(self, encoder, ds, basesim_matrix, word_candidates, k=1, lemmatize=False):
        self.lemmatize = lemmatize
        self.ds = ds
        self.org_list = self.ds.org_list
        self.nonempty_text = self.ds.nonempty_text
        english_stopwords = nltk.corpus.stopwords.words('english')
        self.stopwords_list = set(english_stopwords)
        self.vectorizer = CountVectorizer(vocabulary=word_candidates)
        self.vectorizer.fit(self.preprocess_ctm(self.nonempty_text)) 
        self.bow_list = []
        for sent in tqdm(self.nonempty_text):
            self.bow_list.append(self.vectorize(sent))
            
        sim_weight, sim_indices = basesim_matrix.topk(k=k, dim=-1)
        zip_iterator = zip(np.arange(len(sim_weight)), sim_indices.squeeze().data.numpy())
        self.pos_dict = dict(zip_iterator)
        
        self.embedding_list = []
        encoder_device = next(encoder.parameters()).device
        for org_input in tqdm(self.org_list):
            org_input_ids = org_input['input_ids'].to(encoder_device).reshape(1, -1)
            org_attention_mask = org_input['attention_mask'].to(encoder_device).reshape(1, -1)
            embedding = encoder(input_ids = org_input_ids, attention_mask = org_attention_mask)
            self.embedding_list.append(embedding['pooler_output'].squeeze().detach().cpu())
            
    
    def __len__(self):
        return len(self.org_list)
        
    def preprocess_ctm(self, documents):
        preprocessed_docs_tmp = documents
        preprocessed_docs_tmp = [doc.lower() for doc in preprocessed_docs_tmp]
        preprocessed_docs_tmp = [doc.translate(
            str.maketrans(string.punctuation, ' ' * len(string.punctuation))) for doc in preprocessed_docs_tmp]
        preprocessed_docs_tmp = [' '.join([w for w in doc.split() if len(w) > 0 and w not in self.stopwords_list])
                                 for doc in preprocessed_docs_tmp]
        if self.lemmatize:
            lemmatizer = WordNetLemmatizer()
            preprocessed_docs_tmp = [' '.join([lemmatizer.lemmatize(w) for w in doc.split()])
                                     for doc in preprocessed_docs_tmp]
        return preprocessed_docs_tmp
        
    def vectorize(self, text):
        text = self.preprocess_ctm([text])
        vectorized_input = self.vectorizer.transform(text)
        vectorized_input = vectorized_input.toarray().astype(np.float64)
#         vectorized_input = (vectorized_input != 0).astype(np.float64)

        # Get word distribution from BoW
        if vectorized_input.sum() == 0:
            vectorized_input += 1e-8
        vectorized_input = vectorized_input / vectorized_input.sum(axis=1, keepdims=True)
        assert abs(vectorized_input.sum() - vectorized_input.shape[0]) < 0.01
        
        vectorized_label = torch.tensor(vectorized_input, dtype=torch.float)
        return vectorized_label[0]
        
        
    def __getitem__(self, idx):
        pos_idx = self.pos_dict[idx]
        return self.embedding_list[idx], self.embedding_list[pos_idx], self.bow_list[idx], self.bow_list[pos_idx]

In [61]:
class Stage2TestDataset(Dataset):
    def __init__(self, encoder, ds, word_candidates, k=1, lemmatize=False):
        self.lemmatize = lemmatize
        self.ds = ds
        self.org_list = self.ds.org_list
        self.nonempty_text = self.ds.nonempty_text
        english_stopwords = nltk.corpus.stopwords.words('english')
        self.stopwords_list = set(english_stopwords)
        self.vectorizer = CountVectorizer(vocabulary=word_candidates)
        self.vectorizer.fit(self.preprocess_ctm(self.nonempty_text)) 
        self.bow_list = []
        for sent in tqdm(self.nonempty_text):
            self.bow_list.append(self.vectorize(sent))
        
        self.embedding_list = []
        encoder_device = next(encoder.parameters()).device
        for org_input in tqdm(self.org_list):
            org_input_ids = org_input['input_ids'].to(encoder_device).reshape(1, -1)
            org_attention_mask = org_input['attention_mask'].to(encoder_device).reshape(1, -1)
            embedding = encoder(input_ids = org_input_ids, attention_mask = org_attention_mask)
            self.embedding_list.append(embedding['pooler_output'].squeeze().detach().cpu())
            
    
    def __len__(self):
        return len(self.org_list)
        
    def preprocess_ctm(self, documents):
        preprocessed_docs_tmp = documents
        preprocessed_docs_tmp = [doc.lower() for doc in preprocessed_docs_tmp]
        preprocessed_docs_tmp = [doc.translate(
            str.maketrans(string.punctuation, ' ' * len(string.punctuation))) for doc in preprocessed_docs_tmp]
        preprocessed_docs_tmp = [' '.join([w for w in doc.split() if len(w) > 0 and w not in self.stopwords_list])
                                 for doc in preprocessed_docs_tmp]
        if self.lemmatize:
            lemmatizer = WordNetLemmatizer()
            preprocessed_docs_tmp = [' '.join([lemmatizer.lemmatize(w) for w in doc.split()])
                                     for doc in preprocessed_docs_tmp]
        return preprocessed_docs_tmp
        
    def vectorize(self, text):
        text = self.preprocess_ctm([text])
        vectorized_input = self.vectorizer.transform(text)
        vectorized_input = vectorized_input.toarray().astype(np.float64)
#         vectorized_input = (vectorized_input != 0).astype(np.float64)

        # Get word distribution from BoW
        if vectorized_input.sum() == 0:
            vectorized_input += 1e-8
        vectorized_input = vectorized_input / vectorized_input.sum(axis=1, keepdims=True)
        assert abs(vectorized_input.sum() - vectorized_input.shape[0]) < 0.01
        
        vectorized_label = torch.tensor(vectorized_input, dtype=torch.float)
        return vectorized_label[0]
        
        
    def __getitem__(self, idx):
        return self.embedding_list[idx], self.bow_list[idx]

In [62]:
finetuneds = Stage2Dataset(model.encoder, trainds, basesim_matrix, word_candidates, lemmatize=True)    

kldiv = torch.nn.KLDivLoss(reduction='batchmean')
vocab_dict = finetuneds.vectorizer.vocabulary_
vocab_dict_reverse = {i:v for v, i in vocab_dict.items()}
print(n_word)

100%|██████████| 11314/11314 [00:04<00:00, 2527.14it/s]
100%|██████████| 11314/11314 [01:15<00:00, 149.60it/s]

3942





# Stage 3

In [63]:
def measure_hungarian_score(topic_dist, train_target):
    dist = topic_dist
    train_target_filtered = train_target
    flat_predict = torch.tensor(np.argmax(dist, axis=1))
    flat_target = torch.tensor(train_target_filtered).to(flat_predict.device)
    num_samples = flat_predict.shape[0]
    num_classes = dist.shape[1]
    match = _hungarian_match(flat_predict, flat_target, num_samples, num_classes)    
    reordered_preds = torch.zeros(num_samples).to(flat_predict.device)
    for pred_i, target_i in match:
        reordered_preds[flat_predict == pred_i] = int(target_i)
    acc = int((reordered_preds == flat_target.float()).sum()) / float(num_samples)
    return acc

In [64]:
weight_cands = torch.tensor(weight_cand_matrix.max(axis=1)).cuda(gpu_ids[0]).float()

# Main

In [65]:
testds = BertDataset(bert=bert_name, text_list=textData.test_data, N_word=n_word, vectorizer=None, lemmatize=True)
testds2 = Stage2TestDataset(model.encoder, testds, word_candidates, lemmatize=True)


100%|██████████| 7532/7532 [00:06<00:00, 1216.02it/s]
100%|██████████| 7532/7532 [00:02<00:00, 2638.18it/s]
100%|██████████| 7532/7532 [00:49<00:00, 150.99it/s]


In [66]:
from evaluation import evaluate_classification, evaluate_clustering

results_list = []

for i in range(args.stage_2_repeat):
    model = ContBertTopicExtractorAE(N_topic=n_topic, N_word=args.n_word, bert=bert_name, bert_dim=768)
    model.load_state_dict(torch.load(model_stage1_name), strict=True)
    model.beta = nn.Parameter(torch.Tensor(model.N_topic, n_word))
    nn.init.xavier_uniform_(model.beta)
    model.beta_batchnorm = nn.Sequential()
    model.cuda(gpu_ids[0])
    
    losses = AverageMeter()
    dlosses = AverageMeter() 
    rlosses = AverageMeter()
    closses = AverageMeter()
    distlosses = AverageMeter()
    trainloader = DataLoader(finetuneds, batch_size=bsz, shuffle=True, num_workers=0)
    testloader = DataLoader(testds2, batch_size=bsz, shuffle=False, num_workers=0)
    memoryloader = DataLoader(finetuneds, batch_size=bsz * 2, shuffle=False, num_workers=0)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.stage_2_lr)

    memory_queue = F.softmax(torch.randn(512, n_topic).cuda(gpu_ids[0]), dim=1)
    print("Coeff   / regul: {:.5f} - recon: {:.5f} - c: {:.5f} - dist: {:.5f} ".format(args.coeff_2_regul, 
                                                                                        args.coeff_2_recon,
                                                                                        args.coeff_2_cons,
                                                                                        args.coeff_2_dist))
    for epoch in range(50):
        model.train()
        model.encoder.eval()
        for batch_idx, batch in enumerate(trainloader):
            org_input, pos_input, org_bow, pos_bow = batch
            org_input = org_input.cuda(gpu_ids[0])
            org_bow = org_bow.cuda(gpu_ids[0])
            pos_input = pos_input.cuda(gpu_ids[0])
            pos_bow = pos_bow.cuda(gpu_ids[0])

            batch_size = org_input_ids.size(0)

            org_dists, org_topic_logit = model.decode(org_input)
            pos_dists, pos_topic_logit = model.decode(pos_input)

            org_topic = F.softmax(org_topic_logit, dim=1)
            pos_topic = F.softmax(pos_topic_logit, dim=1)

            # reconstruction loss
            # batchmean
#             org_target = torch.matmul(org_topic.detach(), weight_cands)
#             pos_target = torch.matmul(pos_topic.detach(), weight_cands)
            
#             _, org_target = torch.max(org_topic.detach(), 1)
#             _, pos_target = torch.max(pos_topic.detach(), 1)
            
            recons_loss = torch.mean(-torch.sum(torch.log(org_dists + 1E-10) * (org_bow * weight_cands), axis=1), axis=0)
            recons_loss += torch.mean(-torch.sum(torch.log((1-org_dists) + 1E-10) * ((1-org_bow) * weight_cands), axis=1), axis=0)
            recons_loss += torch.mean(-torch.sum(torch.log(pos_dists + 1E-10) * (pos_bow * weight_cands), axis=1), axis=0)
            recons_loss += torch.mean(-torch.sum(torch.log((1-pos_dists) + 1E-10) * ((1-pos_bow) * weight_cands), axis=1), axis=0)
            recons_loss *= 0.5

            # consistency loss
            pos_sim = torch.sum(org_topic * pos_topic, dim=-1)
            cons_loss = -pos_sim.mean()

            # distribution loss
            # batchmean
            distmatch_loss = dist_match_loss(torch.cat((org_topic, pos_topic), dim=0), dirichlet_alpha_2)
            

            loss = args.coeff_2_recon * recons_loss + \
                   args.coeff_2_cons * cons_loss + \
                   args.coeff_2_dist * distmatch_loss 
            
            losses.update(loss.item(), bsz)
            closses.update(cons_loss.item(), bsz)
            rlosses.update(recons_loss.item(), bsz)
            distlosses.update(distmatch_loss.item(), bsz)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        print("Epoch-{} / recon: {:.5f} - dist: {:.5f} - cons: {:.5f}".format(epoch, rlosses.avg, distlosses.avg, closses.avg))

    print("------- Evaluation results -------")
    all_list = {}
    for e, i in enumerate(model.beta.cpu().topk(15, dim=1).indices):
        word_list = []
        for j in i:
            word_list.append(vocab_dict_reverse[j.item()])
        all_list[e] = word_list
        print("topic-{}".format(e), word_list)

    topic_words_list = list(all_list.values())
    now = datetime.now().strftime('%y%m%d_%H%M%S')
    results = get_topic_qualities(topic_words_list, palmetto_dir=args.palmetto_dir,
                                  reference_corpus=[doc.split() for doc in trainds.preprocess_ctm(trainds.nonempty_text)],
                                  filename=f'results/{now}.txt')
    train_theta = []
    test_theta = []
    for batch_idx, batch in tqdm(enumerate(trainloader)):
        org_input, _, org_bow, _ = batch
        org_input = org_input.cuda(gpu_ids[0])
        org_bow = org_bow.cuda(gpu_ids[0])
        # pos_input = pos_input.cuda(gpu_ids[0])
        # pos_bow = pos_bow.cuda(gpu_ids[0])

        batch_size = org_input_ids.size(0)

        org_dists, org_topic_logit = model.decode(org_input)
        # pos_dists, pos_topic_logit = model.decode(pos_input)

        org_topic = F.softmax(org_topic_logit, dim=1)
        # pos_topic = F.softmax(pos_topic_logit, dim=1)
        
        train_theta.append(org_topic.detach().cpu())
    
    train_theta = np.concatenate(train_theta, axis=0)

    for batch_idx, batch in tqdm(enumerate(testloader)): 
        org_input, org_bow = batch
        org_input = org_input.cuda(gpu_ids[0])
        org_bow = org_bow.cuda(gpu_ids[0])
        # pos_input = pos_input.cuda(gpu_ids[0])
        # pos_bow = pos_bow.cuda(gpu_ids[0])

        batch_size = org_input_ids.size(0)

        org_dists, org_topic_logit = model.decode(org_input)
        # pos_dists, pos_topic_logit = model.decode(pos_input)

        org_topic = F.softmax(org_topic_logit, dim=1)
        # pos_topic = F.softmax(pos_topic_logit, dim=1)
        
        test_theta.append(org_topic.detach().cpu())
    
    test_theta = np.concatenate(test_theta, axis=0)
    
    classification_res = evaluate_classification(train_theta, test_theta, textData.targets, textData.test_targets)
    clustering_res = evaluate_clustering(test_theta, textData.test_targets)
    
    results.update(classification_res)
    results.update(clustering_res)
    
    
    if should_measure_hungarian:
        topic_dist = torch.empty((0, n_topic))
        model.eval()
        evalloader = DataLoader(finetuneds, batch_size=bsz, shuffle=False, num_workers=0)
        non_empty_text_index = [i for i, text in enumerate(textData.data) if len(text) != 0]
        assert len(finetuneds) == len(non_empty_text_index)
        with torch.no_grad():
            for batch in tqdm(evalloader):
                org_input, _, org_bow, __ = batch
                org_input = org_input.cuda(gpu_ids[0])
                org_dists, org_topic_logit = model.decode(org_input)
                org_topic = F.softmax(org_topic_logit, dim=1)
                topic_dist = torch.cat((topic_dist, org_topic.detach().cpu()), 0)
        label_accuracy = measure_hungarian_score(
                             topic_dist,
                             [target for i, target in enumerate(textData.targets)
                              if i in non_empty_text_index]
                         )
        results['label_match'] = label_accuracy

    print(results)
    print()
    results_list.append(results)

Coeff   / regul: 1.00000 - recon: 1.00000 - c: 1.00000 - dist: 1.00000 
Epoch-0 / recon: 2.86379 - dist: 0.08176 - cons: -0.28128
Epoch-1 / recon: 2.75634 - dist: 0.06411 - cons: -0.32115
Epoch-2 / recon: 2.70017 - dist: 0.05786 - cons: -0.33636
Epoch-3 / recon: 2.66723 - dist: 0.05453 - cons: -0.34645
Epoch-4 / recon: 2.64549 - dist: 0.05213 - cons: -0.35427
Epoch-5 / recon: 2.63038 - dist: 0.05052 - cons: -0.35963
Epoch-6 / recon: 2.61908 - dist: 0.04932 - cons: -0.36352
Epoch-7 / recon: 2.61047 - dist: 0.04851 - cons: -0.36641
Epoch-8 / recon: 2.60358 - dist: 0.04786 - cons: -0.36856
Epoch-9 / recon: 2.59796 - dist: 0.04735 - cons: -0.37056
Epoch-10 / recon: 2.59331 - dist: 0.04692 - cons: -0.37222
Epoch-11 / recon: 2.58939 - dist: 0.04643 - cons: -0.37351
Epoch-12 / recon: 2.58614 - dist: 0.04608 - cons: -0.37446
Epoch-13 / recon: 2.58316 - dist: 0.04568 - cons: -0.37548
Epoch-14 / recon: 2.58054 - dist: 0.04528 - cons: -0.37639
Epoch-15 / recon: 2.57820 - dist: 0.04493 - cons: -0.

89it [00:00, 383.62it/s]
59it [00:00, 536.36it/s]
100%|██████████| 89/89 [00:00<00:00, 408.26it/s]


{'topic_N': 50, 'CV_wiki': 0.37257380000000007, 'sim_w2v': 0.11059587169815185, 'diversity': 0.31066666666666665, 'filename': 'results/240509_235251.txt', 'acc': 0.07448220924057355, 'macro-F1': 0.0638077114371968, 'Purity': 0.49880509824747743, 'NMI': 0.4707201633301822, 'label_match': 0.4344175357963585}

Coeff   / regul: 1.00000 - recon: 1.00000 - c: 1.00000 - dist: 1.00000 
Epoch-0 / recon: 2.86838 - dist: 0.10385 - cons: -0.24990
Epoch-1 / recon: 2.76356 - dist: 0.08239 - cons: -0.28325
Epoch-2 / recon: 2.70622 - dist: 0.07164 - cons: -0.30221
Epoch-3 / recon: 2.67133 - dist: 0.06538 - cons: -0.31726
Epoch-4 / recon: 2.64843 - dist: 0.06142 - cons: -0.32773
Epoch-5 / recon: 2.63246 - dist: 0.05876 - cons: -0.33427
Epoch-6 / recon: 2.62044 - dist: 0.05669 - cons: -0.33964
Epoch-7 / recon: 2.61118 - dist: 0.05496 - cons: -0.34379
Epoch-8 / recon: 2.60384 - dist: 0.05372 - cons: -0.34672
Epoch-9 / recon: 2.59788 - dist: 0.05267 - cons: -0.34953
Epoch-10 / recon: 2.59298 - dist: 0.051

89it [00:00, 357.43it/s]
59it [00:00, 536.36it/s]
100%|██████████| 89/89 [00:00<00:00, 412.04it/s]


{'topic_N': 50, 'CV_wiki': 0.37163199999999996, 'sim_w2v': 0.12865569502834454, 'diversity': 0.32666666666666666, 'filename': 'results/240509_235541.txt', 'acc': 0.07859798194370685, 'macro-F1': 0.057332348401032615, 'Purity': 0.5102230483271375, 'NMI': 0.4838984176956681, 'label_match': 0.43980908608803254}

Coeff   / regul: 1.00000 - recon: 1.00000 - c: 1.00000 - dist: 1.00000 
Epoch-0 / recon: 2.86445 - dist: 0.07908 - cons: -0.29007
Epoch-1 / recon: 2.75859 - dist: 0.06574 - cons: -0.32267
Epoch-2 / recon: 2.70197 - dist: 0.06057 - cons: -0.33736
Epoch-3 / recon: 2.66868 - dist: 0.05741 - cons: -0.34523
Epoch-4 / recon: 2.64704 - dist: 0.05537 - cons: -0.35014
Epoch-5 / recon: 2.63193 - dist: 0.05404 - cons: -0.35343
Epoch-6 / recon: 2.62052 - dist: 0.05290 - cons: -0.35606
Epoch-7 / recon: 2.61158 - dist: 0.05212 - cons: -0.35827
Epoch-8 / recon: 2.60460 - dist: 0.05148 - cons: -0.35965
Epoch-9 / recon: 2.59885 - dist: 0.05108 - cons: -0.36048
Epoch-10 / recon: 2.59400 - dist: 0.0

89it [00:00, 193.90it/s]
59it [00:00, 287.81it/s]
100%|██████████| 89/89 [00:00<00:00, 284.35it/s]


{'topic_N': 50, 'CV_wiki': 0.3755574000000001, 'sim_w2v': 0.13726825836771653, 'diversity': 0.344, 'filename': 'results/240509_235837.txt', 'acc': 0.04580456718003186, 'macro-F1': 0.03911680615461721, 'Purity': 0.5132766861391397, 'NMI': 0.4881399747760182, 'label_match': 0.46057981262153086}

Coeff   / regul: 1.00000 - recon: 1.00000 - c: 1.00000 - dist: 1.00000 
Epoch-0 / recon: 2.86565 - dist: 0.08903 - cons: -0.26928
Epoch-1 / recon: 2.75951 - dist: 0.06914 - cons: -0.31323
Epoch-2 / recon: 2.70354 - dist: 0.06186 - cons: -0.33037
Epoch-3 / recon: 2.67056 - dist: 0.05792 - cons: -0.34050
Epoch-4 / recon: 2.64936 - dist: 0.05547 - cons: -0.34662
Epoch-5 / recon: 2.63447 - dist: 0.05364 - cons: -0.35124
Epoch-6 / recon: 2.62334 - dist: 0.05231 - cons: -0.35462
Epoch-7 / recon: 2.61460 - dist: 0.05132 - cons: -0.35762
Epoch-8 / recon: 2.60775 - dist: 0.05050 - cons: -0.36007
Epoch-9 / recon: 2.60221 - dist: 0.04986 - cons: -0.36155
Epoch-10 / recon: 2.59761 - dist: 0.04934 - cons: -0.

89it [00:00, 200.90it/s]
59it [00:00, 281.85it/s]
100%|██████████| 89/89 [00:00<00:00, 263.31it/s]


{'topic_N': 50, 'CV_wiki': 0.38835539999999985, 'sim_w2v': 0.11422481640695953, 'diversity': 0.3, 'filename': 'results/240510_000146.txt', 'acc': 0.020578863515666488, 'macro-F1': 0.014714544879857567, 'Purity': 0.5103558151885289, 'NMI': 0.4826379132355541, 'label_match': 0.4292911437157504}

Coeff   / regul: 1.00000 - recon: 1.00000 - c: 1.00000 - dist: 1.00000 
Epoch-0 / recon: 2.86267 - dist: 0.07934 - cons: -0.27902
Epoch-1 / recon: 2.75148 - dist: 0.06009 - cons: -0.32727
Epoch-2 / recon: 2.69233 - dist: 0.05348 - cons: -0.34603
Epoch-3 / recon: 2.65752 - dist: 0.05004 - cons: -0.35642
Epoch-4 / recon: 2.63523 - dist: 0.04789 - cons: -0.36339
Epoch-5 / recon: 2.61926 - dist: 0.04619 - cons: -0.36889
Epoch-6 / recon: 2.60745 - dist: 0.04482 - cons: -0.37331
Epoch-7 / recon: 2.59835 - dist: 0.04378 - cons: -0.37698
Epoch-8 / recon: 2.59114 - dist: 0.04298 - cons: -0.38001
Epoch-9 / recon: 2.58535 - dist: 0.04233 - cons: -0.38237
Epoch-10 / recon: 2.58044 - dist: 0.04183 - cons: -0.

89it [00:00, 207.46it/s]
59it [00:00, 308.90it/s]
100%|██████████| 89/89 [00:00<00:00, 273.85it/s]


{'topic_N': 50, 'CV_wiki': 0.3743447999999999, 'sim_w2v': 0.11788714024287275, 'diversity': 0.32, 'filename': 'results/240510_000435.txt', 'acc': 0.027748274030801913, 'macro-F1': 0.020221399226255706, 'Purity': 0.5165958576739246, 'NMI': 0.4949721449785493, 'label_match': 0.4359200989923988}



In [67]:
results_df = pd.DataFrame(results_list)
print(results_df)
print('mean')
print(results_df.mean())
print('std')
print(results_df.std())

if args.result_file is not None:
    result_filename = f'results/{args.result_file}'
else:
    result_filename = f'results/{now}.tsv'

results_df.to_csv(result_filename, sep='\t')

   topic_N   CV_wiki   sim_w2v  diversity                   filename  \
0       50  0.372574  0.110596   0.310667  results/240509_235251.txt   
1       50  0.371632  0.128656   0.326667  results/240509_235541.txt   
2       50  0.375557  0.137268   0.344000  results/240509_235837.txt   
3       50  0.388355  0.114225   0.300000  results/240510_000146.txt   
4       50  0.374345  0.117887   0.320000  results/240510_000435.txt   

        acc  macro-F1    Purity       NMI  label_match  
0  0.074482  0.063808  0.498805  0.470720     0.434418  
1  0.078598  0.057332  0.510223  0.483898     0.439809  
2  0.045805  0.039117  0.513277  0.488140     0.460580  
3  0.020579  0.014715  0.510356  0.482638     0.429291  
4  0.027748  0.020221  0.516596  0.494972     0.435920  
mean
topic_N        50.000000
CV_wiki         0.376493
sim_w2v         0.121726
diversity       0.320267
acc             0.049442
macro-F1        0.039039
Purity          0.509851
NMI             0.484074
label_match     0.44