In [1]:
from pyhealth.datasets import UMLSDataset

umls_ds = UMLSDataset(
    root="https://storage.googleapis.com/pyhealth/umls/",
    dev=True,
    refresh_cache=True
)

  from .autonotebook import tqdm as notebook_tqdm


INFO: Pandarallel will run on 64 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


In [2]:
print(umls_ds.stat()) 
print("Relations in KG:", umls_ds.relation2id)


Statistics of base dataset (dev=True):
	- Dataset: UMLSDataset
	- Number of triples: 88176
	- Number of entities: 9737
	- Number of relations: 8
	- Task name: Null
	- Number of samples: 0

None
Relations in KG: {'may_be_treated_by': 0, 'may_be_prevented_by': 1, 'may_prevent': 2, 'may_treat': 3, 'gene_mapped_to_disease': 4, 'disease_mapped_to_gene': 5, 'disease_has_associated_gene': 6, 'gene_associated_with_disease': 7}


In [3]:
from pyhealth.tasks import link_prediction_fn

umls_ds = umls_ds.set_task(link_prediction_fn, negative_sampling=256, save=False)

Processing UMLSDataset base dataset...


100%|██████████| 88176/88176 [00:01<00:00, 68839.24it/s]

Saving UMLSDataset base dataset to /home/pj20/.cache/pyhealth/datasets/46e7370273967c215741135e6ccdd2b9





In [4]:
umls_ds.stat()


Statistics of base dataset (dev=True):
	- Dataset: UMLSDataset
	- Number of triples: 88176
	- Number of entities: 9737
	- Number of relations: 8
	- Task name: link_prediction_fn
	- Task-specific hyperparameters: {'negative_sampling': 256}



In [17]:
print(umls_ds[0])

{'triple': (0, 0, 2835), 'ground_truth_head': [1027, 1293, 5264, 1564, 7416, 6434, 2610, 4094, 2717, 5007, 5277, 5949, 0, 6870, 6029], 'ground_truth_tail': [398, 244, 3872, 3053, 1711, 2835, 1348, 2309], 'subsampling_weight': tensor([0.1857])}


In [6]:
from pyhealth.datasets import split, get_dataloader

train_dataset, val_dataset, test_dataset = split(umls_ds, [0.9, 0.05, 0.05])

In [7]:
train_loader = get_dataloader(train_dataset, batch_size=256, shuffle=True)
val_loader = get_dataloader(val_dataset, batch_size=16, shuffle=False)
test_loader = get_dataloader(test_dataset, batch_size=16, shuffle=False)

In [10]:
from pyhealth.models.kg import TransE, RotatE, ComplEx, DistMult

model = RotatE(
    dataset=umls_ds,
    e_dim=600, 
    r_dim=300, 
)

In [None]:
from pyhealth.trainer import Trainer

trainer = Trainer(model=model, device='cuda', metrics=['hits@n', 'mean_rank'])
trainer.train(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    epochs=10,
    steps_per_epoch=100,
    evaluation_steps=10,
    optimizer_params={'lr': 1e-3},
    monitor='mean_rank',
    monitor_criterion='min'
)

In [12]:
trainer.evaluate(test_loader)

Evaluation: 100%|██████████| 276/276 [00:04<00:00, 68.93it/s]


{'HITS@1': 0.9292356543433885,
 'HITS@5': 0.9605352687684282,
 'HITS@10': 0.9671127239736902,
 'mean_rank': 45.18371512814697,
 'mean_reciprocal_rank': 0.9439417134301843,
 'loss': 0.008202967498842896}