Skip to content

Commit

Permalink
Use ogb smiles2graph
Browse files Browse the repository at this point in the history
  • Loading branch information
Kh4L committed Mar 30, 2024
1 parent 9f021b2 commit eaab070
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions examples/multi_gpu/pcqm4m_ogb.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
global_add_pool,
global_mean_pool,
)
from torch_geometric.utils import degree
from torch_geometric.utils import degree, from_smiles

### importing OGB-LSC
try:
Expand All @@ -31,6 +31,17 @@
raise error

from torch_geometric.datasets import PCQM4Mv2
from torch_geometric.data import Data

from ogb.utils import smiles2graph

# we need this wrapper as ogb `smiles2graph` returns a dict of np arrays instead of a torch_geometric Data
def ogb_from_smiles_wrapper(smiles, *args, **kwargs):
ret_dict = smiles2graph(smiles, *args, **kwargs)
return Data(x=torch.from_numpy(ret_dict['node_feat']),
edge_index=torch.from_numpy(ret_dict['edge_index']),
edge_attr=torch.from_numpy(ret_dict['edge_feat']),
smiles=smiles)


### GIN convolution along the graph structure
Expand Down Expand Up @@ -448,10 +459,10 @@ def run(rank, dataset, args):

if rank == 0:
if args.on_disk_dataset:
valid_dataset = PCQM4Mv2(root='on_disk_dataset/', split="val")
test_dev_dataset = PCQM4Mv2(root='on_disk_dataset/', split="test")
valid_dataset = PCQM4Mv2(root='on_disk_dataset/', split="val", from_smiles_func=ogb_from_smiles_wrapper)
test_dev_dataset = PCQM4Mv2(root='on_disk_dataset/', split="test", from_smiles_func=ogb_from_smiles_wrapper)
test_challenge_dataset = PCQM4Mv2(root='on_disk_dataset/',
split="holdout")
split="holdout", from_smiles_func=ogb_from_smiles_wrapper)
else:
valid_dataset = dataset[split_idx["valid"]]
test_dev_dataset = dataset[split_idx["test-dev"]]
Expand Down Expand Up @@ -641,7 +652,7 @@ def run(rank, dataset, args):

### automatic dataloading and splitting
if args.on_disk_dataset:
dataset = PCQM4Mv2(root='on_disk_dataset/', split='train')
dataset = PCQM4Mv2(root='on_disk_dataset/', split='train', from_smiles_func=ogb_from_smiles_wrapper)
else:
dataset = PygPCQM4Mv2Dataset(root='dataset/')

Expand Down

0 comments on commit eaab070

Please sign in to comment.