In [10]:
#!/usr/bin/python3

import logging 
import sys, os, re
import random
from collections import defaultdict, Counter
import itertools
import pickle
import conllu

import pandas as pd
import numpy as np

import ipywidgets as widgets

from bertopic import BERTopic
from topictuner import TopicModelTuner

from sentence_transformers import SentenceTransformer
from umap import UMAP
from hdbscan import HDBSCAN
from sklearn.feature_extraction.text import CountVectorizer
from bertopic.vectorizers import ClassTfidfTransformer
from bertopic.representation import KeyBERTInspired, MaximalMarginalRelevance
from gensim.models.coherencemodel import CoherenceModel
import gensim.corpora as corpora

from gensim.matutils import hellinger
from gensim.matutils import kullback_leibler
from gensim.matutils import jaccard

# util classes

In [2]:
###############
#These classes for UD object types (treebank, sentence, word)
#are from https://github.com/personads/ud-selection/blob/main/lib/data.py
###############
class UniversalDependencies:

    def __init__(self, treebanks=[]):
        self._treebanks = treebanks

    
    @staticmethod
    def from_directory(path, verbose=False):
        treebanks = []
        cursor = 0
            
        # iterate over files in TB directory
        for tbf in sorted(os.listdir(path)):
            tbf_path = os.path.join(path, tbf)
            
           
            # parse TB sample name with lang fam and GMM/gold
            r = r'(.+)(?=(\.dev|\.train)\.conllu)'
            tb_match = re.match(r, tbf)
            if not tb_match:
                continue
            fname = tb_match.group(0)
            # print(f'fname: {fname}\n')
                      
            components = fname.split('__')
            # print(f'components: {components}\n')
            
            tb_meta = defaultdict()
            
            if ('dev' in tbf_path)&(len(components) == 3):
                # tb_meta['Language_family'] = components[0]
                tb_meta['Language'] = components[0]
                tb_meta['Treebank'] = components[1]
                tb_meta['Genre'] = components[2]
                
            
            elif 'train' in tbf_path:
                # tb_meta['Language_family'] = components[0]
                # tb_meta['Sample_type'] = components[1]
                tb_meta['Language'] = components[0]
                tb_meta['Treebank'] = components[1]
                tb_meta['Genre'] = components[2]
                tb_meta['Sample_no'] = components[3]
                tb_meta['Seed'] = components[4]
            print(tb_meta)    
            
            # skip non-conllu files
            if os.path.splitext(tbf)[1] != '.conllu': continue

            # load treebank
            treebank = UniversalDependenciesTreebank.from_conllu(tbf_path, name=tbf, meta=tb_meta, start_idx=cursor, ud_filter=None)
            treebanks.append(treebank)
            cursor += len(treebank)

            # print statistics (if verbose)
            if verbose:
                info = f"Loaded {treebank}."
            if logging.getLogger().hasHandlers():
                logging.info(info)
            else:
                print(info)

        return UniversalDependencies(treebanks=treebanks)

    def get_treebanks(self):
        return self._treebanks

