In [1]:
from pyhealth.datasets import UMLSDataset

umls_ds = UMLSDataset(
    root="https://storage.googleapis.com/pyhealth/umls/",
    dev=True,
    refresh_cache=True
)

  from .autonotebook import tqdm as notebook_tqdm


INFO: Pandarallel will run on 64 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


In [2]:
umls_ds.stat()


Statistics of base dataset (dev=True):
	- Dataset: UMLSDataset
	- Number of triples: 88176
	- Number of entities: 9737
	- Number of relations: 8
	- Task name: Null
	- Number of samples: 0



In [3]:
from pyhealth.tasks import link_prediction_fn

umls_ds = umls_ds.set_task(link_prediction_fn, negative_sampling=256, save=False)

Processing UMLSDataset base dataset...


100%|██████████| 88176/88176 [00:32<00:00, 2702.37it/s]

Saving UMLSDataset base dataset to /home/pj20/.cache/pyhealth/datasets/46e7370273967c215741135e6ccdd2b9





In [4]:
umls_ds.stat()


Statistics of base dataset (dev=True):
	- Dataset: UMLSDataset
	- Number of triples: 88176
	- Number of entities: 9737
	- Number of relations: 8
	- Task name: link_prediction_fn
	- Number of samples: 176352



In [5]:
from pyhealth.datasets import split_by_keys, get_dataloader_kg

train_dataset, val_dataset, test_dataset = split_by_keys(umls_ds, [0.9, 0.05, 0.05])

In [6]:
train_loader = get_dataloader_kg(train_dataset, batch_size=256, train=True, shuffle=True)
val_loader = get_dataloader_kg(val_dataset, batch_size=16, train=False, shuffle=False)
test_loader = get_dataloader_kg(test_dataset, batch_size=16, train=False, shuffle=False)

In [7]:
from pyhealth.models.kg import TransE, RotatE

model = RotatE(
    dataset=umls_ds
)
for data in train_loader:
    inputs, mode = (data['positive_sample'], data['negative_sample'], data['subsample_weight']), data['mode']
    inputs = [x.to("cpu") for x in inputs]
    pos_sample, neg_sample, subsampling_weight = inputs
    sample_batch = (pos_sample, neg_sample)
    head, relation, tail = model.data_process(sample_batch, mode)
    print(head.shape, relation.shape, tail.shape)
    break

for data in val_loader:
    inputs, mode = (data['positive_sample'], data['negative_sample'], data['filter_bias']), data['mode']
    inputs = [x.to("cpu") for x in inputs]
    pos_sample, neg_sample, filter_bias = inputs
    sample_batch = (pos_sample, neg_sample)
    head, relation, tail = model.data_process(sample_batch, mode)
    print(head.shape, relation.shape, tail.shape)
    break


torch.Size([256, 1, 600]) torch.Size([256, 1, 300]) torch.Size([256, 256, 600])
torch.Size([16, 1, 600]) torch.Size([16, 1, 300]) torch.Size([16, 9737, 600])


In [11]:
from pyhealth.models.kg import TransE, RotatE

model = RotatE(
    dataset=umls_ds
)

In [None]:
from pyhealth.trainer import Trainer

trainer = Trainer(model=model, device='cuda', metrics=['hits@n', 'mean_rank'])
trainer.train(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    epochs=20,
    steps_per_epoch=100,
    evaluation_steps=10,
    optimizer_params={'lr': 1e-3},
    monitor='mean_rank',
    monitor_criterion='min'
)

In [13]:
trainer.evaluate(test_loader)

Evaluation: 100%|██████████| 552/552 [00:03<00:00, 156.28it/s]


{'HITS@1': 0.2869131322295305,
 'HITS@5': 0.5965071444772058,
 'HITS@10': 0.7090043093672035,
 'mean_rank': 20.68314810614652,
 'mean_reciprocal_rank': 0.4259441186798611,
 'loss': 0.038708909117765186}