# Retraining XLinear on MeSH tags used by Wellcome

## 1. How many labels are actually used in training?

In [1]:
import json

input_path = "/data/grants_tagger/data/raw/allMeSH_2021.json"


def yield_raw_data(input_path):
    with open(input_path, encoding="latin-1") as f_i:
        f_i.readline()  # skip first line ({"articles":[) which is not valid JSON
        for i, line in enumerate(f_i):
            item = json.loads(line[:-2])
            yield item


input_data = yield_raw_data(input_path)

### number of Mesh terms used in the training set

In [2]:
mesh_in_training = []

for line in input_data:
    mesh_in_training.extend(line["meshMajor"])

In [3]:
mesh_training_labels = list(set(mesh_in_training))

In [4]:
len(mesh_training_labels)

29369

## 2. How many labels are used by Wellcome?

In [7]:
wellcome_labels_path = "../data/processed/WT_mesh_tags_used/tags_used.txt"

In [8]:
with open(wellcome_labels_path, "r") as fp:
    y = fp.readlines()

wellcome_labels = [label.split("\n")[0] for label in y]

In [10]:
print(f" there are 29,917 labels in 2021 MeSH")
print(f" there are {len(mesh_training_labels)} labels used in training")
print(f" there are {len(set(wellcome_labels))} labels used by Wellcome")

print(f" which means the trainings set is short by {1-len(mesh_training_labels)/29917}")
print(f" which means {1-len(set(wellcome_labels))/29917} labels aren't used at Wellcome")

 there are 29,917 labels in 2021 MeSH
 there are 29369 labels used in training
 there are 25252 labels used by Wellcome
 which means the trainings set is short by 0.018317344653541512
 which means 0.15593141023498347 labels aren't used at Wellcome


## 3. create a .csv with terms that we would like to stop using because they are contentious

In [22]:
import os 
import httpx
import uuid
import tqdm

def download_from_url(*, url, filename):
    """
    Download a file to the given filename, with a progress bar.
    """
    if os.path.exists(filename):
        return

    # How big is the file?
    # Note: this will throw if the server doesn't return a Content-Length
    # header.  We're downloading snapshots from S3, which always does,
    # but this code may not be suitable elsewhere.
    size = int(httpx.head(url).headers["Content-Length"])

    tmp_path = filename + "." + str(uuid.uuid4()) + ".tmp"

    with open(tmp_path, "wb") as outfile:
        with tqdm.tqdm(
            total=size, unit="B", desc=os.path.basename(filename), unit_scale=True
        ) as pbar:
            with httpx.stream("GET", url) as resp:
                for chunk in resp.iter_bytes():
                    if chunk:
                        outfile.write(chunk)
                        pbar.update(len(chunk))

    os.rename(tmp_path, filename)

In [25]:
# let's download the latest MeSH xml
year = "2021"
mesh_page = f'https://nlmpubs.nlm.nih.gov/projects/mesh/{year}/xmlmesh/desc{year}.xml'
path_to_tree = 'desc'+year+'.xml'
download_from_url(url=mesh_page, filename=path_to_tree)



In [29]:
# this is a library that can parse xml easily
import xml.etree.ElementTree as ET
mesh_tree = ET.parse(path_to_tree)

In [69]:
annotations = []
descriptors = []
for mesh in mesh_tree.iter('DescriptorRecord'):
    descriptors.append(mesh.find('DescriptorName').find('String').text)
    annotation = mesh.find('Annotation')
    if annotation is None:
        annotations.append('')
    else:
        annotations.append(mesh.find('Annotation').text)

In [88]:
descriptors_to_use = []

descriptors_set = zip(descriptors, annotations)

for descriptor, annotation in descriptors_set:
    if "Do not use" not in annotation:
        descriptors_to_use.append(descriptor)

In [91]:
# create a csv with terms we would like to keep

import pandas as pd
pd.DataFrame(to_drop, columns = ['descriptors']).to_csv('../data/processed/descriptors_to_use.csv')

## 3. create an csv with terms we would like to keep

In [100]:
wt_descriptors = set(descriptors_to_use).intersection(set(wellcome_labels))

len(wt_descriptors)

25240

In [101]:
pd.DataFrame(wt_descriptors, columns=["DescriptorName"]).to_csv(
    "../data/processed/wt_tags_used.csv"
)

## 4. Train Xlinear model


In [3]:
%load_ext autoreload
%autoreload 2

from grants_tagger.slim.mesh_xlinear import train, evaluate

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
parameters = {
    "ngram_range": (1, 1),
    "beam_size": 30,
    "only_topk": 200,
    "min_weight_value": 0.1,
    "max_features": 400_000,
}

In [4]:
!$SLACK_USER

In [1]:
# this posts a command to Slack to warn training has started
!curl -X POST -H 'Content-type: application/json' --data "{'text': 'Hi <$SLACK_USER>, training has started'}" $SLACK_HOOK


curl: no URL specified!
curl: try 'curl --help' or 'curl --manual' for more information


In [5]:
# this will train the data
model, label_binarizer = train(
    # uncomment for toy data
        train_data_path='../data/processed/train_mesh2021_wt.jsonl',
        label_binarizer_path='../models/label_binarizer_wt.pkl',
        parameters=parameters,
        model_path='../models/xlinear-wt'
)

../models/label_binarizer_wt.pkl exists. Loading existing
Loading data...




Fitting model
Saving model on ../models/xlinear-toy


In [6]:
results, full_report = evaluate(
    model,
    label_binarizer,
    train_data_path="../data/processed/train_mesh2021_wt.jsonl",
    test_data_path="../data/processed/test_mesh2021_wt.jsonl",
    results_path="../results/results_wt.json",
    full_report_path="../results/full_report_wt.json",
)

Loading data...




Loading data...




Evaluating model


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
!curl -X POST -H 'Content-type: application/json' --data "{'text': 'Hi <$SLACK_USER>, training has finished'}" $SLACK_HOOK

### results from the version that includes all tags:
{'threshold': '0.50', 'precision': '0.74', 'recall': '0.41', 'f1': '0.53'}

In [7]:
# results WT only grants:
results

{'threshold': '0.50', 'precision': '0.78', 'recall': '0.45', 'f1': '0.57'}

In [4]:
!curl -X POST -H 'Content-type: application/json' --data "{'text': 'Hi <$SLACK_USER>, I think I am developing consciousness'}" $SLACK_HOOK

ok