class UniversalDependenciesTreebank:
	def __init__(self, sentences=[], name=None, meta={}):
		self._sentences = sentences
		self._name = name
		self._meta = meta

	def __repr__(self):
		return f'<UniversalDependenciesTreebank{f" ({self._name})" if self._name else ""}: {len(self._sentences)} sentences>'

	def __len__(self):
		return len(self._sentences)

	def __getitem__(self, key):
		return self._sentences[key]

	def __setitem__(self, key, val):
		self._sentences[key] = val

	@staticmethod
	def from_conllu(path, name=None, meta=None, start_idx=0, ud_filter=None):
		sentences = []
		with open(path, 'r', encoding='utf8') as fp:
			cur_lines = []
			for line_idx, line in enumerate(fp):
				line = line.strip()
				# on blank line, construct full sentence from preceding lines
				if line == '':
					try:
						# parse sentence from current set of lines
						sentence = UniversalDependenciesSentence.from_conllu(start_idx + len(sentences), cur_lines)
						# if filter is set, set any sentences not matching the filter to None
						if (ud_filter is not None) and (not ud_filter(sentence, meta)): sentence = None
						# append sentence to results
						sentences.append(sentence)
					except Exception as err:
						warn_msg = f"[Warning] UniversalDependenciesTreebank: Unable to parse '{path}' line {line_idx} ({err}). Skipping."
						if logging.getLogger().hasHandlers():
							logging.warning(warn_msg)
						else:
							print(warn_msg)
					cur_lines = []
					continue
				cur_lines.append(line)
		return UniversalDependenciesTreebank(sentences=sentences, name=name, meta=meta)

	def to_tokens(self):
		sentences = []
		for sentence in self:
			sentences.append(sentence.to_tokens())
		return sentences

	def to_words(self):
		sentences = []
		for sentence in self:
			sentences.append(sentence.to_words())
		return sentences

	def to_conllu(self, comments=True, resolve=False):
		return ''.join([s.to_conllu(comments, resolve) for s in self._sentences])

	def get_sentences(self):
		return self._sentences

	def get_name(self):
		return self._name

	def get_treebank_name(self):
		return self._meta.get('Treebank', 'Unknown')

	def get_language(self):
		return self._meta.get('Language', 'Unknown')

	def get_domains(self):
		return sorted(self._meta.get('Genre', '').split(' '))

	def get_statistics(self):
		statistics = {
			'sentences': len(self._sentences),
			'tokens': 0,
			'words': 0,
			'metadata': set()
		}

		for sidx, sentence in enumerate(self):
			statistics['tokens'] += len(sentence.to_tokens(as_str=False))
			statistics['words'] += len(sentence.to_words(as_str=False))
			statistics['metadata'] |= set(sentence.get_metadata().keys())

		statistics['metadata'] = list(sorted(statistics['metadata']))

		return statistics

    
class UniversalDependenciesSentence:
	def __init__(self, idx, tokens, comments=[]):
		self.idx = idx
		self._tokens = tokens
		self._comments = comments

	def __repr__(self):
		return f"<UniversalDependenciesSentence: ID {self.idx}, {len(self._tokens)} tokens, {len(self._comments)} comments>"

	@staticmethod
	def from_conllu(idx, lines):
		tokens, comments = [], []
		line_idx = 0
		while line_idx < len(lines):
			# check for comment
			if lines[line_idx].startswith('#'):
				comments.append(lines[line_idx])
				line_idx += 1
				continue

			# process tokens
			tkn_words = []
			tkn_line_split = lines[line_idx].split('\t')
			tkn_idx_str = tkn_line_split[0]
			# check for multiword token in 'a-b' format
			num_words = 1
			if '-' in tkn_idx_str:
				tkn_idx_split = tkn_idx_str.split('-')
				# convert token id to tuple signifying range (e.g. (3,4))
				tkn_span = (int(tkn_idx_split[0]), int(tkn_idx_split[1]))
				# collect the number of words in the current span
				while (line_idx + num_words + 1) < len(lines):
					num_words += 1
					# get current index as float due to spans such as '1-2; 1; 2; 2.1; ... 3' (e.g. Arabic data)
					span_str = lines[line_idx+num_words].split('\t')[0]
					if '-' in span_str: break
					span_tkn_idx = float(span_str)
					if int(span_tkn_idx) > tkn_span[1]: break
			# check for multiword token in decimal format '1; 1.1; 1.2; ... 2' or '0.1; 0.2; ... 1' (e.g. Czech data)
			elif re.match(r'^\d+\.\d+', tkn_idx_str)				or ((line_idx < (len(lines) - 1)) and re.match(r'^\d+\.\d+\t', lines[line_idx+1])):
				# count words that are part of multiword token
				while (line_idx + num_words) < len(lines):
					if not re.match(r'^\d+\.\d+\t', lines[line_idx+num_words]):
						break
					num_words += 1
				# token span for decimal indices is (a.1, a.n)
				tkn_span_start = float(tkn_idx_str) if re.match(r'^\d+\.\d+', tkn_idx_str) else int(tkn_idx_str) + .1
				tkn_span_end = tkn_span_start + (.1 * (num_words - 1))
				tkn_span = (tkn_span_start, tkn_span_end)
			# if single word token
			else:
				# convert token id to tuple with range 1 (e.g. (3,3))
				tkn_span = (int(tkn_idx_str), int(tkn_idx_str))

			# construct words contained in token
			for word_line in lines[line_idx:line_idx + num_words]:
				tkn_words.append(UniversalDependenciesWord.from_conllu(word_line))
			# construct and append token
			tokens.append(UniversalDependenciesToken(idx=tkn_span, words=tkn_words))
			# increment line index by number of words in token
			line_idx += num_words

		return UniversalDependenciesSentence(idx=idx, tokens=tokens, comments=comments)

	def to_text(self):
		return ''.join([t.to_text() for t in self._tokens])

	def to_tokens(self, as_str=True):
		return [(t.get_form() if as_str else t) for t in self._tokens]

	def to_words(self, as_str=True):
		return [(w.get_form() if as_str else w) for token in self._tokens for w in token.to_words()]

	def to_conllu(self, comments=True, resolve=False):
		conllu = '\n'.join(self._comments) + '\n' if comments else ''

		conllu += '\n'.join([t.to_conllu(resolve=resolve) for t in self._tokens])
		conllu += '\n\n'

		return conllu

	def get_dependencies(self, offset=-1, include_subtypes=True):

		labels = []
		heads = []
		for token in self._tokens:
			for w in token.to_words():
				try:
					heads.append((w.head + offset))
				except TypeError:
					continue


		for token in self._tokens:
			for w in token.to_words():
				try:
					if include_subtypes:
						labels.append(w.deprel)
					else:
						w.deprel.split(':')[0]
				except TypeError:
					continue

		if labels[0] == None:       
			labels.pop(0)
		# heads = [
		# 	(w.head + offset)
		# 	for token in self._tokens for w in token.to_words()
		# ]
		# labels = [
		# 	w.deprel if include_subtypes else w.deprel.split(':')[0]
		# 	for token in self._tokens for w in token.to_words()
		# ]
		return heads, labels

	def get_comments(self, stripped=True):
		return [c[1:].strip() for c in self._comments]

	def get_metadata(self):
		"""Returns metadata from the comments of a sentence.

		Comment should follow the UD metadata guidelines '# FIELD = VALUE' or '# FIELD VALUE.
		Lines not following this convention are exported in the 'unknown' field.

		Returns a dict of metadata field and value pairs {'FIELD': 'VALUE'}.
		"""
		metadata = {}
		md_patterns = [r'^# ?(.+?) ?= ?(.+)', r'^# ?([^\s]+?)\s([^\s]+)$']
		for comment in self._comments:
			for md_pattern in md_patterns:
				md_match = re.match(md_pattern, comment)
				if md_match:
					metadata[md_match[1]] = md_match[2]
					break
			else:
				metadata['unknown'] = metadata.get('unknown', []) + [comment[1:].strip()]
		return metadata


