In [44]:
# Setting Arguments
args_text = '--base-model sentence-transformers/all-MiniLM-L6-v2 '+\
            '--dataset news --n-word 5000 --epochs-1 200 --epochs-2 50 ' + \
            '--bsz 32 --stage-2-lr 2e-2 --stage-2-repeat 5 --coeff-1-dist 50 '+ \
            '--n-cluster 50 ' + \
            '--stage-1-ckpt trained_model/news_model_all-MiniLM-L6-v2_stage1_50t_5000w_199e.ckpt'

  self.nonempty_text = [re.sub('\S*@\S*\s?', '', sent) for sent in self.nonempty_text]
  self.nonempty_text = [re.sub('\s+', ' ', sent) for sent in self.nonempty_text]
[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\PDT\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\PDT\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\PDT\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [45]:
import re
import os
import sys
import time
import copy
import math
import argparse
import string
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchtools.optim import RangerLars
import gensim.downloader
import itertools

from scipy.stats import ortho_group
from scipy.optimize import linear_sum_assignment as linear_assignment
import matplotlib.pyplot as plt
from sentence_transformers import SentenceTransformer

import numpy as np
from tqdm import tqdm_notebook
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from sklearn.datasets import fetch_20newsgroups
from nltk.corpus import stopwords

from sklearn.feature_extraction.text import CountVectorizer
from utils import AverageMeter
from collections import OrderedDict

import pandas as pd
from sklearn.preprocessing import normalize
from sklearn.feature_extraction.text import TfidfTransformer, CountVectorizer
from gensim.corpora.dictionary import Dictionary
from pytorch_transformers import *
from sklearn.mixture import GaussianMixture
import scipy.stats
from sklearn.decomposition import PCA
from sklearn.cluster import OPTICS
from nltk.corpus import stopwords

from gensim.models.coherencemodel import CoherenceModel
from tqdm import tqdm
import scipy.sparse as sp
import nltk
from nltk.corpus import stopwords

from datetime import datetime
from itertools import combinations
import gensim.downloader
from scipy.linalg import qr
from data import *
from model import ContBertTopicExtractorAE
from evaluation import get_topic_qualities
import warnings
warnings.filterwarnings("ignore")

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [46]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "0" 

# Data loading

In [47]:
def _parse_args():
    parser = argparse.ArgumentParser(description='Contrastive topic modeling')
    parser.add_argument('--epochs-1', default=100, type=int,
                        help='Number of training epochs for Stage 1')
    parser.add_argument('--epochs-2', default=10, type=int,
                        help='Number of training epochs for Stage 2')
    parser.add_argument('--bsz', type=int, default=64,
                        help='Batch size')
    parser.add_argument('--dataset', default='news', type=str,
                        choices=['news', 'twitter', 'wiki', 'nips', 'stackoverflow', 'reuters', 'r52', 'imdb', 'agnews', 'yahoo'],
                        help='Name of the dataset')
    parser.add_argument('--n-cluster', default=50, type=int,
                        help='Number of clusters')
    parser.add_argument('--n-topic', type=int,
                        help='Number of topics. If not specified, use same value as --n-cluster')
    parser.add_argument('--n-word', default=2000, type=int,
                        help='Number of words in vocabulary')
    
    parser.add_argument('--base-model', type=str,
                        help='Name of base model in huggingface library.')
    
    parser.add_argument('--gpus', default=[0,1], type=int, nargs='+',
                        help='List of GPU numbers to use. Use 0 by default')
    
    parser.add_argument('--coeff-1-sim', default=1.0, type=float,
                        help='Coefficient for NN dot product similarity loss (Phase 1)')
    parser.add_argument('--coeff-1-dist', default=1.0, type=float,
                        help='Coefficient for NN SWD distribution loss (Phase 1)')
    parser.add_argument('--dirichlet-alpha-1', type=float,
                        help='Parameter for Dirichlet distribution (Phase 1). Use 1/n_topic by default.')
    
    parser.add_argument('--stage-1-ckpt', type=str,
                        help='Name of torch checkpoint file Stage 1. If this argument is given, skip Stage 1.')
    
    parser.add_argument('--coeff-2-recon', default=1.0, type=float,
                        help='Coefficient for VAE reconstruction loss (Phase 2)')
    parser.add_argument('--coeff-2-regul', default=1.0, type=float,
                        help='Coefficient for VAE KLD regularization loss (Phase 2)')
    parser.add_argument('--coeff-2-cons', default=1.0, type=float,
                        help='Coefficient for CL consistency loss (Phase 2)')
    parser.add_argument('--coeff-2-dist', default=1.0, type=float,
                        help='Coefficient for CL SWD distribution matching loss (Phase 2)')
    parser.add_argument('--dirichlet-alpha-2', type=float,
                        help='Parameter for Dirichlet distribution (Phase 2). Use same value as dirichlet-alpha-1 by default.')
    
    parser.add_argument('--stage-2-lr', default=2e-1, type=float,
                        help='Learning rate of phase 2')
    parser.add_argument('--stage-2-repeat', default=5, type=int,
                        help='Repetition count of phase 2')
    
    parser.add_argument('--result-file', type=str,
                        help='File name for result summary')
    parser.add_argument('--palmetto-dir', type=str, default='./',
                        help='Directory where palmetto JAR and the Wikipedia index are. For evaluation')
    
    
    # Check if the code is run in Jupyter notebook
    is_in_jupyter = False
    try:
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            is_in_jupyter = True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            is_in_jupyter = False  # Terminal running IPython
        else:
            is_in_jupyter = False  # Other type (?)
    except NameError:
        is_in_jupyter = False
    
    if is_in_jupyter:
        return parser.parse_args(args=args_text.split())
    else:
        return parser.parse_args()

def data_load(dataset_name):
    should_measure_hungarian = False
    if dataset_name == 'news':
        textData = newsData()
        should_measure_hungarian = True
    elif dataset_name == 'imdb':
        textData = IMDBData()
    elif dataset_name == 'agnews':
        textData = AGNewsData()
    elif dataset_name == 'yahoo':
        textData = YahooData()
    elif dataset_name == 'twitter':
        textData = twitterData('/home/data/topicmodel/twitter_covid19.tsv')
    elif dataset_name == 'wiki':
        textData = wikiData('/home/data/topicmodel/smplAbstracts/')
    elif dataset_name == 'nips':
        textData = nipsData('/home/data/topicmodel/papers.csv')
    elif dataset_name == 'stackoverflow':
        textData = stackoverflowData('/home/data/topicmodel/stack_overflow.csv')
    elif dataset_name == 'reuters':
        textData = reutersData('/home/data/topicmodel/reuters-21578.txt')
    elif dataset_name == 'r52':
        textData = r52Data('/home/data/topicmodel/r52/')
        should_measure_hungarian = True
    return textData, should_measure_hungarian

In [48]:
args = _parse_args()
bsz = args.bsz
epochs_1 = args.epochs_1
epochs_2 = args.epochs_2

n_cluster = args.n_cluster
n_topic = args.n_topic if (args.n_topic is not None) else n_cluster
args.n_topic = n_topic

textData, should_measure_hungarian = data_load(args.dataset)

ema_alpha = 0.99
n_word = args.n_word
if args.dirichlet_alpha_1 is None:
    dirichlet_alpha_1 = 1 / n_cluster
else:
    dirichlet_alpha_1 = args.dirichlet_alpha_1
if args.dirichlet_alpha_2 is None:
    dirichlet_alpha_2 = dirichlet_alpha_1
else:
    dirichlet_alpha_2 = args.dirichlet_alpha_2
    
bert_name = args.base_model
bert_name_short = bert_name.split('/')[-1]
gpu_ids = args.gpus

skip_stage_1 = (args.stage_1_ckpt is not None)

In [49]:
trainds = BertDataset(bert=bert_name, text_list=textData.data, N_word=n_word, vectorizer=None, lemmatize=True)
basesim_path = f"./save/{args.dataset}_{bert_name_short}_basesim_matrix_full.pkl"
if os.path.isfile(basesim_path) == False:
    model = SentenceTransformer(bert_name.split('/')[-1], device='cuda')
    base_result_list = []
    for text in tqdm_notebook(trainds.nonempty_text):
        base_result_list.append(model.encode(text))
        
    base_result_embedding = np.stack(base_result_list)
    basereduced_norm = F.normalize(torch.tensor(base_result_embedding), dim=-1)
    basesim_matrix = torch.mm(basereduced_norm, basereduced_norm.t())
    ind = np.diag_indices(basesim_matrix.shape[0])
    basesim_matrix[ind[0], ind[1]] = torch.ones(basesim_matrix.shape[0]) * -1
    torch.save(basesim_matrix, basesim_path)
else:
    basesim_matrix = torch.load(basesim_path)

100%|██████████| 11314/11314 [00:05<00:00, 1910.66it/s]


# Step 1: Discovery neighborhood pairs and clustering

In [50]:
def dist_match_loss(hidden, alpha=1.0):
    device = hidden.device
    hidden_dim = hidden.shape[-1]
    rand_w = torch.Tensor(np.eye(hidden_dim, dtype=np.float64)).to(device)
    loss_dist_match = get_swd_loss(hidden, rand_w, alpha)
    return loss_dist_match


def get_swd_loss(states, rand_w, alpha=1.0):
    device = states.device
    states_shape = states.shape
    states = torch.matmul(states, rand_w)
    states_t, _ = torch.sort(states.t(), dim=1)

    states_prior = torch.Tensor(np.random.dirichlet([alpha]*states_shape[1], states_shape[0])).to(device) # (bsz, dim)
    states_prior = torch.matmul(states_prior, rand_w) # (dim, dim)
    states_prior_t, _ = torch.sort(states_prior.t(), dim=1) # (dim, bsz)
    return torch.mean(torch.sum((states_prior_t - states_t) ** 2, axis=0))

In [51]:
try:
    del model
except:
    pass
finally:
    torch.cuda.empty_cache()

In [52]:
if not skip_stage_1:
    model = ContBertTopicExtractorAE(N_topic=n_cluster, N_word=n_word, bert=bert_name, bert_dim=768)
    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model, device_ids=gpu_ids)
    model.cuda(gpu_ids[0])

