In [1]:
from pyhealth.datasets import UMLSDataset

umls_ds = UMLSDataset(
    root="https://storage.googleapis.com/pyhealth/umls/",
    dev=True,
    refresh_cache=True
)

  from .autonotebook import tqdm as notebook_tqdm


INFO: Pandarallel will run on 64 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


In [2]:
print(umls_ds.stat()) 
print("Relations in KG:", umls_ds.relation2id)


Statistics of base dataset (dev=True):
	- Dataset: UMLSDataset
	- Number of triples: 88176
	- Number of entities: 9737
	- Number of relations: 8
	- Task name: Null
	- Number of samples: 0

None
Relations in KG: {'may_be_treated_by': 0, 'may_be_prevented_by': 1, 'may_prevent': 2, 'may_treat': 3, 'gene_mapped_to_disease': 4, 'disease_mapped_to_gene': 5, 'disease_has_associated_gene': 6, 'gene_associated_with_disease': 7}


In [3]:
from pyhealth.tasks import link_prediction_fn

umls_ds = umls_ds.set_task(link_prediction_fn, negative_sampling=256, save=False)

Processing UMLSDataset base dataset...


100%|██████████| 88176/88176 [00:01<00:00, 64441.63it/s]

Saving UMLSDataset base dataset to /home/pj20/.cache/pyhealth/datasets/46e7370273967c215741135e6ccdd2b9





In [4]:
umls_ds.stat()


Statistics of base dataset (dev=True):
	- Dataset: UMLSDataset
	- Number of triples: 88176
	- Number of entities: 9737
	- Number of relations: 8
	- Task name: link_prediction_fn
	- Task-specific hyperparameters: {'negative_sampling': 256}



In [5]:
print(umls_ds[0])

{'triple': (0, 0, 2835), 'ground_truth_head': [1027, 1293, 5264, 1564, 7416, 6434, 2610, 4094, 2717, 5007, 5277, 5949, 0, 6870, 6029], 'ground_truth_tail': [398, 244, 3872, 3053, 1711, 2835, 1348, 2309], 'subsampling_weight': tensor([0.1857])}


In [6]:
from pyhealth.datasets import split, get_dataloader

train_dataset, val_dataset, test_dataset = split(umls_ds, [0.9, 0.05, 0.05])

In [7]:
train_loader = get_dataloader(train_dataset, batch_size=256, shuffle=True)
val_loader = get_dataloader(val_dataset, batch_size=16, shuffle=False)
test_loader = get_dataloader(test_dataset, batch_size=16, shuffle=False)

In [9]:
from pyhealth.models.kg import TransE, RotatE

model = RotatE(
    dataset=umls_ds,
    e_dim=600, 
    r_dim=300, 
)

In [10]:
from pyhealth.trainer import Trainer

trainer = Trainer(model=model, device='cuda', metrics=['hits@n', 'mean_rank'])
trainer.train(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    epochs=10,
    steps_per_epoch=100,
    evaluation_steps=10,
    optimizer_params={'lr': 1e-3},
    monitor='mean_rank',
    monitor_criterion='min'
)

RotatE()
Metrics: ['hits@n', 'mean_rank']
Device: cuda

Training:
Batch size: 256
Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.001}
Weight decay: 0.0
Max grad norm: None
Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7f0a72ccefd0>
Monitor: mean_rank
Monitor criterion: min
Epochs: 10

Epoch 0 / 10: 100%|██████████| 100/100 [00:04<00:00, 20.86it/s]
--- Train epoch-0, step-100 ---
loss: 2.3303
Evaluation: 100%|██████████| 276/276 [00:03<00:00, 75.06it/s]
--- Eval epoch-0, step-100 ---
HITS@1: 0.0024
HITS@5: 0.0119
HITS@10: 0.0221
mean_rank: 2318.6816
mean_reciprocal_rank: 0.0116
loss: 0.1119
New best mean_rank score (2318.6816) at epoch-0, step-100

Epoch 1 / 10: 100%|██████████| 100/100 [00:05<00:00, 19.78it/s]
--- Train epoch-1, step-200 ---
loss: 0.8013

Epoch 2 / 10: 100%|██████████| 100/100 [00:05<00:00, 19.47it/s]
--- Train epoch-2, step-300 ---
loss: 0.6174

Epoch 3 / 10: 100%|██████████| 100/100 [00:05<00:00, 19.90it/s]
--- Train epoch

In [11]:
trainer.evaluate(test_loader)

Evaluation: 100%|██████████| 276/276 [00:03<00:00, 69.59it/s]


{'HITS@1': 0.9324109775459288,
 'HITS@5': 0.9621229303696983,
 'HITS@10': 0.9675663415740531,
 'mean_rank': 45.8667498298934,
 'mean_reciprocal_rank': 0.9459430019329782,
 'loss': 0.007553577323552167}