In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
!pip install Elasticsearch



In [None]:
import requests as req
import json
import xml.etree.ElementTree as ET
import itertools
import os
import argparse
import sys
sys.path.append('..')
# from biomedkg_utils import switch_dictset_to_dictlist
from elasticsearch import Elasticsearch
from multiprocessing import cpu_count, Process

In [None]:
def switch_dictset_to_dictlist(the_dict):
    '''
    FUNCTION:
    - Make a new dictionary with values as lists
      instead of values as sets

    PARAMS:
    - the_dict: The initial dict with values of sets
    '''

    dictlist = dict()

    for k in the_dict.copy():
        dictlist[k] = list(the_dict[k])

    return dictlist


In [None]:
def map_disease_mesh_name_to_id(mesh_desc_file = '../data/MeSH/desc2022.xml',
                                output_folder='../parsed_mappings/MeSH'):
    '''
    FUNCTION:
    - Map the disease MeSH names/terms to the
      MeSH IDs.
    '''
    print('Mapping MeSH disease names to IDs')
    tree = ET.parse(mesh_desc_file)
    root = tree.getroot()

    name2id = dict()

    for ele in root:
        try:
            # MeSH Tree Number
            tree_numbers = ele.find('TreeNumberList').findall('TreeNumber')

            # If Tree is a disease
            for tree_number in tree_numbers:
                if tree_number.text.startswith(('C','F03')):
                    tree_number = tree_number.text

                    # ID to Name
                    try:
                        ID = ele.find('DescriptorUI').text
                        name = ele.find('DescriptorName').find('String').text
                        name2id.setdefault(name,set()).add(ID)
                    except:
                        pass
        except:
            continue

    name2id = switch_dictset_to_dictlist(name2id)
    json.dump(name2id, open(os.path.join(output_folder,'name2id.json'),'w'))

In [None]:
def map_categories2terms(output_folder,
                         meshterms_per_cat_file='../parsed_mappings/MeSH/meshterms_per_cat.json',
                         category_names_file='../config/textcube_config.json'):
    '''
    FUNCTION:
    - Map disease categories -> MeSH terms
    '''
    print('Mapping disease categories to their terms and subcategory terms')
    terms_lol = json.load(open(meshterms_per_cat_file, 'r'))
    category_names = json.load(open(category_names_file, 'r'))
    category2terms = dict()
    for category_name, term_list in zip(category_names, terms_lol):
        category2terms[category_name] = term_list

    json.dump(category2terms,
              open(os.path.join(output_folder, 'category2terms.json'), 'w'))

    return category2terms

In [None]:
def get_mesh_synonyms_api(category2terms, name2id,
                          output_folder='../parsed_mappings/MeSH/'):
    '''
    FUNCTION:
    - Get MeSH synonyms via API
    '''
    category2synonyms = dict()
    total_categories = len(category2terms)
    for category_num, (category, terms) in enumerate(category2terms.items()):
        print(str(category_num) + '/' + str(total_categories), end='\r')

        # Terms in category
        for term in terms:

            mesh_id = name2id[term]
            assert len(mesh_id) == 1
            r_string = 'https://id.nlm.nih.gov/mesh/lookup/details?descriptor=' + mesh_id[0]
            r = req.get(r_string).json()

            # Add term synonyms to the category
            for entry in r['terms']:
                synonym = entry['label']
                category2synonyms.setdefault(category, set()).add(synonym)

    category2synonyms = switch_dictset_to_dictlist(category2synonyms)
    json.dump(category2synonyms,
              open(os.path.join(output_folder, 'category2synonyms.json'), 'w'))

    return category2synonyms

In [None]:
def permute_mesh_synonyms(mesh_synonyms):
    '''
    FUNCTION:
    - Permute MeSH synonyms

    PARAMS:
    - mesh_synonyms: list of mesh synonyms
    '''

    temp_mesh_synonyms = list()
    for mesh_synonym in mesh_synonyms:
        if ',' in mesh_synonym:
            mesh_synonym = mesh_synonym.split(', ')
            mesh_synonym_permuted = [' '.join(list(permuted_synonym)) for permuted_synonym in
                                     list(itertools.permutations(mesh_synonym))]
            temp_mesh_synonyms += mesh_synonym_permuted
        else:
            temp_mesh_synonyms.append(mesh_synonym)

    return temp_mesh_synonyms

In [None]:
def map_category_to_permuted_mesh_synonyms(category2synonyms,
                                           output_folder='../parsed_mappings/MeSH/'):
    '''
    FUNCTION:
    - permute the MeSH Synonyms because sometimes they're
      written weirdly with commas (e.g., Order, Messed, Up, Is;
      Messed, Order, Up)

    PARAMS:
    - category2synonyms: category -> MeSH synonyms mapping
    '''
    print('Mapping disease categories to permuted MeSH synonyms')

    # Permuted synonym->category dictionary
    category2permuted_synonyms, permuted_synonyms2category = dict(), dict()
    for category, synonyms in category2synonyms.items():
        synonyms = permute_mesh_synonyms(synonyms)
        for synonym in synonyms:
            category2permuted_synonyms.setdefault(category, set()).add(synonym)
            permuted_synonyms2category.setdefault(synonym, set()).add(category)

    category2permuted_synonyms = switch_dictset_to_dictlist(category2permuted_synonyms)
    permuted_synonyms2category = switch_dictset_to_dictlist(permuted_synonyms2category)

    json.dump(category2permuted_synonyms,
              open(os.path.join(output_folder, 'category2permuted_synonyms.json'), 'w'))
    json.dump(permuted_synonyms2category,
              open(os.path.join(output_folder, 'permuted_synonyms2category.json'), 'w'))

    return category2permuted_synonyms, permuted_synonyms2category

In [None]:
def get_all_uncategorized_pmids(index_name):
    '''
    FUNCTION:
    - Get the PMIDs of all the PubMed articles that
      don't have MeSH labels.
    - Purpose: For applying to unlabeled PubMed articles

    PARAMS:
    - index_name: The name of the ElasticSearch index
      where all the PubMed articles are indexed/stored
    '''

    es = Elasticsearch()
    relevant_uncategorized_pmids = set()
    for num_pmids, entry in enumerate(es_iterate_all_documents(es, index_name)):

        # Publication's MeSH (if any)
        unlabeled = (len(entry['MeSH']) == 0)

        # Save PMIDs of unlabeled documents
        if unlabeled:
            relevant_uncategorized_pmids.add(entry['pmid'])

            # Print progress
            if num_pmids % 10000 == 0:
                print(str(num_pmids) + ' PMIDs processed', end='\r')

    return list(relevant_uncategorized_pmids)

In [None]:
def get_document_text(entry):
    '''
    FUNCTION:
    - Get the full text of the PubMed publication

    PARAMS:
    - entry (dict): parsed indexed publication
    '''

    # Title
    title = entry['_source']['title']
    if type(title) != str:
        title = ''
    title.replace('\n', ' ').replace('\t', ' ').replace('   ', ' ')

    # Abstract
    abstract = entry['_source']['abstract']
    if type(abstract) != str:
        abstract = ''
    abstract.replace('\n', ' ').replace('\t', ' ').replace('   ', ' ')

    return abstract, title

In [None]:
def ds_label_matching(batch_id, relevant_pmid_batch,
                      index_name, index_type,
                      label_unlabeled_only,
                      label_labeled_only,
                      label_all,
                      filter_list,
                      stop_at_this_many_pmids,
                      run, test,
                      output_folder):
    '''
    FUNCTION:
    - Using MeSH term synonyms, label a document with a MeSH term if the
      lowercased term matches exactly within the text.

    PARAMS:
    - relevant_pmid_batch: PMIDs to look for in this batch
    - temp_outfile: temporary output file of PMID | MeSH Term
    '''
    es = Elasticsearch()
    if run:
        rot = 'run'
    elif test:
        rot = 'test'
    else:
        raise Except('Did not specify if running or testing')

    temp_outfile = os.path.join(output_folder, rot+'_temp_labeling'+str(batch_id)+'.txt')
    procs = cpu_count()
    category2permuted_synonyms = json.load(open(os.path.join(output_folder,'category2permuted_synonyms.json'),'r'))

    with open(temp_outfile,'w') as fout, \
    open(temp_outfile[:-4]+'synonym'+'.txt','w') as fout1:

        # Get PubMed article's text
        for num_pmids, pmid in enumerate(relevant_pmid_batch):
            entry = es.get(id = pmid, index = index_name, doc_type = index_type)

            # Print progress and break early
            if batch_id == 1 and num_pmids % 1000 == 0:
                print(str(num_pmids)+' PMIDs processed in one of the batches',\
                      end='\r')
            if num_pmids > stop_at_this_many_pmids/procs:
                break


            # Determine which publications to find labels for
            labeled_meshes = entry['_source']['MeSH']
            labeled = len(labeled_meshes) > 0
            unlabeled = len(labeled_meshes) == 0
            if label_labeled_only and labeled:
                pass
            elif label_unlabeled_only and unlabeled:
                pass
            elif label_all:
                pass
            else:
                continue


            # Publication's text (title, abstract, full text if provided)
            abstract, title = get_document_text(entry)
            document_text = title + ' ' + abstract
            title = ' '+title+' '
            title = title.replace(',',' ')
            title = title.replace(':',' ')

            # Only consider documents that have certain broader key words
            # E.g., for cardiovascular disease publications, only look
            # at publications that say "heart" or "cardiac"
            #dont_label = True
            #for filter_word in filter_words:
            #    if filter_word in document_text.lower():
            #        dont_label = False
            #        break
            #if dont_label:
            #    continue

            #print(dont_label, 'dont_label')

            #fout.write(title+' | '+'ignore'+'\n')

            # Check if MeSH synonym is in the text
            for category, synonyms in category2permuted_synonyms.copy().items():
                found_syns = set()
                one_syn_in_title = False

                # Each synonym in a set category
                for synonym in synonyms:

                    # Synonym in title
                    if ' '+synonym.lower()+' ' in title.lower():
                        one_syn_in_title = True
                        found_syns.add(synonym)
                        #continue

                    # Synonym in abstract

                    if synonym.lower() in abstract.lower():
                        add = True

                        # If similar synonym hasn't been counted already
                        for found_syn in found_syns:
                            fsyn = found_syn.lower().replace('\'','')
                            syn = synonym.lower()
                            if fsyn in syn or syn in fsyn:
                                add = False
                                break
                        if add:
                            found_syns.add(synonym)

                    # Categorize text with 1+ synonym per category in the text
                    # (This could be modified to include confidence levels for
                    #  how many synonyms were found, remove break then)
                    if one_syn_in_title or len(found_syns) > 1:
                        #print(pmid, category, found_syns)
                        fout.write(pmid+'|'+category+'\n')
                        fout1.write(pmid+'|'+synonym+'\n')
                        break

        try:
            print(str(num_pmids)+' PMIDs processed in one of the batches')
        except:
            # Not enough PMIDs to require doing all the batches
            pass

In [None]:
def multiprocess_ds_label_matching(pmids, index_name, index_type,
                                   label_unlabeled_only,
                                   label_labeled_only,
                                   label_all,
                                   filter_list,
                                   the_function,
                                   stop_at_this_many_pmids,
                                   run = False,
                                   test = False,
                                   output_folder='../parsed_mappings/MeSH/'):
    '''
    FUNCTION:
    - This takes a list of strings and splits it into input
      for separate processes. The processes then output
      their results to temp files which are then merged.

    PARARMS:
    - pmids: the list to be split into input for
      a multiprocessing function
    - the_function: the function that will use the list
      as input for multiprocessing
    '''
    # How many processors can be used
    procs = cpu_count() if len(pmids) > cpu_count() else len(pmids)

    # List of batches for multiprocessing
    batches = [[] for i in range(procs)]

    # Length of input list
    tot = len(pmids)

    # Create batches and send to multiprocessing
    for i, item in enumerate(pmids):

        # Add synonym to a batch
        b_id = i%procs
        batches[b_id].append(item)

    # Create a list of jobs
    print("Running jobs...")
    jobs = []
    for b_id, batch in enumerate(batches):
        jobs.append(Process(target = the_function,
                            args = [b_id, batch,
                                    index_name, index_type,
                                    label_unlabeled_only,
                                    label_labeled_only,
                                    label_all,
                                    filter_list,
                                    stop_at_this_many_pmids,
                                    run, test, output_folder]))

    # Run the jobs
    for j in jobs: j.start()
    for j in jobs: j.join()
    print('Done!')