class UniversalDependenciesToken:
	def __init__(self, idx, words):
		self.idx = idx # expects int or float tuple
		self._words = words # first element is token form, all following belong to potential multiword tokens

	def to_text(self):
		return self._words[0].to_text()

	def to_words(self):
		# if single word token
		if len(self._words) == 1:
			return self._words
		# if multiword token
		else:
			# return words which have a dependency head
			return [w for w in self._words if w.head is not None]

	def to_conllu(self, resolve=False):
		# resolve multiword tokens into its constituents
		if resolve:
			# if form token has no head (e.g. 'i-j' token), get constituent words
			if (self._words[0].head is None) and (len(self._words) > 1):
				return '\n'.join([w.to_conllu() for w in self._words[1:] if w.head is not None])
			# if form token has head or it is not a multiword token, return itself
			elif self._words[0].head is not None:
				return self._words[0].to_conllu()
			# if token consists of only one word which has no head, omit (e.g. Coptic '0.1')
			else:
				return ''
		# otherwise return full set of words
		else:
			return '\n'.join([w.to_conllu() for w in self._words])

	def get_form(self):
		return self._words[0].get_form()


# In[ ]:


class UniversalDependenciesWord:
	"""
	ID: Word index, integer starting at 1 for each new sentence; may be a range for tokens with multiple words.
	FORM: Word form or punctuation symbol.
	LEMMA: Lemma or stem of word form.
	UPOSTAG: Universal part-of-speech tag drawn from our revised version of the Google universal POS tags.
	XPOSTAG: Language-specific part-of-speech tag; underscore if not available.
	FEATS: List of morphological features from the universal feature inventory or from a defined language-specific extension; underscore if not available.
	HEAD: Head of the current token, which is either a value of ID or zero (0).
	DEPREL: Universal Stanford dependency relation to the HEAD (root iff HEAD = 0) or a defined language-specific subtype of one.
	DEPS: List of secondary dependencies (head-deprel pairs).
	MISC: Any other annotation.

	[1] https://universaldependencies.org/docs/format.html
	"""
	def __init__(self, idx, form, lemma, upostag, xpostag, feats, head, deprel, deps=None, misc=None):
		self.idx = idx # expects int, float or str
		self.form = form
		self.lemma = lemma
		self.upostag = upostag
		self.xpostag = xpostag
		self.feats = feats # expects dict
		self.head = head # expects int
		self.deprel = deprel # expects str
		self.deps = deps
		self.misc = misc # expects dict

	def __repr__(self):
		return f'<UniversalDependenciesWord: ID {self.idx}, "{self.form}">'

	@staticmethod
	def from_conllu(line):
		# split line and initially convert '_' values to None
		idx_str, form, lemma, upostag, xpostag, feats, head, deprel, deps, misc = [(v if v != '_' else None) for v in line.split('\t')]
		# parse idx string (int 1, decimal 1.1 or string '1-2')
		idx = idx_str
		if re.match(r'^\d+\.\d+$', idx_str): idx = float(idx_str)
		elif re.match(r'^\d+$', idx_str): idx = int(idx_str)
		# parse form and lemma (special case '_')
		form = form if form is not None else '_'
		lemma = lemma if form != '_' else '_'
		# parse dependency head idx (int)
		head = int(head) if head is not None else head
		# parse FEATS dictionaries
		try:
			feats = {f.split('=')[0]:f.split('=')[1] for f in feats.split('|')}
		except:
			feats = {}
		# parse MISC dictionary
		try:
			misc = {m.split('=')[0]:m.split('=')[1] for m in misc.split('|')}
		except:
			misc = {}
		# construct word
		word = UniversalDependenciesWord(
			idx,
			form, lemma, # form and lemma are str
			upostag, xpostag, # upostag and xpostag are str
			feats,
			head, deprel, deps, # dependency information as str
			misc
		)
		return word

	def to_text(self):
		text = self.get_form() + ' ' # form + space by default
		# if 'SpaceAfter=No' remove trailing space
		if ('SpaceAfter' in self.misc) and (self.misc['SpaceAfter'] == 'No'):
			text = text[:-1]

		return text

	def to_conllu(self):
		conllu = ''

		# convert dictionaries
		feats_str = '|'.join([f'{k}={v}' for k, v in sorted(self.feats.items())]) if self.feats else None
		misc_str = '|'.join([f'{k}={v}' for k, v in sorted(self.misc.items())]) if self.misc else None

		conllu_values = [
			str(self.idx),
			self.form, self.lemma,
			self.upostag, self.xpostag,
			feats_str,
			str(self.head), self.deprel, self.deps,
			misc_str
		]
		# convert None to '_'
		conllu_values = [v if v is not None else '_' for v in conllu_values]

		conllu = '\t'.join(conllu_values)
		return conllu

	def get_form(self):
		form = self.form if self.form else ''
		form = form.replace('\xad', '') # sanitize soft hyphens
		return form
#initialize comfy functions

# basic preprocessing

In [3]:
#preprocessing
def remove_emoji(string):
    emoji_pattern = re.compile("["
                               u"\U0001F600-\U0001F64F"  # emoticons
                               u"\U0001F300-\U0001F5FF"  # symbols & pictographs
                               u"\U0001F680-\U0001F6FF"  # transport & map symbols
                               u"\U0001F1E0-\U0001F1FF"  # flags (iOS)
                               u"\U00002500-\U00002BEF"  # chinese char
                               u"\U00002702-\U000027B0"
                               u"\U00002702-\U000027B0"
                               u"\U000024C2-\U0001F251"
                               u"\U0001f926-\U0001f937"
                               u"\U00010000-\U0010ffff"
                               u"\u2640-\u2642"
                               u"\u2600-\u2B55"
                               u"\u200d"
                               u"\u23cf"
                               u"\u23e9"
                               u"\u231a"
                               u"\ufe0f"  # dingbats
                               u"\u3030"
                               "]+", flags=re.UNICODE)
    return emoji_pattern.sub(r'', string)

def clean_sentences(sents):
        
    clean_sents = []
    for sent in sents:
        sent = re.sub(r'<a_href=\S+', '', sent)
        sent = re.sub(r'<(/)?\w+>', '', sent)
        sent = re.sub("[#*]", "", sent)
        sent = remove_emoji(sent)
        clean_sents.append(sent)
        
    return clean_sents


