Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug in GAMNET #197

Closed
bbb801 opened this issue Jul 24, 2023 · 1 comment
Closed

Bug in GAMNET #197

bbb801 opened this issue Jul 24, 2023 · 1 comment

Comments

@bbb801
Copy link

bbb801 commented Jul 24, 2023

Dear Sir/Madam,

When I run 'drug_recommendation_mimic4_gamenet.py' in tutorials, I get an error.

Epoch 0 / 20:   0%|                                                                                               | 0/2 [00:00<?, ?it/s]queries shape torch.Size([64, 10, 128])
prev_drugs shape torch.Size([64, 10, 147])
curr_drugs shape torch.Size([64, 147])
a_s shape torch.Size([64, 9])
DM_values shape torch.Size([64, 10, 147])
Epoch 0 / 20:   0%|                                                                                               | 0/2 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/home/featurize/PyHealth/examples/drug_recommendation_mimic4_gamenet.py", line 103, in <module>
    model, trainer = train_gamenet(data, train_loader, val_loader)
  File "/home/featurize/PyHealth/examples/drug_recommendation_mimic4_gamenet.py", line 77, in train_gamenet
    trainer.train(
  File "/home/featurize/work/py38/lib/python3.8/site-packages/pyhealth/trainer.py", line 195, in train
    output = self.model(**data)
  File "/home/featurize/work/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/featurize/work/py38/lib/python3.8/site-packages/pyhealth/models/gamenet.py", line 410, in forward
    loss, y_prob = self.gamenet(queries, prev_drugs, curr_drugs, mask)
  File "/home/featurize/work/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/featurize/work/py38/lib/python3.8/site-packages/pyhealth/models/gamenet.py", line 211, in forward
    a_m = torch.einsum("bv,bvz->bz", a_s, DM_values.float())
  File "/home/featurize/work/py38/lib/python3.8/site-packages/torch/functional.py", line 378, in einsum
    return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
RuntimeError: einsum(): subscript v has size 10 for operand 1 which does not broadcast with previously seen size 9

It seems it can be addressed like in the following figure. Please have a look.
1690209021253

@ycq091044
Copy link
Collaborator

ycq091044 commented Aug 31, 2023

Hi @bbb801

Could you please modify the dynamic memory into:

        # dynamic memory
        DM_keys = queries[:, :-1, :]
        DM_values = prev_drugs[:, :-1, :]

which would make the code runs through. These two lines mean that we use the drugs from visit 1 to visit N-1, so the DM_keys and DM_values should both have :-1.

The above is for the GAMENetLayer class. Also, we have improved the GAMENet class, which replaces the target calculation blocks:

# previous block
curr_drugs = [p[-1] for p in drugs_hist]
curr_drugs = batch_to_multihot(curr_drugs, label_size)
curr_drugs = curr_drugs.to(self.device)

# new block
curr_drugs = self.prepare_labels(drugs, self.label_tokenizer)

Please refer to our current new master branch.

ycq091044 added a commit that referenced this issue Aug 31, 2023
ycq091044 added a commit that referenced this issue Aug 31, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants