Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] authored and Kh4L committed Mar 30, 2024
1 parent eaab070 commit b5dae1a
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions examples/multi_gpu/pcqm4m_ogb.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@
print("`PygPCQM4Mv2Dataset` requires rdkit (`pip install rdkit`)")
raise error

from torch_geometric.datasets import PCQM4Mv2
from ogb.utils import smiles2graph

from torch_geometric.data import Data
from torch_geometric.datasets import PCQM4Mv2

from ogb.utils import smiles2graph

# we need this wrapper as ogb `smiles2graph` returns a dict of np arrays instead of a torch_geometric Data
# we need this wrapper as ogb `smiles2graph` returns a dict of np arrays instead of a torch_geometric Data
def ogb_from_smiles_wrapper(smiles, *args, **kwargs):
ret_dict = smiles2graph(smiles, *args, **kwargs)
return Data(x=torch.from_numpy(ret_dict['node_feat']),
Expand Down Expand Up @@ -459,10 +460,14 @@ def run(rank, dataset, args):

if rank == 0:
if args.on_disk_dataset:
valid_dataset = PCQM4Mv2(root='on_disk_dataset/', split="val", from_smiles_func=ogb_from_smiles_wrapper)
test_dev_dataset = PCQM4Mv2(root='on_disk_dataset/', split="test", from_smiles_func=ogb_from_smiles_wrapper)
test_challenge_dataset = PCQM4Mv2(root='on_disk_dataset/',
split="holdout", from_smiles_func=ogb_from_smiles_wrapper)
valid_dataset = PCQM4Mv2(root='on_disk_dataset/', split="val",
from_smiles_func=ogb_from_smiles_wrapper)
test_dev_dataset = PCQM4Mv2(
root='on_disk_dataset/', split="test",
from_smiles_func=ogb_from_smiles_wrapper)
test_challenge_dataset = PCQM4Mv2(
root='on_disk_dataset/', split="holdout",
from_smiles_func=ogb_from_smiles_wrapper)
else:
valid_dataset = dataset[split_idx["valid"]]
test_dev_dataset = dataset[split_idx["test-dev"]]
Expand Down Expand Up @@ -652,7 +657,8 @@ def run(rank, dataset, args):

### automatic dataloading and splitting
if args.on_disk_dataset:
dataset = PCQM4Mv2(root='on_disk_dataset/', split='train', from_smiles_func=ogb_from_smiles_wrapper)
dataset = PCQM4Mv2(root='on_disk_dataset/', split='train',
from_smiles_func=ogb_from_smiles_wrapper)
else:
dataset = PygPCQM4Mv2Dataset(root='dataset/')

Expand Down

0 comments on commit b5dae1a

Please sign in to comment.