# BERTOPIC

In [8]:
class bertopicTM():
    
    def __init__(self, preprocessed_sentences):
        
        self.docs = preprocessed_sentences
        self.emb_model = SentenceTransformer("paraphrase-multilingual-MiniLM-L12-v2")
        # self.emb_model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
        
        # self.emb_model = SentenceTransformer("intfloat/multilingual-e5-base")
        # self.docs = [f'query: {text}' for text in preprocessed_sentences] # multilingual-e5-large requirement  
        
        self.embeddings : np.ndarray() = None
        self.umap_model = UMAP(
            n_neighbors=15,
            n_components=5,
            min_dist=0.0,
            metric='cosine',
            # low_memory=False,
            random_state=42)
        self.hdbscan_model = HDBSCAN(min_cluster_size=15, metric='euclidean', cluster_selection_method='eom', prediction_data=True) #try leaf, eom tends to create big clus
        
        #tune hdbscan
        self.tmt = TopicModelTuner(docs = self.docs, embedding_model = self.emb_model, hdbscan_model=self.hdbscan_model, reducer_model=self.umap_model, reducer_random_state=1337, verbose = 2)
        self.BestResultsDF_random = pd.DataFrame()
        self.BestResultsDF_pseudoG = pd.DataFrame()
        self.BestResultsDF_grid = pd.DataFrame()
        
        #topic representation
        self.vectorizer_model = CountVectorizer(max_df = 0.7, ngram_range = (1,2))
        self.ctfidf_model = ClassTfidfTransformer(reduce_frequent_words=True, bm25_weighting=True)#bm25_weighting=True - robustness to stopwords
        
        keybert_model = KeyBERTInspired()
        mmr_model = MaximalMarginalRelevance(diversity=0.5)
        
        self.representation_model = {
            "KeyBERT": keybert_model,
            # "OpenAI": openai_model,  # Uncomment if you will use OpenAI
            "MMR": mmr_model,
        }
        

    def get_range_check(self):

        best_num_clusters = max(self.tmt.ResultsDF.sort_values(by = 'number_uncategorized').query('number_uncategorized < 1000').number_of_clusters.tolist())
        best_clu_size = self.tmt.ResultsDF.sort_values(by = 'number_uncategorized').query('number_uncategorized < 1000 & number_of_clusters == @best_num_clusters').min_cluster_size.tolist()[0]
        if best_clu_size < 30:
            perc = 0.5
        else:
            perc = 0.1
        
        lower_limit = best_clu_size - int(best_clu_size*perc)
        if lower_limit == 1:
            lower_limit = 2
        higher_limit = best_clu_size + int(best_clu_size*perc)
        range_check = [*range(lower_limit,higher_limit)]

        return range_check 
    
    def tune_hdbscan(self):
        
        self.embeddings = self.tmt.createEmbeddings(self.docs)
        self.tmt.reduce()
        self.tmt.randomSearch([*range(10,15)], [.1, .25, .5, .75, 1], iters = 50)
        self.BestResultsDF_random = self.tmt.summarizeResults()
        self.BestResultsDF_pseudoG = self.tmt.pseudoGridSearch(self.get_range_check(), [x/100 for x in range(10,101,10)])       
        self.BestResultsDF_grid = self.tmt.gridSearch(self.get_range_check())
        self.min_clu_size = self.BestResultsDF_grid.query('number_uncategorized == number_uncategorized.min()').min_cluster_size.tolist()[0]
        self.min_samples = self.BestResultsDF_grid.query('number_uncategorized == number_uncategorized.min()').min_samples.tolist()[0]#
          

    def run_TM(self, use_tuning = False, reduce_outliers = False):

        if not use_tuning:

            hdbscan_model=self.hdbscan_model
        
        else:

            self.tune_hdbscan()

            hdbscan_model = HDBSCAN(min_cluster_size=self.min_clu_size, min_samples = self.min_samples, metric='euclidean', cluster_selection_method='eom', prediction_data=True) 

        # All steps together
        model = BERTopic(
              embedding_model=self.emb_model,          # Step 1 - Extract embeddings
              umap_model=self.umap_model,                    # Step 2 - Reduce dimensionality
              hdbscan_model=hdbscan_model,              # Step 3 - Cluster reduced embeddings
              vectorizer_model=self.vectorizer_model,        # Step 4 - Tokenize topics
              ctfidf_model=self.ctfidf_model,                # Step 5 - Extract topic words
              representation_model=self.representation_model, # Step 6 - (Optional) Fine-tune topic represenations
              calculate_probabilities = True,
              # top_n_words = 15
            )
        
        self.model = model.fit(self.docs, self.embeddings)
        # topics, probs = self.model.transform(self.docs, self.embeddings)
        
        if reduce_outliers:
            
            topics_, _ = self.model.transform(self.docs, self.embeddings)
            new_topics = self.model.reduce_outliers(self.docs, topics_, strategy="embeddings")
            self.model.update_topics(self.docs, topics=new_topics)
            
            # new_topics = self.model.reduce_outliers(self.docs, self.model.topics_, strategy="embeddings", threshold = 0.5)
            # self.model.update_topics(self.docs, topics=new_topics, top_n_words = 15, vectorizer_model=self.vectorizer_model, ctfidf_model=self.ctfidf_model)
                          
        return self.model
        
    
    def _calculate_topic_diversity(self):
        
        topic_keywords = self.model.get_topics()

        bertopic_topics = []
        for k,v in topic_keywords.items():
            temp = []
            for tup in v:
                temp.append(tup[0])
            bertopic_topics.append(temp)  

        unique_words = set()
        for topic in bertopic_topics:
            unique_words = unique_words.union(set(topic[:10]))
        td = len(unique_words) / (10 * len(bertopic_topics))

        return td


    def _calculate_cv_npmi(self, docs, topics): 

        doc = pd.DataFrame({"Document": docs,
                        "ID": range(len(docs)),
                        "Topic": topics})
        documents_per_topic = doc.groupby(['Topic'], as_index=False).agg({'Document': ' '.join})
        cleaned_docs = self.model._preprocess_text(documents_per_topic.Document.values)

        vectorizer = self.model.vectorizer_model
        analyzer = vectorizer.build_analyzer()

        words = vectorizer.get_feature_names_out()
        tokens = [analyzer(doc) for doc in cleaned_docs]
        dictionary = corpora.Dictionary(tokens)
        corpus = [dictionary.doc2bow(token) for token in tokens]
        topic_words = [[words for words, _ in self.model.get_topic(topic)] 
                    for topic in range(len(set(topics))-1)]

        coherence_model = CoherenceModel(topics=topic_words, 
                                      texts=tokens, 
                                      corpus=corpus,
                                      dictionary=dictionary, 
                                      coherence='c_v')
        cv_coherence = coherence_model.get_coherence()

        coherence_model_npmi = CoherenceModel(topics=topic_words, 
                                      texts=tokens, 
                                      corpus=corpus,
                                      dictionary=dictionary, 
                                      coherence='c_npmi')
        npmi_coherence = coherence_model_npmi.get_coherence()

        return cv_coherence, npmi_coherence 

