In [1]:
import ujson as json
import numpy as np
import pandas as pd
from bertopic import BERTopic
from umap import UMAP

from spacy.lang.en import English

import tqdm

Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)


In [2]:
# MAG metadata
with open("../scidocs/data/paper_metadata_mag_mesh.json", "r") as mag_metadata_file:
    mag_metadata = json.load(mag_metadata_file)

In [3]:
# What keys are there in each paper?
print(mag_metadata['fedb8360a09a326f403dcca14494e1da8a5f3adc'])
print(mag_metadata['fedb8360a09a326f403dcca14494e1da8a5f3adc'].keys())

{'abstract': 'BACKGROUND\nMany patients with coronary artery disease (CAD) fail to attend cardiac rehabilitation following acute coronary events because they lack motivation to exercise. Theory-based approaches to promote physical activity among non-participants in cardiac rehabilitation are required.\n\n\nDESIGN\nA randomized trial comparing physical activity levels at baseline, 6, and 12 months between a motivational counselling (MC) intervention group and a usual care (UC) control group.\n\n\nMETHOD\nOne hundred and forty-one participants hospitalized with acute coronary syndromes not planning to attend cardiac rehabilitation were recruited at a single centre and randomized to either MC (n\u2009=\u200969) or UC (n\u2009=\u200972). The MC intervention, designed from an ecological perspective, included one face-to-face contact and eight telephone contacts with a trained physiotherapist over a 52-week period. The UC group received written information about starting a walking programme 

In [4]:
# Load the paper ids in the MAG validation set
mag_val = pd.read_csv("../scidocs/data/mag/val.csv")

mag_val_pids = set(mag_val.pid)

In [5]:
print(mag_val)

                                           pid  class_label
0     901362650a27cf0a40193eaefdd7eea042a70780            9
1     dab652f2a1f1006079a03d373ba5b8ff1af93934           18
2     f29fb8fecb090291a635887ff8e13e566d785da7           13
3     3353eb3708b2df906d7831dffb93fd25c9f5e84b            9
4     d3679b15e371c6d8a2eafb1527aecfb1fdaee16d            9
...                                        ...          ...
3746  23d11486c371ef3235f74ed9994dd8f68686a47e           11
3747  fedb8360a09a326f403dcca14494e1da8a5f3adc           13
3748  2f4226076bf759367076094df6f09a59cb23fbc2           11
3749  62618110bb09c9629338ebd2059d3465907a7196            0
3750  e7511458ab482ba305277caaa4bee18af1b39681            6

[3751 rows x 2 columns]


In [6]:
spacy_nlp = English()

In [7]:
# Read the embeddings jsonl created with embed.py
facet_selected = 0

mag_embeddings = []
mag_docs = []
mag_labels = []

with open("save_U_k-3_sum_embs_original-0-9+no_sum-0-1+mean-avg_word-0-05_extra_facet_alternate_layer_8_4-alternate_identity_common_random_cross_entropy_02-09/cls_no_sum.jsonl", "r") as mag_embeddings_file:
    for line in tqdm.tqdm(mag_embeddings_file):
        paper = json.loads(line)

        if paper["paper_id"] in mag_val_pids:
            emb = paper["embedding"][facet_selected]
            mag_embeddings.append(np.array(emb))

            # Filter out some stop words
            text = (mag_metadata[paper["paper_id"]].get("title") or "") + " " + (mag_metadata[paper["paper_id"]].get("abstract") or "")
            text_filtered = ""

            spacy_output = spacy_nlp(text)

            for j, token in enumerate(spacy_output):
                if not (token.is_stop or token.is_punct):
                    text_filtered += str(token) + " "

            mag_docs.append(text_filtered)
            
            mag_labels.append(mag_val[mag_val.pid == paper["paper_id"]].iloc[0].class_label)

mag_embeddings = np.array(mag_embeddings)
mag_labels = np.array(mag_labels)

48473it [00:28, 1705.19it/s]


In [9]:
umap_model = UMAP(n_neighbors=15, n_components=5, min_dist=0.0, metric='cosine', random_state=42)
topic_model = BERTopic(umap_model=umap_model)

In [10]:
topics, probs = topic_model.fit_transform(mag_docs, mag_embeddings)

OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.


In [None]:
topic_model.visualize_topics()

In [None]:
topic_model.visualize_hierarchy()

In [None]:
topic_model.visualize_barchart()

In [None]:
topic_model.visualize_heatmap()

In [None]:
topic_model.visualize_term_rank()