In [None]:
def merge_pmid2new_mesh_labels(run=False, test=False, output_folder = '../parsed_mappings/MeSH/'):
    '''
    FUNCTION:
    - Merges the separate files containing PMID|category_name
      for the imputed category_name labels.
    '''

    pmid2imputed_meshsynonym, pmid2imputed_category = dict(), dict()
    procs = cpu_count()
    if run:
        rot = 'run'
    elif test:
        rot = 'test'

    for batch_id in range(procs):
        ''' PMID - MeSH Category '''
        temp_outfile = os.path.join(output_folder, rot+'_temp_labeling'+str(batch_id)+'.txt')
        with open(temp_outfile) as fin:
            for line in fin:
                line = line.split('|')

                # PMID, MeSH
                assert len(line) == 2
                pmid = line[0]
                mesh = line[1].strip()

                # PMID->MeSH
                pmid2imputed_category.setdefault(pmid, set()).add(mesh)

        ''' PMID - MeSH Synonym '''
        with open(temp_outfile[:-4]+'synonym'+'.txt') as fin1:
            for line in fin1:
                line = line.split('|')

                # PMID, MeSH
                assert len(line) == 2
                pmid = line[0]
                mesh = line[1].strip()

                # PMID->MeSH
                pmid2imputed_meshsynonym.setdefault(pmid, set()).add(mesh)

    pmid2imputed_category = switch_dictset_to_dictlist(pmid2imputed_category)
    pmid2imputed_meshsynonym = switch_dictset_to_dictlist(pmid2imputed_meshsynonym)
    print(len(pmid2imputed_category), 'PMIDs with imputed labels')

    # Export PMID-Category mappings to dictionaries
    json.dump(pmid2imputed_category,
              open(os.path.join(output_folder, 'pmid2imputed_category.json'), 'w'))
    json.dump(pmid2imputed_meshsynonym,
              open(os.path.join(output_folder, 'pmid2imputed_mesh_synonym.json'), 'w'))

    return pmid2imputed_category, pmid2imputed_meshsynonym

In [None]:
def index_imputed_mesh_categories(index_name, index_type):
    '''
    FUNCTION:
    - index the imputed MeSH categories as MeSH terms
      in the ElasticSearch index

    PARAMS:
    - index_name: Name of the ElasticSearch index
    - index_type: Type name of the ElastichSearch index
    '''

    es = Elasticsearch()
    for pmid, imputed_category in pmid2imputed_category.items():

        # Get current MeSH terms
        entry = es.get(id = pmid,
                       index = index_name,
                       doc_type = index_type)
        mesh_terms = entry['_source']['MeSH']
        mesh_terms += imputed_category
        mesh_terms = list(set(mesh_terms))

        # Update each publication's index
        es.update(index = index_name,
                  id = pmid,
                  doc_type = index_type,
                  doc = {'MeSH': mesh_terms})

In [None]:
def update_textcube_files(data_folder='../data/',
                          config_folder='../config/',
                          parsed_mapping_folder='../parsed_mappings'):
    '''
    FUNCTION:
    - Add category names to the considered MeSH Terms
    - Add pmid-category to pmid2category mapping files
    '''
    meshterms_per_cat = json.load(open(os.path.join(mapping_folder, 'MeSH/meshterms_per_cat.json'), 'r'))
    meshterms_per_cat = [set(meshlist) for meshlist in meshterms_per_cat]

    category_names = json.load(open(os.path.join(config_folder, 'textcube_config.json'), 'r'))

    for i in range(0, len(category_names)):
        meshterms_per_cat[i].add(category_names[i])
    meshterms_per_cat = [list(meshlist) for meshlist in meshterms_per_cat]
    json.dump(meshterms_per_cat, open(os.path.join(mapping_folder, 'MeSH/meshterms_per_cat.json'), 'w'))

    textcube_pmid2category = json.load(open(os.path.join(data_folder, 'textcube_pmid2category.json'), 'r'))
    textcube_category2pmid = json.load(open(os.path.join(data_folder, 'textcube_category2pmid.json'), 'r'))

    ''' Update textcube_category2pmid '''
    category_names = json.load(open(os.path.join(config_folder, 'textcube_config.json'), 'r'))
    category_name2num = {name: num for num, name in enumerate(category_names)}

    for pmid, imputed_categories in pmid2imputed_category.items():
        for imputed_category in imputed_categories:
            cat_num = category_name2num[imputed_category]
            textcube_category2pmid[cat_num].append(pmid)

    for i in range(0, len(textcube_category2pmid)):
        textcube_category2pmid[i] = list(set(textcube_category2pmid[i]))

    ''' Update textcube_pmid2category '''
    new_textcube_pmid2category = list()
    for cat_num, pmid_list in enumerate(textcube_category2pmid):
        for pmid in pmid_list:
            new_textcube_pmid2category.append([pmid, cat_num])

    json.dump(new_textcube_pmid2category, open(os.path.join(data_folder, 'textcube_pmid2category.json'), 'w'))
    json.dump(textcube_category2pmid, open(os.path.join(data_folder, 'textcube_category2pmid.json'), 'w'))

In [None]:
def get_all_categorized_pmids(index_name, outpath):
    '''
    FUNCTION:
    - Get the PMIDs of all the PubMed articles that
      don't have MeSH labels.
    - Purpose: For applying to unlabeled PubMed articles

    PARAMS:
    - index_name: The name of the ElasticSearch index
      where all the PubMed articles are indexed/stored
    - outpath: Where the categorized PMIDs will go.
    '''

    es = Elasticsearch()
    relevant_categorized_pmids = set()
    for num_pmids, entry in enumerate(es_iterate_all_documents(es, index_name)):

        # Publication's MeSH (if any)
        labeled = (len(entry['MeSH']) > 0)

        # Save PMIDs of unlabeled documents
        if labeled:
            relevant_categorized_pmids.add(entry['pmid'])

            # Print progress
            if num_pmids % 10000 == 0:
                print(str(num_pmids) + ' PMIDs processed', end='\r')

    relevant_categorized_pmids = list(relevant_categorized_pmids)
    json.dump(relevant_categorized_pmids, open(outpath + 'relevant_categorized_pmids.json', 'w'))

In [None]:
def get_groundtruth_pmid2categories(relevant_pmid_batch, index_name,
                                    index_type, permuted_synonyms2category):
    '''
    FUNCTION:
    - This gets the ground truth labels, the MeSH-labeled PubMed documents

    PARAMS:
    - relevant_pmid_batch: This is the list of the PubMed IDs you want to
      get ground truth for
    - index_name: name of the ElasticSearch index
    - index_type: name of the type of ElasticSearch index
    - permuted_synonyms2category: MeSH synonyms -> category
    '''
    pmids2real_categories = dict()
    es = Elasticsearch()
    total_pmids = len(relevant_pmid_batch)

    for num_pmids, pmid in enumerate(relevant_pmid_batch):
        if num_pmids % 10000 == 0:
            print('Getting real mappings' + str(num_pmids)+'/'+str(total_pmids),\
                  end='\r')
        entry = es.get(id = pmid, index = index_name, doc_type = index_type)
        pmids2real_categories[pmid] = set()

        # Publication's MeSH (if any)
        meshes = entry['_source']['MeSH']
        for mesh in meshes:
            try:
                categories = permuted_synonyms2category[mesh]
                for category in categories:
                    pmids2real_categories[pmid].add(category)
            except:
                pass
    pmids2real_categories = switch_dictset_to_dictlist(pmids2real_categories)
    return pmids2real_categories

In [None]:
def evaluate_label_imputation(pmid2imputed_category, pmid2real_categories):
    tp, fp, tn, fn = 0,0,0,0

    for pmid in pmid2real_categories:
        real_categories = pmid2real_categories[pmid]
        try: imputed_categories = pmid2imputed_category[pmid]
        except: imputed_categories = []
        real_and_imputed = real_categories + imputed_categories

        for category in real_and_imputed:
            #print(real_and_imputed)
            #print(real_categories)
            #print(imputed_categories)

            # Real = Yes
            if category in real_categories:

                # Real = Yes, Impute = Yes
                if category in imputed_categories:
                    tp += 1

                # Real = Yes, Impute = No
                else:
                    fn += 1

            elif category not in real_categories:

                # Real = No, Impute = Yes
                if category in imputed_categories:
                    fp += 1

    print('Precision', round(tp/(tp+fp), 4))
    print('Recall', round(tp/(tp+fn), 4))
    print('TP', tp, 'FP', fp, 'FN',fn)

In [None]:
index_name = 'pubmed_lift'
index_type = 'pubmed_meta_lift'
undo_category_label_imputation = False
undo_last_category_label_imputation_only = False
label_imputation = True
run = True
test = False

if undo_category_label_imputation:
    label_imputation = False

if test == True:
    run = False
STOP_AT_THIS_MANY_PMIDS = 9999999999
filter_words = ['placeholder', 'these arent used anymore']

In [None]:
root_directory = '../'
data_folder=os.path.join(root_directory,"data")
config_folder=os.path.join(root_directory,"config")
mapping_folder=os.path.join(root_directory,"parsed_mappings")
output_folder = os.path.join(mapping_folder,'Output')

In [None]:
''' Label imputation '''
if label_imputation:

    '''Map MeSH ID - Name'''
    map_disease_mesh_name_to_id()


    meshterms_per_cat_file = os.path.join(mapping_folder,'MeSH/meshterms_per_cat.json')
    category_names_file = os.path.join(config_folder,'textcube_config.json')

    ''' Get MeSH Synonyms '''
    # Category - MeSH Terms
    category2main_terms = map_categories2terms(output_folder = output_folder,
                            meshterms_per_cat_file = meshterms_per_cat_file,
                            category_names_file = category_names_file)

    # Category - MeSH Terms' Synonyms (including terms)
    name2id = json.load(open(os.path.join(mapping_folder,'MeSH/meshterm-IS-meshid.json'),'r'))
    category2main_terms = json.load(open(os.path.join(mapping_folder, 'MeSH/category2terms.json')))
    try:  category2synonym_terms = json.load(open(os.path.join(mapping_folder,'MeSH/category2synonyms.json')))
    except: category2synonym_terms = get_mesh_synonyms_api(category2main_terms, name2id,output_folder=output_folder)

    # Category - permuted MeSH Terms' Synonyms
    temp1, temp2 =  map_category_to_permuted_mesh_synonyms(category2synonym_terms,output_folder=output_folder)
    category2permuted_synonyms = temp1
    permuted_synonyms2category = temp2


In [None]:
if run:
    print('Running label imputation')
    '''Undo last label imputation'''
    try:
        pmid2imp_cat_path = os.path.join(mapping_folder,'MeSH/pmid2imputed_category.json')
        pmid2imputed_category = json.load(open(pmid2imp_cat_path))
        try:
            print('Removing the labels imputed last time')
            remove_imputed_category_mesh_terms_previous_li(index_name,
                                                           index_type,
                                                           pmid2imputed_category)
        except:
            raise Exception('Couldnt remove the last indexed imputed labels')
    except:
        print('No previous label imputation')
    '''Get relevant PMIDs'''
    try:
        relevant_pmids = json.load(open(data_folder+'all_uncategorized_pmids.json'))
    except:
        relevant_pmids = get_all_uncategorized_pmids(index_name)
    json.dump(relevant_pmids, open(os.path.join(data_folder,'all_uncategorized_pmids.json'),'w'))

    print(len(relevant_pmids), 'relevant pmids')

    ''' Impute missing MeSH labels '''
    # Impute PMIDs' Categories (i.e., Impute missing MeSH labels)
    multiprocess_ds_label_matching(pmids = relevant_pmids,
                                   index_name = index_name,
                                   index_type = index_type,
                                   label_unlabeled_only = True,
                                   label_labeled_only = False,
                                   label_all = False,
                                   filter_list = filter_words,
                                   stop_at_this_many_pmids = STOP_AT_THIS_MANY_PMIDS,
                                   the_function = ds_label_matching,
                                   output_folder = output_folder,
                                   run = True)
    pmid2imputed_category, pmid2imputed_meshsynonym = merge_pmid2new_mesh_labels(run=True,output_folder=output_folder)
    # Index the imputed MeSH categories into their PMID entries
    # index_imputed_mesh_categories(index_name=index_name, index_type=index_type)
    # Update the MeSH Terms Per Category file for the textcube
    # update_textcube_files(data_folder=data_folder, config_folder=config_folder, parsed_mapping_folder=mapping_folder)