In [5]:

def substitute_vals(mean_probs, distance, len_topic_info):
    
    sorted_dict = {k: mean_probs[k] for k in sorted(mean_probs)}

    # Drop key -1 and add keys to complete the range 
    for key in range(len_topic_info):
        if key not in sorted_dict:
            if distance == 'jaccard':
                sorted_dict[key] = 0.0
            else:
                sorted_dict[key] = 0.00001
                
    sorted_dict = {k: sorted_dict[k] for k in sorted(sorted_dict)}
    
    # Substitute all 0.0 values with 0.0001
    for key, value in sorted_dict.items():
        if value == 0.0:
            if distance == 'jaccard':
                sorted_dict[key] = 0.0
            else:
                sorted_dict[key] = 0.00001       
    # try:
    #     sorted_dict.pop(-1)
    # except:
    #     "NO -1 key"
    return sorted_dict

# LREC-WS test

In [11]:
path = 'LREC_treebanks_test'

distances = defaultdict(dict)

for tb_dir in sorted(os.listdir(path)):
        
    ud_path = os.path.join(path, tb_dir)
    
    if not os.path.isdir(ud_path):
        continue
       
    ud = UniversalDependencies.from_directory(ud_path, verbose=True)
    
    prev_id = 0    
    lengths = defaultdict(dict)
    for tb in ud.get_treebanks():
        
        tb_set_leng = len(tb.get_sentences())
        
        lengths[tb.get_name()] = {'s' : prev_id, 
                                  'e': prev_id+tb_set_leng}
        prev_id += tb_set_leng
            
    ################
    # Get all sentences in tb and preprocess basically
    ################
    tb_sents = []
    for tb_set in ud.get_treebanks():
        
        tb_sents+=[s.to_text() for s in tb_set.get_sentences()]
        
        l = tb_set._meta['Language']
        t = tb_set._meta['Treebank']
    
    tb_sents = list(tb_sents) 
    # assert len(tb_sents) == 20970 #1st tb
    # tb_sents = list(set(tb_sents)) #unique
    clean_docs = clean_sentences(tb_sents)
    
    ################
    BT = bertopicTM(clean_docs)
    tb_model = BT.run_TM(use_tuning = False, reduce_outliers = True)
        
    
