In [1]:
from transformers import AutoTokenizer, AutoModel
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import torch
from time import time

In [2]:
# Medical terms were generated and categorized by ChatGPTv3 (6-4-2023)
medical_terms = {
    'brain_tumours': ["Basal ganglia", "Brainstem", "Cerebellum", "Diencephalon", "Epithalamus", "Frontal lobe", "Gyrus", "Hippocampus", "Insula", "Junction", "Kuhne's commissure", "Limbic lobe", "Midbrain", "Nucleus", "Occipital lobe", "Parietal lobe", "Quadrigeminal plate", "Reticular formation", "Sulcus", "Temporal lobe", "Uncus", "lateral ventricle", "third ventricle", "fourth ventricle", "White matter", "Corpus callosum", "Pineal gland", "Pons"],
    'Cardiology': ['Arrhythmia', 'Atherosclerosis', 'Cardiomyopathy', 'Endocarditis', 'Myocardial infarction', 'Pericarditis', 'Tachycardia'],
    'Dermatology': ['Acne', 'Dermatitis', 'Eczema', 'Hives', 'Melanoma', 'Psoriasis', 'Rosacea'],
    'Endocrinology': ['Diabetes mellitus', 'Goiter', 'Hyperthyroidism', 'Hypothyroidism', 'Osteoporosis', 'Pheochromocytoma', 'Pituitary adenoma'],
    'Gastroenterology': ['Cholecystitis', 'Cirrhosis', 'Colitis', 'Gastroenteritis', 'Hepatitis', 'Pancreatitis', 'Ulcerative colitis'],
    'Hematology': ['Anemia', 'Hemophilia', 'Leukemia', 'Lymphoma', 'Multiple myeloma', 'Sickle cell anemia', 'Thrombocytopenia'],
    'Neurology': ['Alzheimer\'s disease', 'Epilepsy', 'Meningitis', 'Multiple sclerosis', 'Parkinson\'s disease', 'Stroke', 'Traumatic brain injury'],
    'Oncology': ['Carcinoma', 'Chemotherapy', 'Immunotherapy', 'Metastasis', 'Radiation therapy', 'Sarcoma', 'Tumor'],
}

all_terms = []
for key, item in medical_terms.items():
    all_terms.extend(item)

In [2]:
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [10]:
inputs = tokenizer(all_terms, padding=True, return_tensors="pt")

In [11]:
start = time()
with torch.no_grad():
    for _ in range(20):
        outputs = model(**inputs)
print(f"This took {time() - start:2f}s")

This took 55.719733s


In [12]:
last_hidden_states = outputs.last_hidden_state
last_hidden_states.shape # examples, max number of tokens in sequence, hidden units

torch.Size([77, 10, 768])

In [13]:
# Select useful features from the output
features_cls = last_hidden_states[:, 0, :] # [CLS] embedding

In [14]:
pca = PCA(10)
pca_tranformed_features_cls = pca.fit_transform(features_cls)
tsne = TSNE(perplexity=2, verbose=10, n_jobs=10)
transformed_features_cls = tsne.fit_transform(pca_tranformed_features_cls)

[t-SNE] Computing 7 nearest neighbors...
[t-SNE] Indexed 77 samples in 0.095s...
[t-SNE] Computed neighbors for 77 samples in 0.015s...
[t-SNE] Computed conditional probabilities for sample 77 / 77
[t-SNE] Mean sigma: 0.661291
[t-SNE] Computed conditional probabilities in 0.002s
[t-SNE] Iteration 50: error = 76.6415405, gradient norm = 0.5136506 (50 iterations in 55.580s)


KeyboardInterrupt: 

In [None]:
for i, term in enumerate(all_terms):
    plt.scatter(*transformed_features_cls.T)
    plt.text(transformed_features_cls.T[0][i], transformed_features_cls.T[1][i], term)
plt.show()

In [11]:
from dpat.mil.models.ccmil import CCMIL
from dpat.mil.models.varmil import VarAttention
from dpat.data import PMCHHGH5Dataset, PMCHHGH5DataModule

import torch


In [15]:
dm = PMCHHGH5DataModule(
    file_path="/home/sdejong/pmchhg/features/imagenet-11-4-2023-fold-0.hdf5",
    train_path="/home/sdejong/pmchhg/images-tif/splits_with_locations/medulloblastoma+pilocytic-astrocytoma_pmchhg_train-subfold-0-fold-0.csv",
    val_path="/home/sdejong/pmchhg/images-tif/splits_with_locations/medulloblastoma+pilocytic-astrocytoma_pmchhg_val-subfold-0-fold-0.csv",
    test_path="/home/sdejong/pmchhg/images-tif/splits_with_locations/medulloblastoma+pilocytic-astrocytoma_pmchhg_test-subfold-0-fold-0.csv",
    clinical_context=False,
    num_classes=2,
)
dm.setup("fit")

ccmil = CCMIL(
    in_features=1024,
    layers=[2, 2, 2],
    num_classes=2,
    T_max=1000,
    dropout=0.5,
    lr= 0.0003,
    momentum = 0.01,
    wd = 0.01,
).eval()

varmil = VarAttention(
    in_features=1024,
    layers=[2, 2, 2],
    num_classes=2,
    T_max=1000,
    dropout=0.5,
    lr= 0.0003,
    momentum = 0.01,
    wd = 0.01,
).eval()

Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [16]:
loader = dm.train_dataloader()

In [20]:
from time import time
with torch.no_grad():
    batch = next(iter(loader))
    for _ in range(3):
        start = time()
        print("CCMIL", ccmil(batch["data"], str(batch["cc"]))[0])
        print("t=", time() - start)
        start = time()
        print("VarMIL", varmil(batch["data"])[0])
        print("t=", time() - start)
        print("\n")

CCMIL tensor([[-0.2352, -0.1425]])
t= 2.5096795558929443
VarMIL tensor([[-0.2796,  0.5978]])
t= 0.807194709777832


CCMIL tensor([[-0.2352, -0.1425]])
t= 0.39732885360717773
VarMIL tensor([[-0.2796,  0.5978]])
t= 0.30567264556884766


CCMIL tensor([[-0.2352, -0.1425]])
t= 0.3991687297821045
VarMIL tensor([[-0.2796,  0.5978]])
t= 0.4811522960662842


