Skip to content

Commit

Permalink
Fix metaclass conflict in Dataset (#8999)
Browse files Browse the repository at this point in the history
Fixes #8992
  • Loading branch information
rusty1s committed Mar 1, 2024
1 parent 9b660ac commit 8d625c4
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 15 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed metaclass conflict in `Dataset` ([#8999](https://github.com/pyg-team/pytorch_geometric/pull/8999))
- Fixed import errors on `MessagePassing` modules with nested inheritance ([#8973](https://github.com/pyg-team/pytorch_geometric/pull/8973))

### Removed
Expand Down
5 changes: 1 addition & 4 deletions torch_geometric/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import re
import sys
import warnings
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import (
Any,
Expand All @@ -27,7 +26,7 @@
MISSING = '???'


class Dataset(torch.utils.data.Dataset, ABC):
class Dataset(torch.utils.data.Dataset):
r"""Dataset base class for creating graph datasets.
See `here <https://pytorch-geometric.readthedocs.io/en/latest/tutorial/
create_dataset.html>`__ for the accompanying tutorial.
Expand Down Expand Up @@ -79,12 +78,10 @@ def process(self) -> None:
r"""Processes the dataset to the :obj:`self.processed_dir` folder."""
raise NotImplementedError

@abstractmethod
def len(self) -> int:
r"""Returns the number of data objects stored in the dataset."""
raise NotImplementedError

@abstractmethod
def get(self, idx: int) -> BaseData:
r"""Gets the data object at index :obj:`idx`."""
raise NotImplementedError
Expand Down
3 changes: 1 addition & 2 deletions torch_geometric/data/in_memory_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import copy
import os.path as osp
import warnings
from abc import ABC
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -30,7 +29,7 @@
from torch_geometric.io import fs


class InMemoryDataset(Dataset, ABC):
class InMemoryDataset(Dataset):
r"""Dataset base class for creating graph datasets which easily fit
into CPU memory.
See `here <https://pytorch-geometric.readthedocs.io/en/latest/tutorial/
Expand Down
4 changes: 1 addition & 3 deletions torch_geometric/metrics/link_pred.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from abc import ABC, abstractmethod
from typing import Optional, Tuple, Union

import torch
Expand All @@ -15,7 +14,7 @@
BaseMetric = torch.nn.Module # type: ignore


class LinkPredMetric(BaseMetric, ABC):
class LinkPredMetric(BaseMetric):
r"""An abstract class for computing link prediction retrieval metrics.
Args:
Expand Down Expand Up @@ -117,7 +116,6 @@ def reset(self) -> None:
self.accum.zero_()
self.total.zero_()

@abstractmethod
def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
r"""Compute the specific metric.
To be implemented separately for each metric class.
Expand Down
9 changes: 3 additions & 6 deletions torch_geometric/nn/models/rev_gnn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import copy
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Union

import numpy as np
Expand Down Expand Up @@ -145,7 +144,7 @@ def backward(ctx, *grad_outputs):
return (None, None, None, None) + gradients


class InvertibleModule(torch.nn.Module, ABC):
class InvertibleModule(torch.nn.Module):
r"""An abstract class for implementing invertible modules.
Args:
Expand All @@ -168,13 +167,11 @@ def forward(self, *args):
def inverse(self, *args):
return self._fn_apply(args, self._inverse, self._forward)

@abstractmethod
def _forward(self):
pass
raise NotImplementedError

@abstractmethod
def _inverse(self):
pass
raise NotImplementedError

def _fn_apply(self, args, fn, fn_inverse):
if not self.disable:
Expand Down

0 comments on commit 8d625c4

Please sign in to comment.