#     print(f'{l}_{t}_tb_model', tb_model.topic_sizes_)
#     spectopic_indices = [n for n, tid in enumerate(tb_model.topics_) if tid == 0]
#     print('tb_model_sample_sentences: ','\n'.join([s for n, s in enumerate(clean_docs) if n in spectopic_indices][:20]))
    
    # if (l == 'English')&(t == 'GUM'):
    #     tbank_model.save(f"{l}_{t}_model", serialization="safetensors")
    
    results_dev = defaultdict(dict)

    for k in lengths.keys():
        if 'dev' in k:
            
            reg = re.match(r'^([a-zA-Z-]+)__([a-zA-Z-]+)__([a-zA-Z_-]+).dev.conllu', k)
            l = reg.group(1)
            cor = reg.group(2)
            g = reg.group(3)
            
            topic_prop = {key: value / sum(Counter(tb_model.topics_[lengths[k]['s']: lengths[k]['e']]).values()) 
                          for key, value in Counter(tb_model.topics_[lengths[k]['s']: lengths[k]['e']]).items()}
            #dict(sorted(topic_prop.items(), key=lambda item: item[1]))
            results_dev[f'{l}__{cor}'][f'{g}'] = topic_prop
    
    results_train = defaultdict(dict)

    for k in lengths.keys():
        if 'train' in k:
            
            reg = re.match(r'^([a-zA-Z-]+)__([a-zA-Z-]+)__([a-zA-Z_-]+)__s(\d)__rs(\d+).train.conllu', k)
            l = reg.group(1)
            cor = reg.group(2)
            g = reg.group(3)
            samp = reg.group(4)
            rseed = reg.group(5)
            
            topic_prop = {}
            topic_prop = {key: value / sum(Counter(tb_model.topics_[lengths[k]['s']: lengths[k]['e']]).values())
                          for key, value in Counter(tb_model.topics_[lengths[k]['s']: lengths[k]['e']]).items()}

            results_train[f'{l}__{cor}'][f'{g}_{samp}_{rseed}'] = topic_prop
    
    len_topic_info = len(tb_model.get_topic_info())
    
    tbf_id = f'{l}__{t}'
    
    for key_train in results_train[tbf_id].keys():
        for key_dev in results_dev[tbf_id].keys():

            devdict = {}
            devdict = substitute_vals(results_dev[tbf_id][key_dev], 'kl',len_topic_info)
            distr_dev_general = {k: devdict[k] for k in sorted(devdict)}
            distr_dev = [(k, v) for k,v in distr_dev_general.items()]

            devdict = {}
            devdict = substitute_vals(results_dev[tbf_id][key_dev], 'jaccard',len_topic_info)
            distr_dev_jcrd = {k: devdict[k] for k in sorted(devdict)}
            distr_dev_jcrd = [(k, v) for k,v in distr_dev_jcrd.items()]


            devdict = {}
            # print(g, dict(sorted(topic_prop.items(), key=lambda item: item[1])))
            devdict = substitute_vals(results_train[tbf_id][key_train], 'kl',len_topic_info)
            distr_train_general = {k: devdict[k] for k in sorted(devdict)}
            distr_train = [(k, v) for k,v in distr_train_general.items()]

            devdict = {}
            devdict = substitute_vals(results_train[tbf_id][key_train], 'jaccard',len_topic_info)
            distr_train_jcrd = {k: devdict[k] for k in sorted(devdict)}
            distr_train_jcrd = [(k, v) for k,v in distr_train_jcrd.items()]


            between = f"{key_train}---{key_dev}" 

            # print(tbf_id, between)
            hlngr = hellinger(distr_train, distr_dev)
            kl = kullback_leibler(distr_train, distr_dev)
            jcrd = jaccard(distr_train_jcrd, distr_dev_jcrd)

            # print(key_train, key_dev, hlngr, kl, jcrd)
            
            distances[tbf_id][between] = {'distances': {'kl': kl, 'hlngr': hlngr, 'jcrd': jcrd}}

            