In [None]:
if test:
    print('Testing label imputation')

    '''Get relevant PMIDs'''
    # Ground truth: PMIDs in your study already labeled by NIH with MeSH terms
    #relevant_pmids = json.load(open('../caseolap/data/pmids.json'))
    try:
        relevant_pmids = json.load(open(data_folder+'/relevant_categorized_pmids.json'))
    except:
        relevant_pmids_path = data_folder+'/relevant_categorized_pmids.json'
        relevant_pmids = get_all_categorized_pmids('pubmed_lift', relevant_pmids_path)
        json.dump(relevant_pmids, open(relevant_pmids_path, 'w'))
    ''' Impute missing MeSH labels '''
    # Impute PMIDs' Categories (i.e., Impute missing MeSH labels)
    multiprocess_ds_label_matching(pmids = relevant_pmids,
                                   index_name = index_name,
                                   index_type = index_type,
                                   label_unlabeled_only = False,
                                   label_labeled_only = True,
                                   label_all = False,
                                   filter_list = filter_words,
                                   stop_at_this_many_pmids = STOP_AT_THIS_MANY_PMIDS,
                                   the_function = ds_label_matching,
                                   test = True)
    pmid2imputed_category, pmid2imputed_meshsynonym = merge_pmid2new_mesh_labels(test=True)
    # PMID - Ground Truth Categories
    try:
        pmid2real_categories = json.load(open(output_folder+'/pmid2real_categories.json'))
    except:
        pmid2real_categories = ['']
    if len(pmid2real_categories) != STOP_AT_THIS_MANY_PMIDS:
        pmid2real_categories = get_groundtruth_pmid2categories(relevant_pmids, index_name, index_type, permuted_synonyms2category)
        json.dump(pmid2real_categories, open(output_folder+'/pmid2real_categories.json','w'))
    # Evaluate results on ground truth
    evaluate_label_imputation(pmid2imputed_category, pmid2real_categories)

# Start of exploration

In [None]:
import requests as req
import json
import xml.etree.ElementTree as ET
import itertools
import os
import argparse
import sys
sys.path.append('..')
from utils.biomedkg_utils import switch_dictset_to_dictlist
from elasticsearch import Elasticsearch
from multiprocessing import cpu_count, Process

In [None]:
def es_iterate_all_documents(es, index, pagesize=250, scroll_timeout="1m", **kwargs):
    """
    Helper to iterate ALL values from a single index
    Yields all the documents.
    Source: https://techoverflow.net/2019/05/07/elasticsearch-how-to-iterate-scroll-through-all-documents-in-index/
    """
    is_first = True
    while True:
        # Scroll next
        if is_first: # Initialize scroll
            result = es.search(index=index, scroll="1m", **kwargs,
                               size = pagesize)
            is_first = False
        else:
            result = es.scroll(
                scroll_id = scroll_id,
                scroll = scroll_timeout)
        scroll_id = result["_scroll_id"]
        hits = result["hits"]["hits"]

        # Stop after no more docs
        if not hits:
            break

        # Yield each entry
        yield from (hit['_source'] for hit in hits)

In [None]:
def get_pmids(index_name):
    '''
    Get the PMIDs of the PubMed articles that have/don't have MeSH labels or full text
    '''

    es = Elasticsearch()
    uncategorized_pmids = list()
    categorized_pmids = list()
    uncate_text_pmids = list()
    cate_text_pmids = list()
    for num_pmids, entry in enumerate(es_iterate_all_documents(es, index_name)):

        # Publication's MeSH (if any)
        unlabeled = (len(entry['MeSH']) == 0)

        # Save PMIDs
        if unlabeled:
            uncategorized_pmids.append(entry['pmid'])
        else:
            categorized_pmids.append(entry['pmid'])

        # has full text or not
        has_full_text = (len(entry['full_text']) > 0)

        if has_full_text:
            if unlabeled:
                uncate_text_pmids.append(entry['pmid'])
            else:
                cate_text_pmids.append(entry['pmid'])

        # Print progress
        if num_pmids % 10000 == 0:
            print(str(num_pmids) + ' PMIDs processed', end='\r')

    return list(set(uncategorized_pmids)), list(set(categorized_pmids)), list(set(uncate_text_pmids)), list(set(cate_text_pmids))

In [None]:
uncate_pmids, cate_pmids, text_pmids = get_pmids(index_name)

In [None]:
def get_groundtruth_pmid2categories(relevant_pmid_batch, index_name,
                                    index_type, permuted_synonyms2category):
    '''
    FUNCTION:
    - This gets the ground truth labels, the MeSH-labeled PubMed documents

    PARAMS:
    - relevant_pmid_batch: This is the list of the PubMed IDs you want to
      get ground truth for
    - index_name: name of the ElasticSearch index
    - index_type: name of the type of ElasticSearch index
    - permuted_synonyms2category: MeSH synonyms -> category
    '''
    pmids2real_categories = dict()
    es = Elasticsearch()
    total_pmids = len(relevant_pmid_batch)

    for num_pmids, pmid in enumerate(relevant_pmid_batch):
        if num_pmids % 10000 == 0:
            print('Getting real mappings' + str(num_pmids)+'/'+str(total_pmids),\
                  end='\r')
        entry = es.get(id = pmid, index = index_name, doc_type = index_type)
        pmids2real_categories[pmid] = list()

        # Publication's MeSH (if any)
        meshes = entry['_source']['MeSH']
        for mesh in meshes:
            try:
                categories = permuted_synonyms2category[mesh]
                pmids2real_categories[pmid] = list(set(pmids2real_categories[pmid] + categories))
            except:
                pass

    pmids2real_categories = switch_dictset_to_dictlist(pmids2real_categories)
    return pmids2real_categories

In [None]:
pmid2cate = get_groundtruth_pmid2categories(all_doc_pmids, index_name, index_type, permuted_synonyms2category)

In [None]:
import sys
# sys.path.append('/home/ubuntu/InternProjects/Joanne/caseolap_lift/data')

cat = open('/home/ubuntu/InternProjects/Joanne/caseolap_lift/data/relevant_categorized_pmids.json')
categorized_pmids = json.load(cat)

uncat = open('/home/ubuntu/InternProjects/Joanne/caseolap_lift/data/all_uncategorized_pmids.json')
uncategorized_pmids = json.load(uncat)

cat.close()
uncat.close()

In [None]:
pmids = categorized_pmids + uncategorized_pmids

In [None]:
len(pmids)

In [None]:
def generate_ctg2pmid(cat_pmids, uncat_pmids, all_pmids):
    '''
    generate a dictionary with categories as keys and pmids as values
    '''

    ctg2pmid = {'None': uncat_pmids}

    for num, key in enumerate(cat_pmids):

        # Print progress
        if num % 10000 == 0:
            print(str(num) + ' PMIDs processed', end='\r')

        if key in all_pmids:
            tags = all_pmids[key]
            if tags:
                for tag in tags:
                    if tag in ctg2pmid:
                        ctg2pmid[tag].append(key)
                    else:
                        ctg2pmid[tag] = [key]

    return ctg2pmid

# BERTopic quick start

In [None]:
! pip install bertopic