In [53]:
bsz = args.bsz = 128

In [54]:
if not skip_stage_1:
    losses = AverageMeter()
    closses = AverageMeter() 
    dlosses = AverageMeter() 
    rlosses = AverageMeter() 
    criterion = nn.CrossEntropyLoss()

    temp_basesim_matrix = copy.deepcopy(basesim_matrix)
    finetuneds = FinetuneDataset(trainds, temp_basesim_matrix, ratio=1, k=1)
    trainloader = DataLoader(finetuneds, batch_size=bsz, shuffle=False, num_workers=0)
    memoryloader = DataLoader(finetuneds, batch_size=bsz * 2, shuffle=False, num_workers=0)

    optimizer = RangerLars(model.parameters(), lr=0.001, weight_decay=0.0001)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

    global_step = 0
    memory_queue = F.softmax(torch.randn(512, n_cluster).cuda(gpu_ids[0]), dim=1)
    for epoch in range(epochs_1):
        model.train()
        #ema_model.train()
        tbar = tqdm_notebook(trainloader)
        for batch_idx, batch in enumerate(tbar):       
            org_input, pos_input, _, _ = batch
            org_input_ids = org_input['input_ids'].cuda(gpu_ids[0])
            org_attention_mask = org_input['attention_mask'].cuda(gpu_ids[0])
            pos_input_ids = pos_input['input_ids'].cuda(gpu_ids[0])
            pos_attention_mask = pos_input['attention_mask'].cuda(gpu_ids[0])
            batch_size = org_input_ids.size(0)

            all_input_ids = torch.cat((org_input_ids, pos_input_ids), dim=0)
            all_attention_masks = torch.cat((org_attention_mask, pos_attention_mask), dim=0)
            all_topics, _ = model(all_input_ids, all_attention_masks, return_topic=True)

            orig_topic, pos_topic = torch.split(all_topics, len(all_topics) // 2)
            pos_sim = torch.sum(orig_topic * pos_topic, dim=-1)

            # consistency loss
            consistency_loss = -pos_sim.mean()

            # distribution matching loss
            memory_queue = torch.cat((memory_queue.detach(), all_topics), dim=0)[all_topics.size(0):]
            distmatch_loss = dist_match_loss(memory_queue, dirichlet_alpha_1)
            loss = args.coeff_1_sim * consistency_loss + \
                   10 * distmatch_loss

            losses.update(loss.item(), bsz)
            closses.update(consistency_loss.item(), bsz)
            dlosses.update(distmatch_loss.item(), bsz)

            tbar.set_description("Epoch-{} / consistency: {:.5f} - dist: {:.5f}".format(epoch, 
                                                                                        closses.avg, 
                                                                                        dlosses.avg), refresh=True)
            tbar.refresh()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()        
            global_step += 1
        scheduler.step()

In [55]:
if not skip_stage_1:
    model_stage1_name = f'./trained_model/{args.dataset}_model_{bert_name_short}_stage1_{args.n_topic}t_{args.n_word}w_{epoch}e.ckpt'
    torch.save(model.module.state_dict(), model_stage1_name)
else:
    model_stage1_name = args.stage_1_ckpt

# Stage 2: extract vocab set

### Load model

In [56]:
from word_embedding_utils import *

In [57]:
try:
    del model
except:
    pass
try:
    del ema_model
except:
    pass
torch.cuda.empty_cache()

In [58]:
model = ContBertTopicExtractorAE(N_topic=n_cluster, N_word=n_word, bert=bert_name, bert_dim=768)
model.cuda(gpu_ids[0])

model.load_state_dict(torch.load(model_stage1_name), strict=True)

<All keys matched successfully>

In [59]:
temp_basesim_matrix = copy.deepcopy(basesim_matrix)
finetuneds = FinetuneDataset(trainds, temp_basesim_matrix, ratio=1, k=1)
memoryloader = DataLoader(finetuneds, batch_size=bsz * 2, shuffle=False, num_workers=0)
result_list = []
model.eval()
with torch.no_grad():
    for idx, batch in enumerate(tqdm_notebook(memoryloader)):        
        org_input, _, _, _ = batch
        org_input_ids = org_input['input_ids'].to(gpu_ids[0])
        org_attention_mask = org_input['attention_mask'].to(gpu_ids[0])
        topic, embed = model(org_input_ids, org_attention_mask, return_topic = True)
        result_list.append(topic)
result_embedding = torch.cat(result_list)
_, result_topic = torch.max(result_embedding, 1)

  0%|          | 0/45 [00:00<?, ?it/s]

In [60]:
d = {'text': trainds.preprocess_ctm(trainds.nonempty_text), 
     'cluster_label': result_topic.cpu().numpy()}
cluster_df = pd.DataFrame(data=d)

In [61]:
docs_per_class = cluster_df.groupby(['cluster_label'], as_index=False).agg({'text': ' '.join})

In [62]:
supertxt = ""
for idx, row in docs_per_class.iterrows():
    supertxt = supertxt + row['text'] + " "
wwwwwwwwwwwwww = supertxt.split()
ssssssssssssss = set(wwwwwwwwwwwwww)
print(len(ssssssssssssss))

5000


In [63]:
# count_vectorizer = CountVectorizer(token_pattern=r'\b[a-zA-Z]{2,}\b')
count_vectorizer = CountVectorizer()
ctfidf_vectorizer = CTFIDFVectorizer()
count = count_vectorizer.fit_transform(docs_per_class.text)
ctfidf = ctfidf_vectorizer.fit_transform(count, n_samples=len(cluster_df)).toarray()
words = count_vectorizer.get_feature_names()

In [64]:
len(words)

5000

In [65]:
# transport to gensim
(gensim_corpus, gensim_dict) = vect2gensim(count_vectorizer, count)
vocab_list = set(gensim_dict.token2id.keys())
stopwords = set(line.strip() for line in open('stopwords_en.txt'))

In [66]:
normalized = [coherence_normalize(doc) for doc in trainds.nonempty_text]
gensim_dict = Dictionary(normalized)
resolution_score = (ctfidf - np.min(ctfidf, axis=1, keepdims=True)) / (np.max(ctfidf, axis=1, keepdims=True) - np.min(ctfidf, axis=1, keepdims=True))

n_word = args.n_word
# n_topic_word = n_word / len(docs_per_class.cluster_label.index)
n_topic_word = n_word
n_topic_word = 15

topic_word_dict = {}
for label in docs_per_class.cluster_label.index:
    total_score = resolution_score[label]
    score_higest = total_score.argsort()
    score_higest = score_higest[::-1]
    topic_word_list = [words[index] for index in score_higest]
    
    # topic_word_list = [word for word in topic_word_list if len(word) >= 3]    
    # topic_word_list = [word for word in topic_word_list if word not in stopwords]    
    # topic_word_list = [word for word in topic_word_list if word in gensim_dict.token2id]
    topic_word_dict[docs_per_class.cluster_label.iloc[label]] = topic_word_list[:int(n_topic_word)]

In [67]:
for key in topic_word_dict:
    print(f"{key}: {topic_word_dict[key]},")
topic_words_list = list(topic_word_dict.values())

0: ['car', 'cars', 'dealer', 'saturn', 'toyota', 'price', 'lights', 'models', 'profit', 'engine', 'door', 'model', 'automotive', 'buying', 'dodge'],
1: ['ticket', 'launch', 'lib', 'battery', 'tickets', 'pat', 'rockets', 'doug', 'flight', 'exploration', 'engines', 'cost', 'henry', 'batteries', 'lunar'],
2: ['msg', 'food', 'sensitivity', 'chinese', 'superstition', 'foods', 'diet', 'honda', 'taste', 'effects', 'reaction', 'brain', 'eat', 'studies', 'circuit'],
3: ['bus', 'ram', 'simms', 'motherboard', 'card', 'bit', 'memory', 'isa', 'cache', 'board', 'controller', 'mac', 'cards', 'eisa', 'simm'],
4: ['jesus', 'god', 'bible', 'christ', 'christian', 'church', 'christians', 'faith', 'life', 'matthew', 'christianity', 'lord', 'gods', 'father', 'holy'],
5: ['scsi', 'drive', 'ide', 'drives', 'controller', 'mac', 'tape', 'bus', 'hard', 'quadra', 'isa', 'rom', 'disk', 'devices', 'data'],
6: ['gun', 'guns', 'firearms', 'handgun', 'amendment', 'control', 'nra', 'safety', 'firearm', 'weapons', 'arms

In [68]:
from contextualized_topic_models.utils.data_preparation import TopicModelDataPreparation
from nltk.corpus import stopwords as stop_words
from gensim.utils import deaccent


class WhiteSpacePreprocessing():
    def __init__(self, documents, stopwords_language="english", vocabulary_size=2000):
        self.documents = documents
        self.stopwords = set(stop_words.words(stopwords_language))
        self.vocabulary_size = vocabulary_size

        warnings.simplefilter('always', DeprecationWarning)
        warnings.warn("WhiteSpacePreprocessing is deprecated and will be removed in future versions."
                      "Use WhiteSpacePreprocessingStopwords.")

    def preprocess(self):
        preprocessed_docs_tmp = self.documents
        preprocessed_docs_tmp = [deaccent(doc.lower()) for doc in preprocessed_docs_tmp]
        preprocessed_docs_tmp = [doc.translate(
            str.maketrans(string.punctuation, ' ' * len(string.punctuation))) for doc in preprocessed_docs_tmp]
        preprocessed_docs_tmp = [' '.join([w for w in doc.split() if len(w) > 0 and w not in self.stopwords])
                                 for doc in preprocessed_docs_tmp]

        vectorizer = CountVectorizer(max_features=self.vocabulary_size)
        vectorizer.fit_transform(preprocessed_docs_tmp)
        temp_vocabulary = set(vectorizer.get_feature_names())

        preprocessed_docs_tmp = [' '.join([w for w in doc.split() if w in temp_vocabulary])
                                 for doc in preprocessed_docs_tmp]

        preprocessed_docs, unpreprocessed_docs, retained_indices = [], [], []
        for i, doc in enumerate(preprocessed_docs_tmp):
            if len(doc) > 0:
                preprocessed_docs.append(doc)
                unpreprocessed_docs.append(self.documents[i])
                retained_indices.append(i)

        vocabulary = list(set([item for doc in preprocessed_docs for item in doc.split()]))

        return preprocessed_docs, unpreprocessed_docs, vocabulary, retained_indices
    
def _hungarian_match(flat_preds, flat_targets, num_samples, class_num):  
    num_k = class_num
    num_correct = np.zeros((num_k, num_k))
  
    for c1 in range(0, num_k):
        for c2 in range(0, num_k):
            votes = int(((flat_preds == c1) * (flat_targets == c2)).sum())
            num_correct[c1, c2] = votes
  
    match = linear_assignment(num_samples - num_correct)
    match = np.array(list(zip(*match)))
    res = []
    for out_c, gt_c in match:
        res.append((out_c, gt_c))
  
    return res

def get_document_topic(topic_words, preprocessed_documents_lemmatized):
    topic_words_flatten = list(itertools.chain.from_iterable(topic_words))
    if '' in topic_words_flatten:
        topic_words_flatten.remove('')
    topic_words_flatten = list(set(topic_words_flatten))
    
    vectorizer = CountVectorizer(vocabulary = topic_words_flatten)
    vectorizer = vectorizer.fit(preprocessed_documents_lemmatized)
    count_mat = vectorizer.transform(preprocessed_documents_lemmatized).toarray()
    
    count_mat_normalized = count_mat + 1e-4
    count_mat_normalized = count_mat_normalized / count_mat_normalized.sum(axis=1).reshape(-1, 1)
    
    topic_mat = vectorizer.transform([' '.join(i) for i in topic_words]).toarray()
    topic_mat_normalized = topic_mat + 1e-4
    topic_mat_normalized = topic_mat_normalized / topic_mat_normalized.sum(axis=1).reshape(-1, 1)
    
    topic_mat_inverse = topic_mat_normalized @ topic_mat_normalized.transpose()
    topic_mat_inverse = np.linalg.inv(topic_mat_inverse)
    topic_mat_inverse = topic_mat_normalized.transpose() @ topic_mat_inverse
    document_topic = count_mat_normalized @ topic_mat_inverse
    return document_topic

class TopicModelDataPreparationNoNumber(TopicModelDataPreparation):
    def fit(self, text_for_contextual, text_for_bow, labels=None, wordlist=None):
        """
        This method fits the vectorizer and gets the embeddings from the contextual model
        :param text_for_contextual: list of unpreprocessed documents to generate the contextualized embeddings
        :param text_for_bow: list of preprocessed documents for creating the bag-of-words
        :param labels: list of labels associated with each document (optional).
        """

        if self.contextualized_model is None:
            raise Exception("You should define a contextualized model if you want to create the embeddings")

        # TODO: this count vectorizer removes tokens that have len = 1, might be unexpected for the users
        self.vectorizer = CountVectorizer(token_pattern=r'\b[a-zA-Z]{2,}\b', vocabulary=wordlist)

        train_bow_embeddings = self.vectorizer.fit_transform(text_for_bow)
        train_contextualized_embeddings = bert_embeddings_from_list(text_for_contextual, self.contextualized_model)
        self.vocab = self.vectorizer.get_feature_names()
        self.id2token = {k: v for k, v in zip(range(0, len(self.vocab)), self.vocab)}

        if labels:
            self.label_encoder = OneHotEncoder()
            encoded_labels = self.label_encoder.fit_transform(np.array([labels]).reshape(-1, 1))
        else:
            encoded_labels = None

        return CTMDataset(train_contextualized_embeddings, train_bow_embeddings, self.id2token, encoded_labels)
    

topic_words_list = list(topic_word_dict.values())
qt = TopicModelDataPreparationNoNumber("sentence-transformers/all-MiniLM-L6-v2")
sp = WhiteSpacePreprocessing(textData.data, stopwords_language='english')
preprocessed_documents, unpreprocessed_corpus, vocab, retained_indices = sp.preprocess()
vectorizer_model = CountVectorizer(stop_words="english")
lemmatizer = WordNetLemmatizer()
preprocessed_documents_lemmatized = [' '.join([lemmatizer.lemmatize(w) for w in doc.split()]) for doc in preprocessed_documents]

document_topic = get_document_topic(topic_words_list, preprocessed_documents_lemmatized)
train_target_filtered = textData.targets.squeeze()[retained_indices]
flat_predict = torch.tensor(np.argmax(document_topic, axis=1))
flat_target = torch.tensor(train_target_filtered).to(flat_predict.device)
num_samples = flat_predict.shape[0]
match = _hungarian_match(flat_predict, flat_target, num_samples, 20)    
reordered_preds = torch.zeros(num_samples).to(flat_predict.device)
for pred_i, target_i in match:
    reordered_preds[flat_predict == pred_i] = int(target_i)
acc = int((reordered_preds == flat_target.float()).sum()) / float(num_samples)
print(acc)

0.16978964115255435


In [69]:
from sklearn.metrics.cluster import normalized_mutual_info_score

In [70]:
normalized_mutual_info_score(reordered_preds, flat_target)

0.19309600164021448

In [71]:
now = datetime.now().strftime('%y%m%d_%H%M%S')
results = get_topic_qualities(topic_words_list, args.palmetto_dir, reference_corpus=[doc.split() for doc in trainds.preprocess_ctm(trainds.nonempty_text)],
                              filename=f'results/{now}.txt')
results

./
[-1.7659, -2.75958, -5.73115, -3.87078, -1.51117, -2.65345, -3.40299, -3.64834, -3.5648, -6.31774, -2.86435, -4.14767, -5.16925, -3.03433, -2.76406, -4.74539, -3.82516, -1.5307, -4.25273, -5.6412, -4.97686, -3.95081, -3.79115, -5.95325, -2.05163, -1.90816, -3.0101, -2.22296, -2.23495, -3.61482, -5.7169, -2.20765, -4.27635, -8.62208, -5.53544, -1.84872, -3.67992, -5.81377, -4.49597, -2.59091, -3.00314, -5.36454, -3.68631, -4.5383, -3.13057, -2.02037, -3.83416, -5.56778, -4.50328, -6.3241]
-3.8735137999999996
[0.02447, -0.096, -0.09703, -0.00554, 0.14874, 0.03883, 0.02828, 0.08621, 0.09544, -0.01558, -0.05051, -0.03389, -0.0938, 0.04603, 0.05896, -0.03482, -0.0055, 0.10847, -0.11768, -0.01343, -0.05341, -0.05391, -0.06132, -0.06, 0.10642, 0.00153, -0.1337, 0.02273, -0.05109, -0.0062, -0.04994, 0.16947, 0.00989, -0.01169, -0.03567, 0.08334, 0.02871, -0.06607, -0.04349, -0.04062, -0.06495, -0.0105, -0.07653, -0.04343, -0.06654, 0.01004, 0.09758, 0.00865, -0.08983, -0.08316]
-0.0098408
[

{'topic_N': 50,
 'umass_wiki': -3.8735137999999996,
 'npmi_wiki': -0.0098408,
 'uci_wiki': -1.3311518000000004,
 'CV_wiki': 0.4333020000000001,
 'cp_wiki': 0.05935320000000001,
 'sim_w2v': 0.18555901170050965,
 'diversity': 0.7453333333333333,
 'filename': 'results/240517_152354.txt'}

In [72]:
normalized = [coherence_normalize(doc) for doc in trainds.nonempty_text]
gensim_dict = Dictionary(normalized)

n_word = args.n_word
n_topic_word = n_word

words_to_idx = {k: v for v, k in enumerate(words)}
topic_word_dict = {}
topic_score_dict = {}
total_score_cat = []
for label in docs_per_class.cluster_label.index:
    total_score = resolution_score[label]
    score_higest = total_score.argsort()
    score_higest = score_higest[::-1]
    topic_word_list = [words[index] for index in score_higest]
    
    total_score_cat.append(total_score)
    # topic_word_list = [word for word in topic_word_list if word not in stopwords]    
    topic_word_list = [word for word in topic_word_list if word in gensim_dict.token2id]
    # topic_word_list = [word for word in topic_word_list if len(word) >= 3]    
    topic_word_dict[docs_per_class.cluster_label.iloc[label]] = topic_word_list[:int(n_topic_word)]
    topic_score_dict[docs_per_class.cluster_label.iloc[label]] = [total_score[words_to_idx[top_word]] for top_word in topic_word_list[:int(n_topic_word)]]
total_score_cat = np.stack(total_score_cat, axis = 0)

In [73]:
def remove_dup(seq):
    seen = set()
    seen_add = seen.add
    return [x for x in seq if not (x in seen or seen_add(x))]

topic_words_list = list(topic_word_dict.values())
topic_word_set = list(itertools.chain.from_iterable(pd.DataFrame.from_dict(topic_word_dict).values))
word_candidates = remove_dup(topic_word_set)[:n_word]
n_word = len(word_candidates)
n_word

5000

In [74]:
import pickle
with open('our_word_candidates_10000.pkl', 'wb') as f:
    pickle.dump(word_candidates, f)

In [75]:
weight_candidates = {}
for candidate in word_candidates:
    weight_candidates[candidate] = [total_score_cat[label, words_to_idx[candidate]] for label in range(n_cluster)]

In [76]:
weight_cand_to_idx = {k: v for v, k in enumerate(list(weight_candidates.keys()))}
weight_cand_matrix = np.array(list(weight_candidates.values()))

# Re-formulate the bow

In [77]:
def dist_match_loss(hiddens, alpha=1.0):
    device = hiddens.device
    hidden_dim = hiddens.shape[-1]
    H = np.random.randn(hidden_dim, hidden_dim)
    Q, R = qr(H) 
    rand_w = torch.Tensor(Q).to(device)
    loss_dist_match = get_swd_loss(hiddens, rand_w, alpha)
    return loss_dist_match


def js_div_loss(hidden1, hidden2):
    m = 0.5 * (hidden1 + hidden2)
    return kldiv(m.log(), hidden1) + kldiv(m.log(), hidden2)


def get_swd_loss(states, rand_w, alpha=1.0):
    device = states.device
    states_shape = states.shape
    states = torch.matmul(states, rand_w)
    states_t, _ = torch.sort(states.t(), dim=1)

    # Random vector with length from normal distribution
    states_prior = torch.Tensor(np.random.dirichlet([alpha]*states_shape[1], states_shape[0])).to(device) # (bsz, dim)
    states_prior = torch.matmul(states_prior, rand_w) # (dim, dim)
    states_prior_t, _ = torch.sort(states_prior.t(), dim=1) # (dim, bsz)
    return torch.mean(torch.sum((states_prior_t - states_t)**2, axis=0))

In [78]:
class Stage2Dataset(Dataset):
    def __init__(self, encoder, ds, basesim_matrix, word_candidates, k=1, lemmatize=False):
        self.lemmatize = lemmatize
        self.ds = ds
        self.org_list = self.ds.org_list
        self.nonempty_text = self.ds.nonempty_text
        english_stopwords = nltk.corpus.stopwords.words('english')
        self.stopwords_list = set(english_stopwords)
        self.vectorizer = CountVectorizer(vocabulary=word_candidates)
        self.vectorizer.fit(self.preprocess_ctm(self.nonempty_text)) 
        self.bow_list = []
        for sent in tqdm(self.nonempty_text):
            self.bow_list.append(self.vectorize(sent))
            
        sim_weight, sim_indices = basesim_matrix.topk(k=k, dim=-1)
        zip_iterator = zip(np.arange(len(sim_weight)), sim_indices.squeeze().data.numpy())
        self.pos_dict = dict(zip_iterator)
        
        self.embedding_list = []
        encoder_device = next(encoder.parameters()).device
        for org_input in tqdm(self.org_list):
            org_input_ids = org_input['input_ids'].to(encoder_device).reshape(1, -1)
            org_attention_mask = org_input['attention_mask'].to(encoder_device).reshape(1, -1)
            embedding = encoder(input_ids = org_input_ids, attention_mask = org_attention_mask)
            self.embedding_list.append(embedding['pooler_output'].squeeze().detach().cpu())
            
    
    def __len__(self):
        return len(self.org_list)
        
    def preprocess_ctm(self, documents):
        preprocessed_docs_tmp = documents
        preprocessed_docs_tmp = [doc.lower() for doc in preprocessed_docs_tmp]
        preprocessed_docs_tmp = [doc.translate(
            str.maketrans(string.punctuation, ' ' * len(string.punctuation))) for doc in preprocessed_docs_tmp]
        preprocessed_docs_tmp = [' '.join([w for w in doc.split() if len(w) > 0 and w not in self.stopwords_list])
                                 for doc in preprocessed_docs_tmp]
        if self.lemmatize:
            lemmatizer = WordNetLemmatizer()
            preprocessed_docs_tmp = [' '.join([lemmatizer.lemmatize(w) for w in doc.split()])
                                     for doc in preprocessed_docs_tmp]
        return preprocessed_docs_tmp
        
    def vectorize(self, text):
        text = self.preprocess_ctm([text])
        vectorized_input = self.vectorizer.transform(text)
        vectorized_input = vectorized_input.toarray().astype(np.float64)
#         vectorized_input = (vectorized_input != 0).astype(np.float64)

        # Get word distribution from BoW
        if vectorized_input.sum() == 0:
            vectorized_input += 1e-8
        vectorized_input = vectorized_input / vectorized_input.sum(axis=1, keepdims=True)
        assert abs(vectorized_input.sum() - vectorized_input.shape[0]) < 0.01
        
        vectorized_label = torch.tensor(vectorized_input, dtype=torch.float)
        return vectorized_label[0]
        
        
    def __getitem__(self, idx):
        pos_idx = self.pos_dict[idx]
        return self.embedding_list[idx], self.embedding_list[pos_idx], self.bow_list[idx], self.bow_list[pos_idx]

In [79]:
class Stage2TestDataset(Dataset):
    def __init__(self, encoder, ds, word_candidates, k=1, lemmatize=False):
        self.lemmatize = lemmatize
        self.ds = ds
        self.org_list = self.ds.org_list
        self.nonempty_text = self.ds.nonempty_text
        english_stopwords = nltk.corpus.stopwords.words('english')
        self.stopwords_list = set(english_stopwords)
        self.vectorizer = CountVectorizer(vocabulary=word_candidates)
        self.vectorizer.fit(self.preprocess_ctm(self.nonempty_text)) 
        self.bow_list = []
        for sent in tqdm(self.nonempty_text):
            self.bow_list.append(self.vectorize(sent))
        
        self.embedding_list = []
        encoder_device = next(encoder.parameters()).device
        for org_input in tqdm(self.org_list):
            org_input_ids = org_input['input_ids'].to(encoder_device).reshape(1, -1)
            org_attention_mask = org_input['attention_mask'].to(encoder_device).reshape(1, -1)
            embedding = encoder(input_ids = org_input_ids, attention_mask = org_attention_mask)
            self.embedding_list.append(embedding['pooler_output'].squeeze().detach().cpu())
            
    
    def __len__(self):
        return len(self.org_list)
        
    def preprocess_ctm(self, documents):
        preprocessed_docs_tmp = documents
        preprocessed_docs_tmp = [doc.lower() for doc in preprocessed_docs_tmp]
        preprocessed_docs_tmp = [doc.translate(
            str.maketrans(string.punctuation, ' ' * len(string.punctuation))) for doc in preprocessed_docs_tmp]
        preprocessed_docs_tmp = [' '.join([w for w in doc.split() if len(w) > 0 and w not in self.stopwords_list])
                                 for doc in preprocessed_docs_tmp]
        if self.lemmatize:
            lemmatizer = WordNetLemmatizer()
            preprocessed_docs_tmp = [' '.join([lemmatizer.lemmatize(w) for w in doc.split()])
                                     for doc in preprocessed_docs_tmp]
        return preprocessed_docs_tmp
        
    def vectorize(self, text):
        text = self.preprocess_ctm([text])
        vectorized_input = self.vectorizer.transform(text)
        vectorized_input = vectorized_input.toarray().astype(np.float64)
#         vectorized_input = (vectorized_input != 0).astype(np.float64)

        # Get word distribution from BoW
        if vectorized_input.sum() == 0:
            vectorized_input += 1e-8
        vectorized_input = vectorized_input / vectorized_input.sum(axis=1, keepdims=True)
        assert abs(vectorized_input.sum() - vectorized_input.shape[0]) < 0.01
        
        vectorized_label = torch.tensor(vectorized_input, dtype=torch.float)
        return vectorized_label[0]
        
        
    def __getitem__(self, idx):
        return self.embedding_list[idx], self.bow_list[idx]

In [80]:
finetuneds = Stage2Dataset(model.encoder, trainds, basesim_matrix, word_candidates, lemmatize=True)    

kldiv = torch.nn.KLDivLoss(reduction='batchmean')
vocab_dict = finetuneds.vectorizer.vocabulary_
vocab_dict_reverse = {i:v for v, i in vocab_dict.items()}
print(n_word)

100%|██████████| 11314/11314 [00:04<00:00, 2738.14it/s]
100%|██████████| 11314/11314 [00:56<00:00, 201.09it/s]

5000





# Stage 3

In [81]:
def measure_hungarian_score(topic_dist, train_target):
    dist = topic_dist
    train_target_filtered = train_target
    flat_predict = torch.tensor(np.argmax(dist, axis=1))
    flat_target = torch.tensor(train_target_filtered).to(flat_predict.device)
    num_samples = flat_predict.shape[0]
    num_classes = dist.shape[1]
    match = _hungarian_match(flat_predict, flat_target, num_samples, num_classes)    
    reordered_preds = torch.zeros(num_samples).to(flat_predict.device)
    for pred_i, target_i in match:
        reordered_preds[flat_predict == pred_i] = int(target_i)
    acc = int((reordered_preds == flat_target.float()).sum()) / float(num_samples)
    return acc

In [82]:
weight_cands = torch.tensor(weight_cand_matrix.max(axis=1)).cuda(gpu_ids[0]).float()

# Main

In [83]:
testds = BertDataset(bert=bert_name, text_list=textData.test_data, N_word=n_word, vectorizer=None, lemmatize=True)
testds2 = Stage2TestDataset(model.encoder, testds, word_candidates, lemmatize=True)


100%|██████████| 7532/7532 [00:03<00:00, 2013.88it/s]
100%|██████████| 7532/7532 [00:02<00:00, 2771.11it/s]
100%|██████████| 7532/7532 [00:37<00:00, 200.86it/s]


In [84]:
from evaluation import evaluate_classification, evaluate_clustering

results_list = []

for i in range(args.stage_2_repeat):
    model = ContBertTopicExtractorAE(N_topic=n_topic, N_word=args.n_word, bert=bert_name, bert_dim=768)
    model.load_state_dict(torch.load(model_stage1_name), strict=True)
    model.beta = nn.Parameter(torch.Tensor(model.N_topic, n_word))
    nn.init.xavier_uniform_(model.beta)
    model.beta_batchnorm = nn.Sequential()
    model.cuda(gpu_ids[0])
    
    losses = AverageMeter()
    dlosses = AverageMeter() 
    rlosses = AverageMeter()
    closses = AverageMeter()
    distlosses = AverageMeter()
    trainloader = DataLoader(finetuneds, batch_size=bsz, shuffle=False, num_workers=0)
    testloader = DataLoader(testds2, batch_size=bsz, shuffle=False, num_workers=0)
    memoryloader = DataLoader(finetuneds, batch_size=bsz * 2, shuffle=False, num_workers=0)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.stage_2_lr)

    memory_queue = F.softmax(torch.randn(512, n_topic).cuda(gpu_ids[0]), dim=1)
    print("Coeff   / regul: {:.5f} - recon: {:.5f} - c: {:.5f} - dist: {:.5f} ".format(args.coeff_2_regul, 
                                                                                        args.coeff_2_recon,
                                                                                        args.coeff_2_cons,
                                                                                        args.coeff_2_dist))
    for epoch in range(50):
        model.train()
        model.encoder.eval()
        for batch_idx, batch in enumerate(trainloader):
            org_input, pos_input, org_bow, pos_bow = batch
            org_input = org_input.cuda(gpu_ids[0])
            org_bow = org_bow.cuda(gpu_ids[0])
            pos_input = pos_input.cuda(gpu_ids[0])
            pos_bow = pos_bow.cuda(gpu_ids[0])

            batch_size = org_input_ids.size(0)

            org_dists, org_topic_logit = model.decode(org_input)
            pos_dists, pos_topic_logit = model.decode(pos_input)

            org_topic = F.softmax(org_topic_logit, dim=1)
            pos_topic = F.softmax(pos_topic_logit, dim=1)

            # reconstruction loss
            # batchmean
#             org_target = torch.matmul(org_topic.detach(), weight_cands)
#             pos_target = torch.matmul(pos_topic.detach(), weight_cands)
            
#             _, org_target = torch.max(org_topic.detach(), 1)
#             _, pos_target = torch.max(pos_topic.detach(), 1)
            
            recons_loss = torch.mean(-torch.sum(torch.log(org_dists + 1E-10) * (org_bow * weight_cands), axis=1), axis=0)
            recons_loss += torch.mean(-torch.sum(torch.log((1-org_dists) + 1E-10) * ((1-org_bow) * weight_cands), axis=1), axis=0)
            recons_loss += torch.mean(-torch.sum(torch.log(pos_dists + 1E-10) * (pos_bow * weight_cands), axis=1), axis=0)
            recons_loss += torch.mean(-torch.sum(torch.log((1-pos_dists) + 1E-10) * ((1-pos_bow) * weight_cands), axis=1), axis=0)
            recons_loss *= 0.5

            # consistency loss
            pos_sim = torch.sum(org_topic * pos_topic, dim=-1)
            cons_loss = -pos_sim.mean()

            # distribution loss
            # batchmean
            distmatch_loss = dist_match_loss(torch.cat((org_topic, pos_topic), dim=0), dirichlet_alpha_2)
            

            loss = args.coeff_2_recon * recons_loss + \
                   args.coeff_2_cons * cons_loss + \
                   args.coeff_2_dist * distmatch_loss 
            
            losses.update(loss.item(), bsz)
            closses.update(cons_loss.item(), bsz)
            rlosses.update(recons_loss.item(), bsz)
            distlosses.update(distmatch_loss.item(), bsz)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        print("Epoch-{} / recon: {:.5f} - dist: {:.5f} - cons: {:.5f}".format(epoch, rlosses.avg, distlosses.avg, closses.avg))

    print("------- Evaluation results -------")
    all_list = {}
    for e, i in enumerate(model.beta.cpu().topk(15, dim=1).indices):
        word_list = []
        for j in i:
            word_list.append(vocab_dict_reverse[j.item()])
        all_list[e] = word_list
        print("topic-{}".format(e), word_list)

    topic_words_list = list(all_list.values())
    now = datetime.now().strftime('%y%m%d_%H%M%S')
    results = get_topic_qualities(topic_words_list, palmetto_dir=args.palmetto_dir,
                                  reference_corpus=[doc.split() for doc in trainds.preprocess_ctm(trainds.nonempty_text)],
                                  filename=f'results/{now}.txt')
    train_theta = []
    test_theta = []
    for batch_idx, batch in tqdm(enumerate(trainloader)):
        org_input, _, org_bow, _ = batch
        org_input = org_input.cuda(gpu_ids[0])
        org_bow = org_bow.cuda(gpu_ids[0])
        # pos_input = pos_input.cuda(gpu_ids[0])
        # pos_bow = pos_bow.cuda(gpu_ids[0])

        batch_size = org_input_ids.size(0)

        org_dists, org_topic_logit = model.decode(org_input)
        # pos_dists, pos_topic_logit = model.decode(pos_input)

        org_topic = F.softmax(org_topic_logit, dim=1)
        # pos_topic = F.softmax(pos_topic_logit, dim=1)
        
        train_theta.append(org_topic.detach().cpu())
    
    train_theta = np.concatenate(train_theta, axis=0)

    for batch_idx, batch in tqdm(enumerate(testloader)): 
        org_input, org_bow = batch
        org_input = org_input.cuda(gpu_ids[0])
        org_bow = org_bow.cuda(gpu_ids[0])
        # pos_input = pos_input.cuda(gpu_ids[0])
        # pos_bow = pos_bow.cuda(gpu_ids[0])

        batch_size = org_input_ids.size(0)

        org_dists, org_topic_logit = model.decode(org_input)
        # pos_dists, pos_topic_logit = model.decode(pos_input)

        org_topic = F.softmax(org_topic_logit, dim=1)
        # pos_topic = F.softmax(pos_topic_logit, dim=1)
        
        test_theta.append(org_topic.detach().cpu())
    
    test_theta = np.concatenate(test_theta, axis=0)
    
    classification_res = evaluate_classification(train_theta, test_theta, textData.targets, textData.test_targets)
    clustering_res = evaluate_clustering(test_theta, textData.test_targets)
    
    results.update(classification_res)
    results.update(clustering_res)
    
    
    if should_measure_hungarian:
        topic_dist = torch.empty((0, n_topic))
        model.eval()
        evalloader = DataLoader(finetuneds, batch_size=bsz, shuffle=False, num_workers=0)
        non_empty_text_index = [i for i, text in enumerate(textData.data) if len(text) != 0]
        assert len(finetuneds) == len(non_empty_text_index)
        with torch.no_grad():
            for batch in tqdm(evalloader):
                org_input, _, org_bow, __ = batch
                org_input = org_input.cuda(gpu_ids[0])
                org_dists, org_topic_logit = model.decode(org_input)
                org_topic = F.softmax(org_topic_logit, dim=1)
                topic_dist = torch.cat((topic_dist, org_topic.detach().cpu()), 0)
        label_accuracy = measure_hungarian_score(
                             topic_dist,
                             [target for i, target in enumerate(textData.targets)
                              if i in non_empty_text_index]
                         )
        results['label_match'] = label_accuracy

    print(results)
    print()
    results_list.append(results)

Coeff   / regul: 1.00000 - recon: 1.00000 - c: 1.00000 - dist: 1.00000 
Epoch-0 / recon: 2.54389 - dist: 0.07809 - cons: -0.27672
Epoch-1 / recon: 2.43787 - dist: 0.05814 - cons: -0.32910
Epoch-2 / recon: 2.38195 - dist: 0.05084 - cons: -0.35136
Epoch-3 / recon: 2.34941 - dist: 0.04676 - cons: -0.36403
Epoch-4 / recon: 2.32836 - dist: 0.04424 - cons: -0.37254
Epoch-5 / recon: 2.31367 - dist: 0.04250 - cons: -0.37886
Epoch-6 / recon: 2.30267 - dist: 0.04112 - cons: -0.38391
Epoch-7 / recon: 2.29419 - dist: 0.04006 - cons: -0.38790
Epoch-8 / recon: 2.28749 - dist: 0.03916 - cons: -0.39119
Epoch-9 / recon: 2.28200 - dist: 0.03846 - cons: -0.39398
Epoch-10 / recon: 2.27746 - dist: 0.03788 - cons: -0.39632
Epoch-11 / recon: 2.27360 - dist: 0.03738 - cons: -0.39856
Epoch-12 / recon: 2.27022 - dist: 0.03693 - cons: -0.40075
Epoch-13 / recon: 2.26728 - dist: 0.03655 - cons: -0.40279
Epoch-14 / recon: 2.26471 - dist: 0.03620 - cons: -0.40447
Epoch-15 / recon: 2.26245 - dist: 0.03586 - cons: -0.

89it [00:00, 186.58it/s]
59it [00:00, 267.54it/s]
100%|██████████| 89/89 [00:00<00:00, 261.00it/s]


{'topic_N': 50, 'umass_wiki': -3.1000192, 'npmi_wiki': -0.031743600000000004, 'uci_wiki': -1.3323108, 'CV_wiki': 0.36467320000000003, 'cp_wiki': 0.0405284, 'sim_w2v': 0.12225475113772923, 'diversity': 0.328, 'filename': 'results/240517_153125.txt', 'acc': 0.49402549123738715, 'macro-F1': 0.4674685936857392, 'Purity': 0.5114179500796601, 'NMI': 0.48645531376616835, 'label_match': 0.43344528902245005}

Coeff   / regul: 1.00000 - recon: 1.00000 - c: 1.00000 - dist: 1.00000 
Epoch-0 / recon: 2.54682 - dist: 0.08964 - cons: -0.27331
Epoch-1 / recon: 2.44628 - dist: 0.07193 - cons: -0.31212
Epoch-2 / recon: 2.39432 - dist: 0.06492 - cons: -0.32813
Epoch-3 / recon: 2.36389 - dist: 0.06105 - cons: -0.33729
Epoch-4 / recon: 2.34408 - dist: 0.05873 - cons: -0.34304
Epoch-5 / recon: 2.33015 - dist: 0.05702 - cons: -0.34696
Epoch-6 / recon: 2.31965 - dist: 0.05548 - cons: -0.35053
Epoch-7 / recon: 2.31144 - dist: 0.05421 - cons: -0.35408
Epoch-8 / recon: 2.30486 - dist: 0.05305 - cons: -0.35684
Ep

89it [00:00, 195.60it/s]
59it [00:00, 282.30it/s]
100%|██████████| 89/89 [00:00<00:00, 258.72it/s]


{'topic_N': 50, 'umass_wiki': -3.2232423999999997, 'npmi_wiki': -0.0460716, 'uci_wiki': -1.6506857999999993, 'CV_wiki': 0.380886, 'cp_wiki': -0.011788799999999999, 'sim_w2v': 0.11873108914489816, 'diversity': 0.30933333333333335, 'filename': 'results/240517_154244.txt', 'acc': 0.47437599575146044, 'macro-F1': 0.44354017434120074, 'Purity': 0.49070631970260226, 'NMI': 0.4733430637891917, 'label_match': 0.46049142655117553}

Coeff   / regul: 1.00000 - recon: 1.00000 - c: 1.00000 - dist: 1.00000 
Epoch-0 / recon: 2.54453 - dist: 0.08944 - cons: -0.25046
Epoch-1 / recon: 2.44148 - dist: 0.06865 - cons: -0.29795
Epoch-2 / recon: 2.38818 - dist: 0.06046 - cons: -0.31811
Epoch-3 / recon: 2.35742 - dist: 0.05631 - cons: -0.32885
Epoch-4 / recon: 2.33767 - dist: 0.05383 - cons: -0.33614
Epoch-5 / recon: 2.32390 - dist: 0.05220 - cons: -0.34096
Epoch-6 / recon: 2.31364 - dist: 0.05084 - cons: -0.34485
Epoch-7 / recon: 2.30584 - dist: 0.04980 - cons: -0.34777
Epoch-8 / recon: 2.29946 - dist: 0.04

In [None]:
results_df = pd.DataFrame(results_list)
print(results_df)
print('mean')
print(results_df.mean())
print('std')
print(results_df.std())

if args.result_file is not None:
    result_filename = f'results/{args.result_file}'
else:
    result_filename = f'results/{now}.tsv'

results_df.to_csv(result_filename, sep='\t')

   topic_N   CV_wiki   sim_w2v  diversity                   filename      acc  \
0       50  0.367416  0.198575   0.290667  results/240516_233645.txt  0.58284   
1       50  0.367488  0.197549   0.266667  results/240516_234214.txt  0.58976   
2       50  0.364076  0.192047   0.293333  results/240516_234737.txt  0.58264   
3       50  0.361422  0.191338   0.269333  results/240516_235509.txt  0.58844   
4       50  0.368351  0.197752   0.264000  results/240517_000053.txt  0.58924   

   macro-F1   Purity       NMI  
0  0.581245  0.57880  0.007457  
1  0.589379  0.57484  0.007546  
2  0.581773  0.57700  0.007772  
3  0.587958  0.57088  0.006728  
4  0.588312  0.57708  0.007461  
mean
topic_N      50.000000
CV_wiki       0.365751
sim_w2v       0.195452
diversity     0.276800
acc           0.586584
macro-F1      0.585733
Purity        0.575720
NMI           0.007393
dtype: float64
std
topic_N      0.000000
CV_wiki      0.002920
sim_w2v      0.003463
diversity    0.014035
acc          0.0035