# uncomment to pickle the distances
# with open(f'distances_allmodels.pkl', 'wb') as outp:
#     pickle.dump(distances, outp, pickle.HIGHEST_PROTOCOL)

defaultdict(None, {'Language': 'Belarusian', 'Treebank': 'HSE', 'Genre': 'fiction'})
Loaded <UniversalDependenciesTreebank (Belarusian__HSE__fiction.dev.conllu): 500 sentences>.
defaultdict(None, {'Language': 'Belarusian', 'Treebank': 'HSE', 'Genre': 'news'})
Loaded <UniversalDependenciesTreebank (Belarusian__HSE__news.dev.conllu): 636 sentences>.
defaultdict(None, {'Language': 'Belarusian', 'Treebank': 'HSE', 'Genre': 'news', 'Sample_no': 's1', 'Seed': 'rs0'})
Loaded <UniversalDependenciesTreebank (Belarusian__HSE__news__s1__rs0.train.conllu): 881 sentences>.
defaultdict(None, {'Language': 'Belarusian', 'Treebank': 'HSE', 'Genre': 'news', 'Sample_no': 's1', 'Seed': 'rs1234'})
Loaded <UniversalDependenciesTreebank (Belarusian__HSE__news__s1__rs1234.train.conllu): 872 sentences>.
defaultdict(None, {'Language': 'Belarusian', 'Treebank': 'HSE', 'Genre': 'news', 'Sample_no': 's1', 'Seed': 'rs42'})
Loaded <UniversalDependenciesTreebank (Belarusian__HSE__news__s1__rs42.train.conllu): 871 sen

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Belarusian_HSE_tb_model Counter({1: 811, 0: 636, 3: 622, 2: 447, 4: 404, 5: 303, 6: 248, 11: 242, 10: 237, 13: 232, 19: 230, 36: 223, 21: 208, 7: 201, 12: 193, 22: 181, 14: 161, 8: 161, 9: 157, 111: 149, 158: 149, 32: 148, 40: 145, 18: 144, 26: 136, 145: 130, 16: 127, 155: 126, 23: 123, 38: 122, 152: 121, 15: 119, 17: 119, 92: 117, 59: 113, 28: 113, 24: 112, 42: 111, 20: 110, 49: 109, 54: 109, 30: 108, 219: 108, 65: 106, 179: 105, 39: 105, 46: 105, 25: 104, 27: 103, 229: 101, 33: 99, 103: 98, 269: 96, 37: 96, 166: 96, 50: 95, 41: 95, 228: 94, 44: 94, 80: 92, 89: 92, 34: 91, 53: 91, 71: 90, 29: 90, 31: 90, 35: 89, 253: 89, 47: 88, 143: 85, 48: 85, 108: 85, 61: 84, 173: 83, 150: 83, 112: 83, 56: 83, 85: 82, 79: 81, 51: 81, 73: 81, 76: 80, 201: 80, 221: 80, 133: 78, 84: 78, 60: 78, 252: 75, 113: 74, 52: 72, 192: 71, 135: 70, 161: 70, 45: 70, 55: 70, 74: 69, 126: 69, 69: 69, 178: 69, 67: 69, 72: 68, 266: 68, 125: 67, 251: 66, 66: 66, 43: 66, 57: 66, 88: 65, 137: 65, 140: 65, 81: 64, 162: 6

AssertionError: 