Collecting bertopic
  Downloading bertopic-0.15.0-py2.py3-none-any.whl (143 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/143.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m143.4/143.4 kB[0m [31m4.7 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.4/143.4 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
Collecting hdbscan>=0.8.29 (from bertopic)
  Downloading hdbscan-0.8.33.tar.gz (5.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.2/5.2 MB[0m [31m27.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting umap-learn>=0.5.0 (from bertopic)
  Downloading umap-learn-0.5.3.tar.gz (88 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.2/88.2 kB[0

In [None]:
from bertopic import BERTopic
from sklearn.datasets import fetch_20newsgroups

docs = fetch_20newsgroups(subset='all',  remove=('headers', 'footers', 'quotes'))['data']

topic_model = BERTopic()
topics, probs = topic_model.fit_transform(docs)

Downloading (…)e9125/.gitattributes:   0%|          | 0.00/1.18k [00:00<?, ?B/s]

Downloading (…)_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading (…)7e55de9125/README.md:   0%|          | 0.00/10.6k [00:00<?, ?B/s]

Downloading (…)55de9125/config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

Downloading (…)ce_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

Downloading (…)125/data_config.json:   0%|          | 0.00/39.3k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

Downloading (…)nce_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading (…)e9125/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

Downloading (…)9125/train_script.py:   0%|          | 0.00/13.2k [00:00<?, ?B/s]

Downloading (…)7e55de9125/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)5de9125/modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

In [None]:
topic_model.get_topic_info()

Unnamed: 0,Topic,Count,Name,Representation,Representative_Docs
0,-1,6547,-1_to_is_the_and,"[to, is, the, and, for, of, you, in, it, that]",[->\tFirst I want to start right out and say t...
1,0,1830,0_game_team_games_he,"[game, team, games, he, players, season, hocke...",[\nNo. Patrick Roy is the reason the game was...
2,1,619,1_key_clipper_chip_encryption,"[key, clipper, chip, encryption, keys, escrow,...","[[An article from comp.org.eff.news, EFFector ..."
3,2,527,2_ites_cheek_yep_huh,"[ites, cheek, yep, huh, ken, ignore, forget, w...","[\nHuh?, \n \n ..."
4,3,471,3_israel_israeli_jews_arab,"[israel, israeli, jews, arab, jewish, arabs, p...",[From: Center for Policy Research <cpr>\nSubje...
...,...,...,...,...,...
214,213,10,213_crohns_inflammation_patients_colitis,"[crohns, inflammation, patients, colitis, fatt...",[One thing that I haven't seen in this thread ...
215,214,10,214_slip_packet_0x60_driver,"[slip, packet, 0x60, driver, goto, cslipper, p...","[\n\n\nThey are working ok, but your definitio..."
216,215,10,215_disks_ibm_3m_boxies,"[disks, ibm, 3m, boxies, quarenteed, st412, se...",[I have a few the original IBM 10Mb harddisks ...
217,216,10,216_space_astronaut_nasa_candidates,"[space, astronaut, nasa, candidates, aerospace...",[I am looking for any information about the sp...


In [None]:
data = fetch_20newsgroups(subset='all',  remove=('headers', 'footers', 'quotes'))
docs = data["data"]
categories = data["target"]
category_names = data["target_names"]

In [None]:
categories

In [None]:
docs

In [None]:
len(','.join(docs))

In [None]:
len(docs)

In [None]:
sum([len(x) for x in docs])/len(docs)

In [None]:
topic_model.get_topic_info()

In [None]:
df = topic_model.get_topic_info()

In [None]:
df_docs = topic_model.get_document_info(docs)

In [None]:
len(df_docs[df_docs['Topic'] == 0])

In [None]:
topic_model.get_topic(0)

In [None]:
topic_model.get_document_info(docs)

In [None]:
topic_model.visualize_topics()

TO-DO:
1. build a data loader:
   - parameters: include_abstract = True, include_title = True, include_full_text = False, max_docs_per_category = 1000
   - keys: data, labels (-1, 0, 1,...), name (CVD1, CVD2, ...)
   - return: a dict

In [None]:
import json
ctg2pmid = json.load(open('gdrive/MyDrive/bertopic/caseolap_lift/text_mining/ctg2pmid.json'))

In [None]:
ctg2pmid

In [None]:
index_name = 'pubmed_lift'
index_type = 'pubmed_meta_lift'

In [None]:
def sample_publications(pmids, size=80, sample_prop='', use_prop_sampling = False):
    '''
    return a sample given the sample size
    '''

    sample = {}

    if not use_prop_sampling:
        n = int(size/8)
        for key in pmids:
            if key != 'None':
                sample[key] = random.sample(pmids[key], n)
    else:
        for key in pmids:
            if key != 'None':
                n = int(size*sample_prop[key])
                sample[key] = random.sample(pmids[key], n)

    return sample

In [None]:
for num_pmids, pmid in enumerate(relevant_pmid_batch):
            entry = es.get(id = pmid, index = index_name, doc_type = index_type)

            # Print progress and break early
            if batch_id == 1 and num_pmids % 1000 == 0:
                print(str(num_pmids)+' PMIDs processed in one of the batches',\
                      end='\r')
            if num_pmids > stop_at_this_many_pmids/procs:
                break


            # Determine which publications to find labels for
            labeled_meshes = entry['_source']['MeSH']
            labeled = len(labeled_meshes) > 0
            unlabeled = len(labeled_meshes) == 0
            if label_labeled_only and labeled:
                pass
            elif label_unlabeled_only and unlabeled:
                pass
            elif label_all:
                pass
            else:
                continue


            # Publication's text (title, abstract, full text if provided)
            abstract, title = get_document_text(entry)
            document_text = title + ' ' + abstract
            title = ' '+title+' '
            title = title.replace(',',' ')
            title = title.replace(':',' ')

In [None]:
def get_document_text(entry):
    '''
    FUNCTION:
    - Get the full text of the PubMed publication

    PARAMS:
    - entry (dict): parsed indexed publication
    '''

    # Title
    title = entry['_source']['title']
    if type(title) != str:
        title = ''
    title.replace('\n', ' ').replace('\t', ' ').replace('   ', ' ')

    # Abstract
    abstract = entry['_source']['abstract']
    if type(abstract) != str:
        abstract = ''
    abstract.replace('\n', ' ').replace('\t', ' ').replace('   ', ' ')

    return abstract, title

In [None]:
import random

In [None]:
sample = sample_publications(ctg2pmid, size=10)
labels = list(sample.keys())
labels

In [None]:
def dataloader(index_name, index_type, names, sample_size=800, include_abstract=True, include_title=True, include_fulltext=False):

    sample_docs = {}
    docs_data = []
    docs_labels = []
    docs_names = []

    sample = sample_publications(ctg2pmid, size=sample_size)
    es = Elasticsearch()

    for idx, key in enumerate(names):
        pmid_batch = sample[key]

        for num_pmids, pmid in enumerate(pmid_batch):
            entry = es.get(id = pmid, index = index_name, doc_type = index_type)

            abstract, title = get_document_text(entry)
            # document_text = title + ' ' + abstract

            if include_abstract:
                if len(abstract) > 0:
                    docs_data.append(abstract)
                    docs_labels.append(idx)
                    docs_names.append(key)

            if include_title:
                docs_data.append(title)
                docs_labels.append(idx)
                docs_names.append(key)

            if include_fulltext:
                fulltext = entry['full_text']
                if len(entry['full_text']) > 0:
                    docs_data.append(fulltext)
                    docs_labels.append(idx)
                    docs_names.append(key)

            # if num_pmids % 100 == 0:
            print(str(num_pmids)+' PMIDs processed in one of the batches', end='\r')

    sample_docs['data'] = docs_data
    sample_docs['labels'] = docs_labels
    sample_docs['names'] = docs_names

    return sample_docs

In [None]:
order = ["CM", "ARR", "CHD", "VD", "IHD", "CCD", "VOO", "OTH"]

In [None]:
# docs_1 = dataloader(index_name, index_type, order, sample_size=8000)
docs_1 = json.load(open('gdrive/MyDrive/bertopic/caseolap_lift/text_mining/docs_1.json'))

In [None]:
docs_1

In [None]:
len(docs_1['data'])

In [None]:
len(','.join(docs_1['data']))

In [None]:
from bertopic import BERTopic
topic_model_1 = BERTopic()
topics, probs = topic_model_1.fit_transform(docs_1['data'])

Downloading (…)e9125/.gitattributes:   0%|          | 0.00/1.18k [00:00<?, ?B/s]

Downloading (…)_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading (…)7e55de9125/README.md:   0%|          | 0.00/10.6k [00:00<?, ?B/s]

Downloading (…)55de9125/config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

Downloading (…)ce_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

Downloading (…)125/data_config.json:   0%|          | 0.00/39.3k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

Downloading (…)nce_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading (…)e9125/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

Downloading (…)9125/train_script.py:   0%|          | 0.00/13.2k [00:00<?, ?B/s]

Downloading (…)7e55de9125/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)5de9125/modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

In [None]:
topic_model_1.get_topic_info()

Unnamed: 0,Topic,Count,Name,Representation,Representative_Docs
0,-1,3038,-1_the_of_and_in,"[the, of, and, in, with, to, patients, was, fo...",[Management of intractable ventricular tachyar...
1,0,271,0_arrest_resuscitation_cpr_cardiac,"[arrest, resuscitation, cpr, cardiac, cardiopu...",[Do patient characteristics or factors at resu...
2,1,114,1_endocarditis_infective_ie_bacterial,"[endocarditis, infective, ie, bacterial, absce...","[Severe mitral or aortic valve regurgitation, ..."
3,2,112,2_artery_coronary_anomalous_left,"[artery, coronary, anomalous, left, origin, ma...",[Anomalous origin of the left coronary artery....
4,3,104,3_tumors_tumor_malignant_chemotherapy,"[tumors, tumor, malignant, chemotherapy, prima...",[Surgical Treatment of Cardiac Tumors: Insight...
...,...,...,...,...,...
153,152,11,152_anticoagulant_oral_lowmolecularweight_therapy,"[anticoagulant, oral, lowmolecularweight, ther...",[Unrelenting Abdominal Pain after Recent Initi...
154,153,10,153_ventricle_double_chambered_right,"[ventricle, double, chambered, right, she, dou...",[[Double outlet right ventricle--a case associ...
155,154,10,154_mitral_repair_annuloplasty_dilated,"[mitral, repair, annuloplasty, dilated, regurg...",[Novel approach to mitral valve repair in chil...
156,155,10,155_mitral_valve_replacement_repair,"[mitral, valve, replacement, repair, elderly, ...",[Prosthetic valve replacement in infants and c...


In [None]:
docs_topic = topic_model_1.get_document_info(docs_1['data'])
docs_topic

Unnamed: 0,Document,Topic,Name,Representation,Representative_Docs,Top_n_words,Probability,Representative_document
0,"Total red cell volume, plasma volume, and sodi...",-1,-1_the_of_and_in,"[the, of, and, in, with, to, patients, was, fo...",[Management of intractable ventricular tachyar...,the - of - and - in - with - to - patients - w...,0.000000,False
1,What is the role of myocardial mast cells?,-1,-1_the_of_and_in,"[the, of, and, in, with, to, patients, was, fo...",[Management of intractable ventricular tachyar...,the - of - and - in - with - to - patients - w...,0.000000,False
2,Feasibility of using multivector impedance to ...,34,34_icd_implantable_defibrillator_cardioverter,"[icd, implantable, defibrillator, cardioverter...",[Association between left ventricular ejection...,icd - implantable - defibrillator - cardiovert...,0.218832,False
3,Impact of predictive value of Fibrosis-4 index...,-1,-1_the_of_and_in,"[the, of, and, in, with, to, patients, was, fo...",[Management of intractable ventricular tachyar...,the - of - and - in - with - to - patients - w...,0.000000,False
4,Is Too Much Oxygen Bad for the Heart?,-1,-1_the_of_and_in,"[the, of, and, in, with, to, patients, was, fo...",[Management of intractable ventricular tachyar...,the - of - and - in - with - to - patients - w...,0.000000,False
...,...,...,...,...,...,...,...,...
7995,Thrombosis and failure of a HeartMate II devic...,44,44_assist_lvad_impella_device,"[assist, lvad, impella, device, support, devic...",[Outcomes of patients with right ventricular f...,assist - lvad - impella - device - support - d...,1.000000,False
7996,Cardiac arrest prognostic factors in children.,0,0_arrest_resuscitation_cpr_cardiac,"[arrest, resuscitation, cpr, cardiac, cardiopu...",[Do patient characteristics or factors at resu...,arrest - resuscitation - cpr - cardiac - cardi...,0.852608,False
7997,Identification of potent and selective amidobi...,-1,-1_the_of_and_in,"[the, of, and, in, with, to, patients, was, fo...",[Management of intractable ventricular tachyar...,the - of - and - in - with - to - patients - w...,0.000000,False
7998,MR demonstration of right atrial involvement i...,27,27_myxoma_tumor_case_atrium,"[myxoma, tumor, case, atrium, atrial, myxomas,...",[Left atrial myxoma in a patient with paroxysm...,myxoma - tumor - case - atrium - atrial - myxo...,0.130659,False


In [None]:
docs_topic = docs_topic[docs_topic['Topic'] != -1]
docs_topic = docs_topic.reset_index(drop=True)
docs_topic

In [None]:
subset = docs_topic[['Document', 'Topic', 'Name', 'Representation', 'Probability']]
subset

In [None]:
docs_search = {}
for idx, key in enumerate(docs_1['data']):
    if key in docs_search:
        val = docs_1['names'][idx]
        if val not in docs_search[key]:
            docs_search[key].append(val)
    else:
        docs_search[key] = [docs_1['names'][idx]]


In [None]:
docs_search

In [None]:
import numpy as np

In [None]:
len(np.unique(subset['Document'].tolist()))

In [None]:
subset['Document'].apply(lambda x: x in docs_search).all()

In [None]:
true_topic = subset['Document'].apply(lambda x: ', '.join(docs_search[x]))
true_topic

In [None]:
subset['True topic'] = true_topic

In [None]:
subset

In [None]:
import pandas as pd

In [None]:
topic_table = pd.DataFrame(subset.groupby('True topic').apply(lambda x: list(set(x['Topic']))))

In [None]:
topic_table = topic_table.reset_index()

In [None]:
topic_table.columns

In [None]:
topic_table = topic_table.rename(columns={0: 'topics'})

In [None]:
topic_table

In [None]:
topic_table['length_of_topics'] = topic_table['topics'].apply(lambda x: len(x))

In [None]:
one_topic_table = topic_table[topic_table['True topic'].apply(lambda x: len(x)) <= 3]

In [None]:
one_topic_table

In [None]:
set.intersection(*[set(x) for x in one_topic_table['topics'].tolist()])

In [None]:
[set(x) for x in one_topic_table['topics'].tolist()]

In [None]:
topic_model_1.visualize_topics()

In [None]:
topic_model_1.get_params()


{'calculate_probabilities': False,
 'ctfidf_model': ClassTfidfTransformer(),
 'embedding_model': <bertopic.backend._sentencetransformers.SentenceTransformerBackend at 0x7ef776c00d60>,
 'hdbscan_model': HDBSCAN(min_cluster_size=10, prediction_data=True),
 'language': 'english',
 'low_memory': False,
 'min_topic_size': 10,
 'n_gram_range': (1, 1),
 'nr_topics': None,
 'representation_model': None,
 'seed_topic_list': None,
 'top_n_words': 10,
 'umap_model': UMAP(angular_rp_forest=True, low_memory=False, metric='cosine', min_dist=0.0, n_components=5, tqdm_kwds={'bar_format': '{desc}: {percentage:3.0f}%| {bar} {n_fmt}/{total_fmt} [{elapsed}]', 'desc': 'Epochs completed', 'disable': True}),
 'vectorizer_model': CountVectorizer(),
 'verbose': False}

In [None]:
from scipy.cluster import hierarchy as sch

# Hierarchical topics
linkage_function = lambda x: sch.linkage(x, 'single', optimal_ordering=True)
hierarchical_topics = topic_model_1.hierarchical_topics(docs_1['data'], linkage_function=linkage_function)

In [None]:
topic_model_1.visualize_hierarchy(hierarchical_topics=hierarchical_topics)

In [None]:
hierarchical_topics

In [None]:
category2synonyms = json.load(open('gdrive/MyDrive/bertopic/caseolap_lift/text_mining/category2synonyms.json'))

In [None]:
[len(val) for key, val in category2synonyms.items()]

In [None]:
category2synonyms

To-do:
1. parse xml file to find the corresponding synonyms
2. generate a list of categories with synonyms to fit the model
3. run the guided model

In [None]:
! wget https://nlmpubs.nlm.nih.gov/projects/mesh/MESH_FILES/xmlmesh/desc2022.xml

In [None]:
! ls desc2022.xml

In [None]:
! ls

In [None]:
! ls ../data/MeSH/desc2022.xml

In [None]:
category2synonyms['VOO']

In [None]:
CM_syn = [s for s in mesh.keys() if 'Cardiomyopathy' in s]
CM_syn

In [None]:
ARR_syn = [s for s in mesh.keys() if 'Arrhythmia' in s]
ARR_syn

In [None]:
CHD_syn = [s for s in mesh.keys() if 'Congenital' in s]
CHD_syn

In [None]:
VD_syn = [s for s in mesh.keys() if 'Valve' in s]
VD_syn

In [None]:
CCD_syn = [s for s in mesh.keys() if 'Cardiac Conduction' in s or 'Cardiac Complex' in s or 'Atrioventricular' in s]
CCD_syn

In [None]:
VOO_syn = [s for s in mesh.keys() if 'Subvalvular' in s or 'Pulmonary' in s]
VOO_syn

In [None]:
from bertopic import BERTopic

In [None]:
from bertopic import BERTopic
from bertopic.backend import BaseEmbedder
from bertopic.cluster import BaseCluster
from bertopic.vectorizers import ClassTfidfTransformer
from bertopic.dimensionality import BaseDimensionalityReduction
from sklearn.datasets import fetch_20newsgroups

docs = fetch_20newsgroups(subset='all',  remove=('headers', 'footers', 'quotes'))
data = docs["data"]

seed_topic_list = [["drug", "cancer", "drugs", "doctor"],
                   ["windows", "drive", "dos", "file"],
                   ["space", "launch", "orbit", "lunar"]]

empty_embedding_model = BaseEmbedder()
empty_dimensionality_model = BaseDimensionalityReduction()
empty_cluster_model = BaseCluster()
ctfidf_model = ClassTfidfTransformer(reduce_frequent_words=True)


topic_model = BERTopic(
    umap_model=empty_dimensionality_model,
    hdbscan_model=empty_cluster_model,
    ctfidf_model=ctfidf_model,
    seed_topic_list=seed_topic_list
)
topics, probs = topic_model.fit_transform(data, y=docs['target'])


In [None]:
topic_model.get_topic_info()

NameError: ignored

In [None]:
docs

In [None]:
list(category2synonyms.values())

In [None]:
from sklearn.feature_extraction.text import CountVectorizer
from bertopic.representation import KeyBERTInspired

docs_1_ = docs_1['data']
seed = list(category2synonyms.values())

# vectorizer_model = CountVectorizer(stop_words="english")
representation_model = KeyBERTInspired()

topic_model_1 = BERTopic(
    umap_model=empty_dimensionality_model,
    hdbscan_model=empty_cluster_model,
    ctfidf_model=ctfidf_model,
    # vectorizer_model=vectorizer_model,
    representation_model=representation_model,
    seed_topic_list=seed
)
topics_1, probs_1 = topic_model_1.fit_transform(docs_1_, y=docs_1['labels'])

In [None]:
topic_model_1.get_params()

{'calculate_probabilities': False,
 'ctfidf_model': ClassTfidfTransformer(reduce_frequent_words=True),
 'embedding_model': <bertopic.backend._sentencetransformers.SentenceTransformerBackend at 0x7afb8c18b100>,
 'hdbscan_model': <bertopic.cluster._base.BaseCluster at 0x7afb899a9900>,
 'language': 'english',
 'low_memory': False,
 'min_topic_size': 10,
 'n_gram_range': (1, 1),
 'nr_topics': None,
 'representation_model': KeyBERTInspired(),
 'seed_topic_list': [['Cardiomyopathy, Congestive',
   'Congestive Heart Failure',
   'Cardiomyopathy, Familial Idiopathic',
   'Congestive Cardiomyopathy',
   'Dilated Cardiomyopathy',
   'Familial Hypertrophic Cardiomyopathy',
   'Dyspnea, Paroxysmal',
   'Asthma, Cardiac',
   'Myocarditis',
   'Injury, Myocardial Reperfusion',
   'Heart Failure',
   'Chagas Cardiomyopathy',
   'Ventricular Dysplasia, Right, Arrhythmogenic',
   'Secondary Myocardial Diseases',
   'Left-Sided Heart Failure',
   'Heart Failure, Congestive',
   'Cardiomyopathy, Hypertro

In [None]:
df = topic_model_1.get_topic_info()

In [None]:
df

Unnamed: 0,Topic,Count,Name,Representation,Representative_Docs
0,-1,11454,-1_stenosis_aortic_ventricular_cardiac,"[stenosis, aortic, ventricular, cardiac, myoca...",[The study aim was to assess the value of exer...
1,0,614,0_abnormalities_trisomy_mutations_congenital,"[abnormalities, trisomy, mutations, congenital...",[Conotruncal heart defects (CTDs) are present ...
2,1,301,1_tachyarrhythmias_arrhythmias_tachycardia_arr...,"[tachyarrhythmias, arrhythmias, tachycardia, a...",[A computer-assisted analysis of the TU-comple...
3,2,200,2_cardiac_ventricular_tachycardia_mutations,"[cardiac, ventricular, tachycardia, mutations,...",[The congenital long QT syndrome is a potentia...
4,3,148,3_endocarditis_coxsackievirus_myocarditis_peri...,"[endocarditis, coxsackievirus, myocarditis, pe...",[BACKGROUND Calcified amorphous tumor (CAT) of...
5,4,57,4_cardiomyocyte_myocarditis_immunoglobulins_pr...,"[cardiomyocyte, myocarditis, immunoglobulins, ...",[The impact and clinical relevance of pregnanc...
6,5,32,5_dysplasia_congenital_diagnosis_dyskinesia,"[dysplasia, congenital, diagnosis, dyskinesia,...",[To determine whether CT-guided mucociliary cl...
7,6,21,6_peroxidation_lipoperoxidation_hypercholester...,"[peroxidation, lipoperoxidation, hypercholeste...",[It was found that glucose in the range of con...
8,7,16,7_dilution_saturation_sampling_bronchoscopically,"[dilution, saturation, sampling, bronchoscopic...",[In an article in a previous issue of the Jour...


In [None]:
sum(df['Count'])

12843

In [None]:
len(data)

18846

In [None]:
list(df['Representation'])[0]

['and', 'in', 'with', 'to', 'of', 'the', 'patients', 'was', 'for', 'were']

In [None]:
df_doc = topic_model_1.get_document_info(docs_1_)

In [None]:
df_doc['Top_n_words'][0]

'and - in - with - to - of - the - patients - was - for - were'

In [None]:
topic_model_1.visualize_topics()

## To-Do: 08/31
  - combine the title and abstract together
  - model:
    - try to see if possible to modify the weights of the seed (1.2 to 1.5)
    - change ngram range
  - seed_topic_list:
    - lowercase everything
    - split them into individual words and remove trivial words
    - include words that are only in one topic (remove overlapping words maybe)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import json
category2synonyms = json.load(open('drive/MyDrive/bertopic/caseolap_lift/text_mining/category2synonyms.json'))

In [None]:
category2synonyms

{'CM': ['Cardiomyopathy, Congestive',
  'Congestive Heart Failure',
  'Cardiomyopathy, Familial Idiopathic',
  'Congestive Cardiomyopathy',
  'Dilated Cardiomyopathy',
  'Familial Hypertrophic Cardiomyopathy',
  'Dyspnea, Paroxysmal',
  'Asthma, Cardiac',
  'Myocarditis',
  'Injury, Myocardial Reperfusion',
  'Heart Failure',
  'Chagas Cardiomyopathy',
  'Ventricular Dysplasia, Right, Arrhythmogenic',
  'Secondary Myocardial Diseases',
  'Left-Sided Heart Failure',
  'Heart Failure, Congestive',
  'Cardiomyopathy, Hypertrophic Obstructive',
  'Heart Failure, Normal Ejection Fraction',
  'Adhalinopathy, Primary',
  'Myocardial Disease',
  'CPEO with Myopathy',
  'Ophthalmoplegia Plus Syndrome',
  'Cardiomyopathies',
  'Hypertrophic Subaortic Stenosis, Idiopathic',
  'LGMD2D',
  'Limb-Girdle Muscular Dystrophy, Type 2D',
  'Noncompaction of the Left Ventricular Myocardium, Autosomal Dominant',
  'Cardiac Failure',
  'Myocardial Ischemic Reperfusion Injury',
  'Alpha-Sarcoglycanopathy',
 

In [None]:
docs_1 = json.load(open('drive/MyDrive/bertopic/caseolap_lift/text_mining/docs_1.json'))

In [None]:
seed = list(category2synonyms.values())

In [None]:
order = ["CM", "ARR", "CHD", "VD", "IHD", "CCD", "VOO", "OTH"]

#### model

In [None]:
! pip install bertopic

Collecting bertopic
  Downloading bertopic-0.15.0-py2.py3-none-any.whl (143 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.4/143.4 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
Collecting hdbscan>=0.8.29 (from bertopic)
  Downloading hdbscan-0.8.33.tar.gz (5.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.2/5.2 MB[0m [31m18.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting umap-learn>=0.5.0 (from bertopic)
  Downloading umap-learn-0.5.3.tar.gz (88 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.2/88.2 kB[0m [31m10.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting sentence-transformers>=0.4.1 (from bertopic)
  Downloading sentence-transformers-2.2.2.tar.gz (85 kB)
[2K     [90m━━━━━━━━━━━━━━━━━

In [None]:
from bertopic import BERTopic
from bertopic.backend import BaseEmbedder
from bertopic.cluster import BaseCluster
from bertopic.vectorizers import ClassTfidfTransformer
from bertopic.dimensionality import BaseDimensionalityReduction
# from sklearn.datasets import fetch_20newsgroups

In [None]:
empty_embedding_model = BaseEmbedder()
empty_dimensionality_model = BaseDimensionalityReduction()
empty_cluster_model = BaseCluster()
ctfidf_model = ClassTfidfTransformer(reduce_frequent_words=True)

In [None]:
from sklearn.feature_extraction.text import CountVectorizer
from bertopic.representation import KeyBERTInspired

docs_1_ = docs_1['data']
seed = list(category2synonyms.values())

vectorizer_model = CountVectorizer(stop_words="english")
representation_model = KeyBERTInspired()

topic_model_1 = BERTopic(
    umap_model=empty_dimensionality_model,
    hdbscan_model=empty_cluster_model,
    ctfidf_model=ctfidf_model,
    # vectorizer_model=vectorizer_model,
    representation_model=representation_model,
    seed_topic_list=seed,
    n_gram_range=(1,2)
)
topics_1, probs_1 = topic_model_1.fit_transform(docs_1_, y=docs_1['labels'])

In [None]:
df_1 = topic_model_1.get_topic_info()
df_1

Unnamed: 0,Topic,Count,Name,Representation,Representative_Docs
0,-1,7147,-1_heart failure_cardiac_aortic valve_ventricular,"[heart failure, cardiac, aortic valve, ventric...",[Aortic stiffness as a marker of cardiac funct...
1,0,388,0_turner syndrome_turners syndrome_trisomy_abn...,"[turner syndrome, turners syndrome, trisomy, a...",[The knee alignment and the foot arch in patie...
2,1,167,1_qt syndrome_arrhythmias_ventricular fibrilla...,"[qt syndrome, arrhythmias, ventricular fibrill...",[Effect of sodium channel blockers on ST segme...
3,2,134,2_arrhythmias_qt syndrome_cardiac_tachycardia,"[arrhythmias, qt syndrome, cardiac, tachycardi...",[A missense mutation (G604S) in the S5/pore re...
4,3,84,3_bacterial endocarditis_bacteraemia_bacterial...,"[bacterial endocarditis, bacteraemia, bacteria...",[Methanobrevibacter smithii Archaemia in Febri...
5,4,38,4_cardiac myosin_myocarditis and_myocarditis_a...,"[cardiac myosin, myocarditis and, myocarditis,...",[Genetic susceptibility to Chagas disease card...
6,5,22,5_bronchiectasis in_of bronchiectasis_bronchie...,"[bronchiectasis in, of bronchiectasis, bronchi...","[Aetiology of bronchiectasis in Guangzhou, sou..."
7,6,15,6_lipoproteins_triglycerides and_palmoplantar ...,"[lipoproteins, triglycerides and, palmoplantar...",[DGCR8 recognizes primary transcripts of micro...
8,7,5,7_american correction_tracheostomy as_required...,"[american correction, tracheostomy as, require...",[[TRACHEOSTOMY AS A MEANS OF PREVENTING RESPIR...


In [None]:
topic_model_2 = BERTopic(
    # umap_model=empty_dimensionality_model,
    # hdbscan_model=empty_cluster_model,
    # ctfidf_model=ctfidf_model,
    # vectorizer_model=vectorizer_model,
    representation_model=representation_model,
    seed_topic_list=seed,
    n_gram_range=(1,2)
)
topics_2, probs_2 = topic_model_2.fit_transform(docs_1_, y=docs_1['labels'])

df_2 = topic_model_2.get_topic_info()
df_2

Unnamed: 0,Topic,Count,Name,Representation,Representative_Docs
0,-1,2745,-1_aortic stenosis_aortic valve_aortic_heart f...,"[aortic stenosis, aortic valve, aortic, heart ...",[Severe aortic stenosis in octogenarians: is o...
1,0,415,0_turner syndrome_turners syndrome_turner_marf...,"[turner syndrome, turners syndrome, turner, ma...",[The knee alignment and the foot arch in patie...
2,1,337,1_heart failure_chronic heart_cardiac_ventricular,"[heart failure, chronic heart, cardiac, ventri...",[[Value of aldosterone receptor blockade in di...
3,2,282,2_cardiac arrest_cardiopulmonary resuscitation...,"[cardiac arrest, cardiopulmonary resuscitation...",[Admission of out-of-hospital cardiac arrest v...
4,3,114,3_myocardial ischemia_ischemiareperfusion inju...,"[myocardial ischemia, ischemiareperfusion inju...",[Myocardial infarct extension during reperfusi...
...,...,...,...,...,...
115,114,12,114_hyperthyroid heart_hyperthyroid patients_o...,"[hyperthyroid heart, hyperthyroid patients, of...",[Occult thyrotoxicosis: a correctable cause of...
116,115,12,115_antiphospholipid syndrome_antiphospholipid...,"[antiphospholipid syndrome, antiphospholipid a...","[[Severe, non-infectious mitral valve endocard..."
117,116,11,116_coronary disease_coronary insufficiency_th...,"[coronary disease, coronary insufficiency, thr...",[[Coronary sclerosis and its sequelae (Statist...
118,117,11,117_occupational poisoning_oxide poisoning_qui...,"[occupational poisoning, oxide poisoning, quin...",[Presence or absence of elevated acute total s...


In [None]:
topic_model_3 = BERTopic(
    # umap_model=empty_dimensionality_model,
    # hdbscan_model=empty_cluster_model,
    ctfidf_model=ctfidf_model,
    # vectorizer_model=vectorizer_model,
    representation_model=representation_model,
    seed_topic_list=seed,
    n_gram_range=(1,2)
)

topics_3, probs_3 = topic_model_3.fit_transform(docs_1_, y=docs_1['labels'])

df_3 = topic_model_3.get_topic_info()
df_3

Unnamed: 0,Topic,Count,Name,Representation,Representative_Docs
0,-1,2963,-1_aortic stenosis_heart failure_aortic valve_...,"[aortic stenosis, heart failure, aortic valve,...",[Prognostic significance of frequent premature...
1,0,408,0_turner syndrome_turners syndrome_growth horm...,"[turner syndrome, turners syndrome, growth hor...",[Current indications for growth hormone therap...
2,1,284,1_cardiac arrest_cardiac arrests_resuscitation...,"[cardiac arrest, cardiac arrests, resuscitatio...",[A retrospective study of pulseless electrical...
3,2,125,2_myocardial ischemia_ischemia reperfusion_of ...,"[myocardial ischemia, ischemia reperfusion, of...",[[Morphine-induced late cardioprotection: pote...
4,3,99,3_atrial fibrillation_oral anticoagulants_anti...,"[atrial fibrillation, oral anticoagulants, ant...",[Stroke prophylaxis in atrial fibrillation: se...
...,...,...,...,...,...
126,125,11,125_leadless pacing_leadless pacemaker_ventric...,"[leadless pacing, leadless pacemaker, ventricu...",[Recent advances in pacemaker and implantable ...
127,126,11,126_atrial fibrosis_atrial fibrillation_inhibi...,"[atrial fibrosis, atrial fibrillation, inhibit...",[Role of Rac1 GTPase activation in atrial fibr...
128,127,11,127_published tavi_tavi_of tavi_tavi is,"[published tavi, tavi, of tavi, tavi is, in ta...",[The official position of the Latin American A...
129,128,10,128_heart murmur_systolic murmur_systolic murm...,"[heart murmur, systolic murmur, systolic murmu...",[Critical evaluation of atrial presystolic mur...


In [None]:
topic_model_4 = BERTopic(
    umap_model=empty_dimensionality_model,
    # hdbscan_model=empty_cluster_model,
    # ctfidf_model=ctfidf_model,
    # vectorizer_model=vectorizer_model,
    representation_model=representation_model,
    seed_topic_list=seed,
    n_gram_range=(1,2)
)
topics_4, probs_4 = topic_model_4.fit_transform(docs_1_, y=docs_1['labels'])

df_4 = topic_model_4.get_topic_info()
df_4

Unnamed: 0,Topic,Count,Name,Representation,Representative_Docs
0,-1,3073,-1_heart failure_cardiovascular_myocardial inf...,"[heart failure, cardiovascular, myocardial inf...","[[The QDF-HF (Quality of life, Depression and ..."
1,0,4914,0_aortic stenosis_heart failure_aortic valve_a...,"[aortic stenosis, heart failure, aortic valve,...",[Occult aortic stenosis as cause of intractabl...
2,1,13,1_discussion_discussion hard_hard discussion_d...,"[discussion, discussion hard, hard discussion,...","[Discussion. , Discussion. , Discussion. ]"


In [None]:
topic_model_5 = BERTopic(
    # umap_model=empty_dimensionality_model,
    hdbscan_model=empty_cluster_model,
    # ctfidf_model=ctfidf_model,
    # vectorizer_model=vectorizer_model,
    representation_model=representation_model,
    seed_topic_list=seed,
    n_gram_range=(1,2)
)
topics_5, probs_5 = topic_model_5.fit_transform(docs_1_, y=docs_1['labels'])

df_5 = topic_model_5.get_topic_info()
df_5

Unnamed: 0,Topic,Count,Name,Representation,Representative_Docs
0,-1,7147,-1_heart failure_aortic valve_ventricular_cardiac,"[heart failure, aortic valve, ventricular, car...",[Essential ECG clues in patients with congenit...
1,0,388,0_turner syndrome_turners syndrome_syndrome an...,"[turner syndrome, turners syndrome, syndrome a...",[Social and medical determinants of quality of...
2,1,167,1_qt syndrome_arrhythmias_ventricular tachycar...,"[qt syndrome, arrhythmias, ventricular tachyca...",[The spectrum of symptoms and QT intervals in ...
3,2,134,2_qt syndrome_arrhythmias_cardiac_ventricular,"[qt syndrome, arrhythmias, cardiac, ventricula...",[A missense mutation (G604S) in the S5/pore re...
4,3,84,3_bacterial myocarditis_bacterial endocarditis...,"[bacterial myocarditis, bacterial endocarditis...",[Bacteraemia during Transurethral Resection of...
5,4,38,4_myocarditis and_myocarditis_autoantibodies_c...,"[myocarditis and, myocarditis, autoantibodies,...",[Acute susceptibility of aged mice to infectio...
6,5,22,5_bronchiectasis in_of bronchiectasis_bronchie...,"[bronchiectasis in, of bronchiectasis, bronchi...","[Aetiology of bronchiectasis in Guangzhou, sou..."
7,6,15,6_platelet survival_platelets_triglycerides_my...,"[platelet survival, platelets, triglycerides, ...",[Unrecognized diabetes and myocardial necrosis...
8,7,5,7_american correction_tracheostomy as_required...,"[american correction, tracheostomy as, require...",[[TRACHEOSTOMY AS A MEANS OF PREVENTING RESPIR...


In [None]:
from umap import UMAP
from hdbscan import HDBSCAN
from sentence_transformers import SentenceTransformer
from sklearn.feature_extraction.text import CountVectorizer

from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired
from bertopic.vectorizers import ClassTfidfTransformer


# Step 1 - Extract embeddings
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")

# Step 2 - Reduce dimensionality
umap_model = UMAP(n_neighbors=15, n_components=5, min_dist=0.0, metric='cosine')

# Step 3 - Cluster reduced embeddings
hdbscan_model = HDBSCAN(min_cluster_size=15, metric='euclidean', cluster_selection_method='eom', prediction_data=True)

# Step 4 - Tokenize topics
vectorizer_model = CountVectorizer(stop_words="english")

# Step 5 - Create topic representation
ctfidf_model = ClassTfidfTransformer(reduce_frequent_words=True)

# Step 6 - (Optional) Fine-tune topic representations with
# a `bertopic.representation` model
representation_model = KeyBERTInspired()

# All steps together
topic_model_6 = BERTopic(
  top_n_words=100,
  embedding_model=embedding_model,           # Step 1 - Extract embeddings
  umap_model=umap_model,                     # Step 2 - Reduce dimensionality
  hdbscan_model=hdbscan_model,               # Step 3 - Cluster reduced embeddings
  vectorizer_model=vectorizer_model,         # Step 4 - Tokenize topics
  ctfidf_model=ctfidf_model,                 # Step 5 - Extract topic words
  representation_model=representation_model, # Step 6 - (Optional) Fine-tune topic represenations
  seed_topic_list=seed,
  n_gram_range=(1,2)
)


In [None]:
topics_6, probs_6 = topic_model_6.fit_transform(docs_1_, y=docs_1['labels'])

df_6 = topic_model_6.get_topic_info()
df_6

Unnamed: 0,Topic,Count,Name,Representation,Representative_Docs
0,-1,2948,-1_echocardiography_aortic_aorta_stenosis,"[echocardiography, aortic, aorta, stenosis, ve...",[Sequential balloon dilatation for combined ao...
1,0,343,0_diuretics_antihypertensive_angiotensin_hyper...,"[diuretics, antihypertensive, angiotensin, hyp...",[OUTpatient intravenous LASix Trial in reducin...
2,1,329,1_cardiomyocytes_cardioprotection_cardioprotec...,"[cardiomyocytes, cardioprotection, cardioprote...",[Changes in PPAR gene expression and myocardia...
3,2,272,2_resuscitation_resuscitated_defibrillation_po...,"[resuscitation, resuscitated, defibrillation, ...",[[Cardiac arrest management: any news? When th...
4,3,204,3_mitral_regurgitation_echocardiographic_mitra...,"[mitral, regurgitation, echocardiographic, mit...",[Ischemic mitral regurgitation: intraventricul...
...,...,...,...,...,...
86,85,17,85_antiarrhythmic_fibrillation_atrial_pharmaco...,"[antiarrhythmic, fibrillation, atrial, pharmac...",[Pharmacologic approaches to rhythm versus rat...
87,86,16,86_myocardial_biomarkers_biomarker_troponin,"[myocardial, biomarkers, biomarker, troponin, ...",[Prediction of Recurrent Events by D-Dimer and...
88,87,16,87_transplantation_transplant_donor_transplanted,"[transplantation, transplant, donor, transplan...",[Development of a successful non-heart-beating...
89,88,16,88_echocardiography_echocardiographic_echocard...,"[echocardiography, echocardiographic, echocard...",[Detection of ventricular thrombi by ultrasoun...


In [None]:
topic_model_6._extract_words_per_topic()

TypeError: ignored

In [None]:
# Step 3 - Cluster reduced embeddings
hdbscan_model = HDBSCAN(min_cluster_size=100, max_cluster_size=1000, leaf_size=20, metric='euclidean', cluster_selection_method='eom', prediction_data=True)

# All steps together
topic_model_7 = BERTopic(
  top_n_words=100,
  embedding_model=embedding_model,           # Step 1 - Extract embeddings
  umap_model=umap_model,                     # Step 2 - Reduce dimensionality
  hdbscan_model=hdbscan_model,               # Step 3 - Cluster reduced embeddings
  vectorizer_model=vectorizer_model,         # Step 4 - Tokenize topics
  ctfidf_model=ctfidf_model,                 # Step 5 - Extract topic words
  representation_model=representation_model, # Step 6 - (Optional) Fine-tune topic represenations
  seed_topic_list=seed,
  n_gram_range=(1,2)
)

topics_7, probs_7 = topic_model_7.fit_transform(docs_1['data'], y=docs_1['labels'])

df_7 = topic_model_7.get_topic_info()
df_7

Unnamed: 0,Topic,Count,Name,Representation,Representative_Docs
0,-1,3489,-1_cardiovascular_echocardiography_cardiac_ven...,"[cardiovascular, echocardiography, cardiac, ve...",[Risk factors for sepsis and endocarditis and ...
1,0,886,0_ventricular_ventricle_atrioventricular_aortic,"[ventricular, ventricle, atrioventricular, aor...",[Biventricular repair of transposition of the ...
2,1,775,1_inhibitor_cardioprotective_myocardial_cardiac,"[inhibitor, cardioprotective, myocardial, card...",[A review of heart failure management in the e...
3,2,515,2_tachyarrhythmias_antiarrhythmic_arrhythmias_...,"[tachyarrhythmias, antiarrhythmic, arrhythmias...",[Combined use of time and frequency domain var...
4,3,446,3_aortic_tavi_transcatheter_stenosis,"[aortic, tavi, transcatheter, stenosis, valve,...",[The Outcomes of Pulmonary Hypertension Patien...
5,4,433,4_cardiac_myocarditis_endocarditis_echocardiog...,"[cardiac, myocarditis, endocarditis, echocardi...","[Pediatric cardiac tumors: a 45-year, single-i..."
6,5,400,5_trisomy_abnormalities_syndrome_congenital,"[trisomy, abnormalities, syndrome, congenital,...",[Gender Dysphoria and Gender Change in Disorde...
7,6,276,6_atrial_fibrillation_cardioversion_anticoagul...,"[atrial, fibrillation, cardioversion, anticoag...",[[Analysis of risk factors for all cause-morta...
8,7,265,7_resuscitation_resuscitated_defibrillation_cpr,"[resuscitation, resuscitated, defibrillation, ...",[Duration of cardiopulmonary resuscitation in ...
9,8,220,8_mitral_regurgitation_echocardiography_ventri...,"[mitral, regurgitation, echocardiography, vent...",[How does the use of polytetrafluoroethylene n...


In [None]:
topic_model_7.visualize_topics()

#### Edit seed topic list

In [None]:
unique_seed = []

# Iterate through the list of strings
for ls in seed:
    unique_list = set()

    for string in ls:

        words = string.split()
        unique_list.update(words)

    unique_seed.append(list(unique_list))


In [None]:
topic_model_8 = BERTopic(
  embedding_model=embedding_model,           # Step 1 - Extract embeddings
  umap_model=umap_model,                     # Step 2 - Reduce dimensionality
  hdbscan_model=hdbscan_model,               # Step 3 - Cluster reduced embeddings
  vectorizer_model=vectorizer_model,         # Step 4 - Tokenize topics
  ctfidf_model=ctfidf_model,                 # Step 5 - Extract topic words
  representation_model=representation_model, # Step 6 - (Optional) Fine-tune topic represenations
  seed_topic_list=unique_seed,
)

topics_8, probs_8 = topic_model_8.fit_transform(docs_1['data'], y=docs_1['labels'])

df_8 = topic_model_8.get_topic_info()
df_8

Unnamed: 0,Topic,Count,Name,Representation,Representative_Docs
0,-1,4057,-1_echocardiography_ventricular_cardiac_ventricle,"[echocardiography, ventricular, cardiac, ventr...",[Left ventricular reconstruction: Early and la...
1,0,814,0_cardiovascular_cardiomyocytes_cardioprotecti...,"[cardiovascular, cardiomyocytes, cardioprotect...",[Beneficial effect of prolonged heme oxygenase...
2,1,432,1_aortic_tavi_stentless_stenosis,"[aortic, tavi, stentless, stenosis, valves, va...",[Early- and mid-term outcomes of transcatheter...
3,2,401,2_coronary_stenting_cardiovascular_atheroscler...,"[coronary, stenting, cardiovascular, atheroscl...",[Relationship of interleukin-6-572C/G promoter...
4,3,343,3_tachyarrhythmias_antiarrhythmic_arrhythmias_...,"[tachyarrhythmias, antiarrhythmic, arrhythmias...",[[Syncope in supraventricular tachycardia. Inc...
5,4,297,4_atrial_antiarrhythmic_ablation_arrhythmia,"[atrial, antiarrhythmic, ablation, arrhythmia,...",[Is there still a role for additional linear a...
6,5,273,5_endocarditis_myocarditis_myopericarditis_per...,"[endocarditis, myocarditis, myopericarditis, p...",[Bacteraemia during Transurethral Resection of...
7,6,270,6_resuscitation_resuscitated_defibrillation_po...,"[resuscitation, resuscitated, defibrillation, ...",[Usefulness of cooling and coronary catheteriz...
8,7,198,7_mitral_regurgitation_echocardiography_ventri...,"[mitral, regurgitation, echocardiography, vent...",[Ischemic mitral regurgitation: intraventricul...
9,8,186,8_ventricular_ventricle_atrioventricular_stenosis,"[ventricular, ventricle, atrioventricular, ste...",[Surgical management of double-outlet right ve...


In [None]:
topic_model_8.visualize_topics()

In [None]:
# Define a list of lists containing words
list_of_lists = unique_seed.copy()

# Create a set to store words that have been seen
seen_words = set()

# Create a set to store overlapping words
overlapping_words = set()

# Iterate through the list of lists
for word_list in list_of_lists:
    for word in word_list:
        # If the word has been seen before, add it to the overlapping set
        if word in seen_words:
            overlapping_words.add(word)
        else:
            seen_words.add(word)

# Create a new list of lists without overlapping words
new_unique_seed = []

# Iterate through the original list of lists
for word_list in list_of_lists:
    # Create a new list containing words that are not in the overlapping set
    unique_word_list = [word for word in word_list if word not in overlapping_words]

    # Add the unique word list to the new list of lists
    new_unique_seed.append(unique_word_list)

# new_list_of_lists now contains lists without overlapping words


In [None]:
[len(x) for x in new_unique_seed]

[91, 15, 184, 20, 59, 1, 3, 73]

In [None]:
# Define a list of lists containing words
list_of_lists = seed.copy()

# Create a set to store words that have been seen
seen_words = set()

# Create a set to store overlapping words
overlapping_words = set()

# Iterate through the list of lists
for word_list in list_of_lists:
    for word in word_list:
        # If the word has been seen before, add it to the overlapping set
        if word in seen_words:
            overlapping_words.add(word)
        else:
            seen_words.add(word)

# Create a new list of lists without overlapping words
new_seed = []

# Iterate through the original list of lists
for word_list in list_of_lists:
    # Create a new list containing words that are not in the overlapping set
    unique_word_list = [word for word in word_list if word not in overlapping_words]

    # Add the unique word list to the new list of lists
    new_seed.append(unique_word_list)

# new_list_of_lists now contains lists without overlapping words

In [None]:
[len(x) for x in new_seed]

[92, 28, 204, 64, 80, 4, 7, 102]

In [None]:
category2synonyms.keys()

dict_keys(['CM', 'ARR', 'CHD', 'VD', 'IHD', 'CCD', 'VOO', 'OTH'])

In [None]:
new_seed[5:7]

[['Cardiac Conduction Defects',
  'Cardiac Conduction System Diseases',
  'Cardiac Conduction System Disease',
  'Cardiac Conduction Defect'],
 ['Outflow Obstruction, Right Ventricular',
  'Outflow Obstruction, Left Ventricular',
  'Ventricular Outflow Obstruction, Left',
  'Right Ventricular Outflow Obstruction',
  'Ventricular Outflow Obstruction',
  'Ventricular Outflow Obstruction, Right',
  'Left Ventricular Outflow Obstruction']]

In [None]:
list_of_lists = seed.copy()

# Create a set to store words that have been seen
seen_words = {}

# Create a set to store overlapping words
overlapping_words = {}

# Iterate through the list of lists
for idx, word_list in enumerate(list_of_lists):
    for word in word_list:
        # If the word has been seen before, add it to the overlapping set
        if word in seen_words:
            if word in overlapping_words:
                overlapping_words[word].append(idx)
            else:
                overlapping_words[word] = [seen_words[word], idx]
            overlapping_words[word] = list(set(overlapping_words[word]))
        else:
            seen_words[word] = idx


To-do list: 09/05
- change the weights of the seed to 2 or 3 (line 3506 in _bertopic.py)
- run the topic reduction on the model we already had
- fuzzy synonym matching

#### with seed = 2

In [None]:
topic_model_6.reduce_topics(docs_1_, nr_topics=9)

<bertopic._bertopic.BERTopic at 0x789cd56d1a20>

In [None]:
df_6_ = topic_model_6.get_topic_info()
df_6_

Unnamed: 0,Topic,Count,Name,Representation,Representative_Docs
0,-1,2948,-1_stenosis_echocardiography_aorta_aortic,"[stenosis, echocardiography, aorta, aortic, ca...",[Comparison of procedural and in-hospital outc...
1,0,1520,0_ventricular_aortic_regurgitation_ventricle,"[ventricular, aortic, regurgitation, ventricle...",[Should high risk patients with concomitant se...
2,1,1374,1_cardiovascular_coronary_myocardial_angiotensin,"[cardiovascular, coronary, myocardial, angiote...",[Augmentation of endogenous adenosine attenuat...
3,2,1207,2_antiarrhythmic_arrhythmias_arrhythmia_cardiac,"[antiarrhythmic, arrhythmias, arrhythmia, card...",[Congenital long QT syndrome. Congenital long ...
4,3,553,3_cardiac_cardiomyopathy_myocarditis_echocardi...,"[cardiac, cardiomyopathy, myocarditis, echocar...",[Surgical Treatment of Cardiac Tumors: Insight...
5,4,310,4_trisomy_abnormalities_disorders_syndrome,"[trisomy, abnormalities, disorders, syndrome, ...",[Gender Dysphoria and Gender Change in Disorde...
6,5,36,5_ventricle_diagnosis_atrioventricular_ebstein,"[ventricle, diagnosis, atrioventricular, ebste...",[Multiplanar review of three-dimensional echoc...
7,6,29,6_clinical_clinicopathological_clinicoradiolog...,"[clinical, clinicopathological, clinicoradiolo...",[Case records of the Massachusetts General Hos...
8,7,23,7_commentary_invited_discussion_response,"[commentary, invited, discussion, response, , ...","[Invited Commentary. , Invited commentary. , I..."


In [None]:
', '.join(list(topic_model_6.get_topic_info(-1)['Representation'])[0])

'echocardiography, aortic, aorta, stenosis, ventricular, cardiovascular, atrioventricular, ventricle, artery, coronary'

In [None]:
topic_model_7.reduce_topics(docs_1['data'], nr_topics=9)
df_7_ = topic_model_7.get_topic_info()
df_7_

Unnamed: 0,Topic,Count,Name,Representation,Representative_Docs
0,-1,3441,-1_echocardiography_cardiac_ventricular_cardio...,"[echocardiography, cardiac, ventricular, cardi...",[Hemodynamic findings during exercise on a bic...
1,0,1239,0_ventricular_ventricle_cardiac_atrioventricular,"[ventricular, ventricle, cardiac, atrioventric...",[Intracardiac repair of lesions associated wit...
2,1,1146,1_cardiovascular_coronary_angiotensin_myocardial,"[cardiovascular, coronary, angiotensin, myocar...",[[Diagnostic and prognostic value of atheroscl...
3,2,547,2_aortic_transcatheter_tavi_stenosis,"[aortic, transcatheter, tavi, stenosis, transf...",[The Outcomes of Pulmonary Hypertension Patien...
4,3,450,3_tachyarrhythmias_tachycardia_antiarrhythmic_...,"[tachyarrhythmias, tachycardia, antiarrhythmic...",[Role of late potentials in identifying patien...
5,4,397,4_trisomy_abnormalities_congenital_syndrome,"[trisomy, abnormalities, congenital, syndrome,...",[Etiological classification and clinical asses...
6,5,276,5_resuscitation_resuscitated_postresuscitation...,"[resuscitation, resuscitated, postresuscitatio...",[Usefulness of cooling and coronary catheteriz...
7,6,273,6_atrial_fibrillation_cardioversion_ablation,"[atrial, fibrillation, cardioversion, ablation...",[Novel surgical ablation through a septal-supe...
8,7,231,7_mitral_regurgitation_ventricular_echocardiog...,"[mitral, regurgitation, ventricular, echocardi...",[How does the use of polytetrafluoroethylene n...


In [None]:
topic_model_7.visualize_topics()

In [None]:
topic_model_1.reduce_topics(docs_1['data'], nr_topics=9)
df_1_ = topic_model_1.get_topic_info()
df_1_

Unnamed: 0,Topic,Count,Name,Representation,Representative_Docs
0,-1,3038,-1_the_of_and_in,"[the, of, and, in, with, to, patients, was, fo...",[Ascending aortic aneurysms in unicommissural ...
1,0,2456,0_the_of_and_in,"[the, of, and, in, to, with, patients, was, we...",[[Correlation of the parameters of myocardial ...
2,1,2028,1_the_of_and_in,"[the, of, and, in, with, to, valve, aortic, pa...",[Ross procedure in congenital patients: result...
3,2,315,2_the_of_syndrome_in,"[the, of, syndrome, in, and, with, to, is, pat...",[Social and medical determinants of quality of...
4,3,47,3_the_case_practice_of,"[the, case, practice, of, to, massachusetts, y...",[Case records of the Massachusetts General Hos...
5,4,46,4_of_and_the_in,"[of, and, the, in, with, intoxication, to, wer...",[Clinicopathologic analysis of cardiac dysfunc...
6,5,30,5_lyme_adamsstokes_psoriasis_of,"[lyme, adamsstokes, psoriasis, of, syndrome, c...",[Successful treatment of hand and foot psorias...
7,6,27,6_invited_commentary_discussion_reply,"[invited, commentary, discussion, reply, rouma...","[Invited commentary. , Invited commentary. , I..."
8,7,13,7_digitalis_glycosides_treatment_of,"[digitalis, glycosides, treatment, of, sdigoxi...","[[DIGITALIS GLYCOSIDES IN RHYTHM DISORDERS]. ,..."


In [None]:
topic_model_1.visualize_topics()

In [None]:
topic_model_8.reduce_topics(docs_1['data'], nr_topics=9)
df_8_ = topic_model_8.get_topic_info()
df_8_

Unnamed: 0,Topic,Count,Name,Representation,Representative_Docs
0,-1,4057,-1_echocardiography_ventricular_cardiac_ventricle,"[echocardiography, ventricular, cardiac, ventr...",[Left ventricular reconstruction: Early and la...
1,0,1215,0_cardiovascular_coronary_cardiac_myocardial,"[cardiovascular, coronary, cardiac, myocardial...",[Prognostic value of uric acid in patients wit...
2,1,736,1_aortic_ventricular_transcatheter_stenosis,"[aortic, ventricular, transcatheter, stenosis,...",[Outcomes of definitive surgical repair for co...
3,2,463,2_tachyarrhythmias_antiarrhythmic_arrhythmias_...,"[tachyarrhythmias, antiarrhythmic, arrhythmias...",[Combined use of time and frequency domain var...
4,3,458,3_cardiac_myocarditis_endocarditis_echocardiog...,"[cardiac, myocarditis, endocarditis, echocardi...","[Pediatric cardiac tumors: a 45-year, single-i..."
5,4,306,4_hormone_trisomy_ovarian_hypogonadism,"[hormone, trisomy, ovarian, hypogonadism, abno...",[Gender Dysphoria and Gender Change in Disorde...
6,5,297,5_atrial_antiarrhythmic_fibrillation_ablation,"[atrial, antiarrhythmic, fibrillation, ablatio...",[Novel surgical ablation through a septal-supe...
7,6,270,6_resuscitation_resuscitated_defibrillation_po...,"[resuscitation, resuscitated, defibrillation, ...",[Usefulness of cooling and coronary catheteriz...
8,7,198,7_mitral_regurgitation_echocardiography_ventri...,"[mitral, regurgitation, echocardiography, vent...",[How does the use of polytetrafluoroethylene n...


In [None]:
topic_model_8.visualize_topics()

In [None]:
topic_model_0 = BERTopic(seed_topic_list=seed, n_gram_range=(1,4))
topics_0, probs_0 = topic_model_0.fit_transform(docs_1_)

In [None]:
topic_model_0.reduce_topics(docs_1_, nr_topics=9)
df_0_ = topic_model_0.get_topic_info()
df_0_

Unnamed: 0,Topic,Count,Name,Representation,Representative_Docs
0,-1,2813,-1_the_of_and_in,"[the, of, and, in, with, to, patients, was, fo...",[Estimation of the left ventricular diastolic ...
1,0,2621,0_the_of_and_in,"[the, of, and, in, with, to, patients, was, of...",[Early left ventricular remodeling after aorti...
2,1,1439,1_the_of_and_in,"[the, of, and, in, with, to, patients, was, fo...",[[Establishment of porcine model of prolonged ...
3,2,1022,2_the_of_and_in,"[the, of, and, in, to, with, patients, was, he...",[Developmental changes in tolerance to ischaem...
4,3,35,3_ebsteins_the_of_in,"[ebsteins, the, of, in, tricuspid, of the, ano...",[Long-term ECG in ambulatory clinical practice...
5,4,34,4_case_the_case records of the_records of the ...,"[case, the, case records of the, records of th...",[Case records of the Massachusetts General Hos...
6,5,13,5_digitalis_of_glycosides_treatment,"[digitalis, of, glycosides, treatment, intoxic...","[Digitalis intoxication. , Digitalis therapy i..."
7,6,12,6_invited commentary_invited_commentary_discus...,"[invited commentary, invited, commentary, disc...","[Invited commentary. , Invited commentary. , I..."
8,7,11,7_reply hard_reply hard lightning_lightning_ha...,"[reply hard, reply hard lightning, lightning, ...","[Reply. , Hard. , Lightning. ]"


To-do:
- see if we can edit the number of words in Representation
- compare the Representation words to the synonyms
- see if we can edit the English stop words in the vectorizer model

### Manual modeling

In [None]:
from sklearn.datasets import fetch_20newsgroups

# Get labeled data
data = fetch_20newsgroups(subset='all',  remove=('headers', 'footers', 'quotes'))
docs = data['data']
y = data['target']

In [None]:
from bertopic import BERTopic
from bertopic.backend import BaseEmbedder
from bertopic.cluster import BaseCluster
from bertopic.vectorizers import ClassTfidfTransformer
from bertopic.dimensionality import BaseDimensionalityReduction

# Prepare our empty sub-models and reduce frequent words while we are at it.
empty_embedding_model = BaseEmbedder()
empty_dimensionality_model = BaseDimensionalityReduction()
empty_cluster_model = BaseCluster()
ctfidf_model = ClassTfidfTransformer(reduce_frequent_words=True)

# Fit BERTopic without actually performing any clustering
topic_model= BERTopic(
        embedding_model=empty_embedding_model,
        umap_model=empty_dimensionality_model,
        hdbscan_model=empty_cluster_model,
        ctfidf_model=ctfidf_model
)
# topics, probs = topic_model.fit_transform(docs, y=y)


In [None]:
docs_1.keys()

dict_keys(['data', 'labels', 'names'])

In [None]:
# Map input `y` to topics
mappings = topic_model.topic_mapper_.get_mappings()
mappings = {value: data["target_names"][key] for key, value in mappings.items()}

# Assign original classes to our topics
df = topic_model.get_topic_info()
df["Class"] = df.Topic.map(mappings)
df


In [None]:
# Prepare our empty sub-models and reduce frequent words while we are at it.
empty_embedding_model = BaseEmbedder()
empty_dimensionality_model = BaseDimensionalityReduction()
empty_cluster_model = BaseCluster()
ctfidf_model = ClassTfidfTransformer(reduce_frequent_words=True)

# Fit BERTopic without actually performing any clustering
topic_model= BERTopic(
        embedding_model=empty_embedding_model,
        umap_model=empty_dimensionality_model,
        hdbscan_model=empty_cluster_model,
        ctfidf_model=ctfidf_model
)
topics, probs = topic_model.fit_transform(docs_1['data'], y=docs_1['labels'])

In [None]:
mappings = topic_model.topic_mapper_.get_mappings()
mappings = {value: docs_1["names"][key] for key, value in mappings.items()}

# Assign original classes to our topics
df = topic_model.get_topic_info()
df["Class"] = df.Topic.map(mappings)
df

Unnamed: 0,Topic,Count,Name,Representation,Representative_Docs,Class
0,0,1000,0_failure_heart_hf_in,"[failure, heart, hf, in, and, to, of, with, th...",[Mode of death in heart failure: findings from...,CM
1,1,1000,1_af_atrial_fibrillation_with,"[af, atrial, fibrillation, with, patients, of,...",[Tachycardias of right ventricular origin. Ven...,CM
2,2,1000,2_syndrome_the_of_with,"[syndrome, the, of, with, artery, and, in, pul...",[Right Ventricular Outflow Tract Reconstructio...,CM
3,3,1000,3_valve_mitral_aortic_regurgitation,"[valve, mitral, aortic, regurgitation, patient...",[[Percutaneous balloon valvuloplasty for sever...,CM
4,4,1000,4_coronary_myocardial_infarction_in,"[coronary, myocardial, infarction, in, of, and...",[Bivalirudin versus heparin with or without gl...,CM
5,5,1000,5_block_tachycardia_qt_ventricular,"[block, tachycardia, qt, ventricular, the, in,...",[Effect of propranolol on ventricular rate dur...,CM
6,6,1000,6_aortic_valve_stenosis_patients,"[aortic, valve, stenosis, patients, tavr, repl...",[Repair of interrupted aortic arch: a ten-year...,CM
7,7,1000,7_arrest_cardiac_hypertrophy_of,"[arrest, cardiac, hypertrophy, of, and, to, in...",[Successful use of therapeutic hypothermia aft...,CM


In [None]:
topic_model.get_topic_info()

Unnamed: 0,Topic,Count,Name,Representation,Representative_Docs
0,0,1000,0_failure_heart_hf_in,"[failure, heart, hf, in, and, to, of, with, th...",[Mode of death in heart failure: findings from...
1,1,1000,1_af_atrial_fibrillation_with,"[af, atrial, fibrillation, with, patients, of,...",[Tachycardias of right ventricular origin. Ven...
2,2,1000,2_syndrome_the_of_with,"[syndrome, the, of, with, artery, and, in, pul...",[Right Ventricular Outflow Tract Reconstructio...
3,3,1000,3_valve_mitral_aortic_regurgitation,"[valve, mitral, aortic, regurgitation, patient...",[[Percutaneous balloon valvuloplasty for sever...
4,4,1000,4_coronary_myocardial_infarction_in,"[coronary, myocardial, infarction, in, of, and...",[Bivalirudin versus heparin with or without gl...
5,5,1000,5_block_tachycardia_qt_ventricular,"[block, tachycardia, qt, ventricular, the, in,...",[Effect of propranolol on ventricular rate dur...
6,6,1000,6_aortic_valve_stenosis_patients,"[aortic, valve, stenosis, patients, tavr, repl...",[Repair of interrupted aortic arch: a ten-year...
7,7,1000,7_arrest_cardiac_hypertrophy_of,"[arrest, cardiac, hypertrophy, of, and, to, in...",[Successful use of therapeutic hypothermia aft...


In [None]:
topic_model.visualize_topics()

In [None]:
from sentence_transformers import SentenceTransformer

In [None]:
dir(embedding_model)

In [None]:
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = embedding_model.embed(docs_1['data'][:10])

AttributeError: ignored

In [None]:
embeddings

In [None]:
topic_model.transform(docs_1['data'][:10])

ValueError: ignored

In [None]:
topic_model