Skip to content

Commit

Permalink
Custom from_smiles function in PCQM4Mv2 and MoleculeNet (#9073)
Browse files Browse the repository at this point in the history
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
Kh4L and rusty1s committed Mar 24, 2024
1 parent f0e4c82 commit b678124
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 10 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added option to pass custom` from_smiles` functionality to `PCQM4Mv2` and `MoleculeNet` ([#9073](https://github.com/pyg-team/pytorch_geometric/pull/9073))
- Added `group_cat` functionality ([#9029](https://github.com/pyg-team/pytorch_geometric/pull/9029))
- Added support for `EdgeIndex` in `spmm` ([#9026](https://github.com/pyg-team/pytorch_geometric/pull/9026))
- Added option to pre-allocate memory in GPU-based `ApproxKNN` ([#9046](https://github.com/pyg-team/pytorch_geometric/pull/9046))
Expand Down
12 changes: 6 additions & 6 deletions test/nn/pool/test_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
L2KNNIndex,
MIPSKNNIndex,
)
from torch_geometric.testing import withDevice, withPackage
from torch_geometric.testing import withCUDA, withPackage


@withDevice
@withCUDA
@withPackage('faiss')
@pytest.mark.parametrize('k', [2])
def test_l2(device, k):
Expand All @@ -34,7 +34,7 @@ def test_l2(device, k):
assert torch.equal(out.index, index[:, :k])


@withDevice
@withCUDA
@withPackage('faiss')
@pytest.mark.parametrize('k', [2])
def test_mips(device, k):
Expand All @@ -58,7 +58,7 @@ def test_mips(device, k):
assert torch.equal(out.index, index[:, :k])


@withDevice
@withCUDA
@withPackage('faiss')
@pytest.mark.parametrize('k', [2])
@pytest.mark.parametrize('reserve', [None, 100])
Expand All @@ -82,7 +82,7 @@ def test_approx_l2(device, k, reserve):
assert out.index.min() >= 0 and out.index.max() < 10_000


@withDevice
@withCUDA
@withPackage('faiss')
@pytest.mark.parametrize('k', [2])
@pytest.mark.parametrize('reserve', [None, 100])
Expand All @@ -106,7 +106,7 @@ def test_approx_mips(device, k, reserve):
assert out.index.min() >= 0 and out.index.max() < 10_000


@withDevice
@withCUDA
@withPackage('faiss')
@pytest.mark.parametrize('k', [50])
def test_mips_exclude(device, k):
Expand Down
10 changes: 8 additions & 2 deletions torch_geometric/datasets/molecule_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from torch_geometric.data import InMemoryDataset, download_url, extract_gz
from torch_geometric.utils import from_smiles
from torch_geometric.utils import from_smiles as _from_smiles


class MoleculeNet(InMemoryDataset):
Expand Down Expand Up @@ -38,6 +38,10 @@ class MoleculeNet(InMemoryDataset):
final dataset. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
from_smiles (callable, optional): A custom function that takes a SMILES
string and outputs a :obj:`~torch_geometric.data.Data` object.
If not set, defaults to :meth:`~torch_geometric.utils.from_smiles`.
(default: :obj:`None`)
**STATS:**
Expand Down Expand Up @@ -152,9 +156,11 @@ def __init__(
pre_transform: Optional[Callable] = None,
pre_filter: Optional[Callable] = None,
force_reload: bool = False,
from_smiles: Optional[Callable] = None,
) -> None:
self.name = name.lower()
assert self.name in self.names.keys()
self.from_smiles = from_smiles or _from_smiles
super().__init__(root, transform, pre_transform, pre_filter,
force_reload=force_reload)
self.load(self.processed_paths[0])
Expand Down Expand Up @@ -199,7 +205,7 @@ def process(self) -> None:
ys = [float(y) if len(y) > 0 else float('NaN') for y in labels]
y = torch.tensor(ys, dtype=torch.float).view(1, -1)

data = from_smiles(smiles)
data = self.from_smiles(smiles)
data.y = y

if self.pre_filter is not None and not self.pre_filter(data):
Expand Down
10 changes: 8 additions & 2 deletions torch_geometric/datasets/pcqm4m.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from torch_geometric.data import Data, OnDiskDataset, download_url, extract_zip
from torch_geometric.data.data import BaseData
from torch_geometric.utils import from_smiles
from torch_geometric.utils import from_smiles as _from_smiles


class PCQM4Mv2(OnDiskDataset):
Expand Down Expand Up @@ -36,6 +36,10 @@ class PCQM4Mv2(OnDiskDataset):
(default: :obj:`None`)
backend (str): The :class:`Database` backend to use.
(default: :obj:`"sqlite"`)
from_smiles (callable, optional): A custom function that takes a SMILES
string and outputs a :obj:`~torch_geometric.data.Data` object.
If not set, defaults to :meth:`~torch_geometric.utils.from_smiles`.
(default: :obj:`None`)
"""
url = ('https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/'
'pcqm4m-v2.zip')
Expand All @@ -53,6 +57,7 @@ def __init__(
split: str = 'train',
transform: Optional[Callable] = None,
backend: str = 'sqlite',
from_smiles: Optional[Callable] = None,
) -> None:
assert split in ['train', 'val', 'test', 'holdout']

Expand All @@ -64,6 +69,7 @@ def __init__(
'y': float,
}

self.from_smiles = from_smiles or _from_smiles
super().__init__(root, transform, backend=backend, schema=schema)

split_idx = torch.load(self.raw_paths[1])
Expand All @@ -89,7 +95,7 @@ def process(self) -> None:
data_list: List[Data] = []
iterator = enumerate(zip(df['smiles'], df['homolumogap']))
for i, (smiles, y) in tqdm(iterator, total=len(df)):
data = from_smiles(smiles)
data = self.from_smiles(smiles)
data.y = y

data_list.append(data)
Expand Down

0 comments on commit b678124

Please sign in to comment.