Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into test
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthias Barth committed Feb 25, 2020
2 parents 0c4aaf0 + 35003e2 commit a72fd98
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 84 deletions.
53 changes: 24 additions & 29 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,39 +1,34 @@
jobs:
include:
- os: linux
language: python
python: 3.7
addons:
apt:
sources:
- ubuntu-toolchain-r-test
packages:
- gcc-5
- g++-5
env:
- CC=gcc-5
- CXX=g++-5
language: shell

os:
- linux
- osx
- windows

env:
jobs:
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cpu

install:
- pip install numpy
- pip install torch==1.4.0+cpu torchvision==0.5.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
- pip install torch-scatter==latest+cpu -f https://s3.eu-central-1.amazonaws.com/pytorch-geometric.com/whl/torch-1.4.0.html
- pip install torch-sparse==latest+cpu -f https://s3.eu-central-1.amazonaws.com/pytorch-geometric.com/whl/torch-1.4.0.html
- source script/conda.sh
- conda create --yes -n test python="${PYTHON_VERSION}"
- source activate test
- conda install pytorch=${TORCH_VERSION} torchvision ${TOOLKIT} -c pytorch --yes
- pip install torch-scatter==latest+${IDX} -f https://pytorch-geometric.com/whl/torch-${TORCH_VERSION}.html
- pip install torch-sparse==latest+${IDX} -f https://pytorch-geometric.com/whl/torch-${TORCH_VERSION}.html
- pip install torch-cluster
- pip install torch-spline-conv
- pip install cython && pip install gdist
- pip install cython
- git clone https://github.com/the-virtual-brain/tvb-geodesic /tmp/gdist
- cd /tmp/gdist && python setup.py install && cd -
- pip install trimesh
- pip install pycodestyle
- pip install flake8
- pip install codecov
- pip install sphinx
- pip install sphinx_rtd_theme
- pip install flake8 codecov
- pip install sphinx sphinx_rtd_theme
- python setup.py install
script:
- python -c "import torch; print(torch.__version__)"
- pycodestyle .
- flake8 .
- python setup.py install
- python setup.py test
- cd docs && make clean && make html && cd ..
- if [ "${TRAVIS_OS_NAME}" = "linux" ]; then cd docs && make clean && make html && cd ..; fi
after_success:
- codecov
notifications:
Expand Down
16 changes: 8 additions & 8 deletions examples/cluster_gcn.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Reddit
from torch_geometric.data import ClusterDataset, ClusterLoader
from torch_geometric.data import ClusterData, ClusterLoader
from torch_geometric.nn import SAGEConv

dataset = Reddit('../data/Reddit')
data = dataset[0]

print('Partioning the graph... (this may take a while)')
cluster_dataset = ClusterDataset(dataset, num_parts=1500, save=True)
train_loader = ClusterLoader(cluster_dataset, batch_size=20, shuffle=True,
drop_last=True, num_workers=6)
test_loader = ClusterLoader(cluster_dataset, batch_size=20, shuffle=False,
num_workers=6)
cluster_data = ClusterData(data, num_parts=1500, recursive=False,
save_dir=dataset.processed_dir)
loader = ClusterLoader(cluster_data, batch_size=20, shuffle=True,
num_workers=6)
print('Done!')


Expand All @@ -37,7 +37,7 @@ def forward(self, x, edge_index):
def train():
model.train()
total_loss = total_nodes = 0
for data in train_loader:
for data in loader:
data = data.to(device)
optimizer.zero_grad()
logits = model(data.x, data.edge_index)
Expand All @@ -56,7 +56,7 @@ def train():
def test():
model.eval()
total_correct, total_nodes = [0, 0, 0], [0, 0, 0]
for data in test_loader:
for data in loader:
data = data.to(device)
logits = model(data.x, data.edge_index)
pred = logits.argmax(dim=1)
Expand Down
39 changes: 39 additions & 0 deletions script/conda.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/bin/bash

if [ "${TRAVIS_OS_NAME}" = "linux" ]; then
wget -nv https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh
chmod +x miniconda.sh
./miniconda.sh -b
PATH=/home/travis/miniconda3/bin:${PATH}
fi

if [ "${TRAVIS_OS_NAME}" = "osx" ]; then
wget -nv https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh -O miniconda.sh
chmod +x miniconda.sh
./miniconda.sh -b
PATH=/Users/travis/miniconda3/bin:${PATH}
fi


if [ "${TRAVIS_OS_NAME}" = "windows" ]; then
choco install openssl.light
choco install miniconda3
PATH=/c/tools/miniconda3/Scripts:$PATH
fi

conda update --yes conda

conda create --yes -n test python="${PYTHON_VERSION}"

if [ "${TRAVIS_OS_NAME}" = "linux" ]; then
export TOOLKIT=cpuonly
fi

if [ "${TRAVIS_OS_NAME}" = "windows" ]; then
export TOOLKIT=cpuonly
export PYTHONHTTPSVERIFY=0
fi

if [ "${TRAVIS_OS_NAME}" = "osx" ]; then
export TOOLKIT=""
fi
4 changes: 2 additions & 2 deletions torch_geometric/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .in_memory_dataset import InMemoryDataset
from .dataloader import DataLoader, DataListLoader, DenseDataLoader
from .sampler import NeighborSampler
from .cluster import ClusterDataset, ClusterLoader
from .cluster import ClusterData, ClusterLoader
from .download import download_url
from .extract import extract_tar, extract_zip, extract_bz2, extract_gz

Expand All @@ -17,7 +17,7 @@
'DataListLoader',
'DenseDataLoader',
'NeighborSampler',
'ClusterDataset',
'ClusterData',
'ClusterLoader',
'download_url',
'extract_tar',
Expand Down
109 changes: 68 additions & 41 deletions torch_geometric/data/cluster.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

import copy
import os.path as osp

Expand All @@ -6,56 +8,72 @@
from torch_sparse import SparseTensor, cat


class ClusterDataset(torch.utils.data.Dataset):
def __init__(self, dataset, num_parts, save=True):
assert len(dataset) == 1
assert (dataset[0].edge_index is not None)
class ClusterData(torch.utils.data.Dataset):
r"""Clusters/partitions a graph data object into multiple subgraphs, as
motivated by the `"Cluster-GCN: An Efficient Algorithm for Training Deep
and Large Graph Convolutional Networks"
<https://arxiv.org/abs/1905.07953>`_ paper.
Args:
data (torch_geometric.data.Data): The graph data object.
num_parts (int): The number of partitions.
recursive (bool, optional): If set to :obj:`True`, will use multilevel
recursive bisection instead of multilevel k-way partitioning.
(default: :obj:`False`)
save_dir (string, optional): If set, will save the partitioned data to
the :obj:`save_dir` directory for faster re-use.
"""
def __init__(self, data, num_parts, recursive=False, save_dir=None):
assert (data.edge_index is not None)

self.dataset = dataset
self.num_parts = num_parts
self.save = osp.exists(self.dataset.processed_dir) and save
self.recursive = recursive
self.save_dir = save_dir

self.process()
self.process(data)

def process(self):
filename = f'part_data_{self.num_parts}.pt'
def process(self, data):
recursive = '_recursive' if self.recursive else ''
filename = f'part_data_{self.num_parts}{recursive}.pt'

path = osp.join(self.dataset.processed_dir, filename)
if self.save and osp.exists(path):
path = osp.join(self.save_dir or '', filename)
if self.save_dir is not None and osp.exists(path):
data, partptr, perm = torch.load(path)
else:
data = copy.copy(self.dataset.get(0))
data = copy.copy(data)
num_nodes = data.num_nodes

(row, col), edge_attr = data.edge_index, data.edge_attr
adj = SparseTensor(row=row, col=col, value=edge_attr,
is_sorted=True)
adj, partptr, perm = adj.partition_kway(self.num_parts)
adj = SparseTensor(row=row, col=col, value=edge_attr)
adj, partptr, perm = adj.partition(self.num_parts, self.recursive)

for key, item in data:
if item.size(0) == data.num_nodes:
if item.size(0) == num_nodes:
data[key] = item[perm]

data.edge_index = None
data.edge_attr = None
data.adj = adj

if self.save:
if self.save_dir is not None:
torch.save((data, partptr, perm), path)

self.__data__ = data
self.__perm__ = perm
self.__partptr__ = partptr
self.data = data
self.perm = perm
self.partptr = partptr

def __len__(self):
return self.__partptr__.numel() - 1
return self.partptr.numel() - 1

def __getitem__(self, idx):
start = int(self.__partptr__[idx])
length = int(self.__partptr__[idx + 1]) - start
start = int(self.partptr[idx])
length = int(self.partptr[idx + 1]) - start

data = copy.copy(self.data)
num_nodes = data.num_nodes

data = copy.copy(self.__data__)
for key, item in data:
if item.size(0) == data.num_nodes:
if item.size(0) == num_nodes:
data[key] = item.narrow(0, start, length)

data.adj = data.adj.narrow(1, start, length)
Expand All @@ -65,37 +83,49 @@ def __getitem__(self, idx):
data.edge_index = torch.stack([row, col], dim=0)
data.edge_attr = value

if self.dataset.transform is not None:
data = self.dataset.transform(data)

return data

def __repr__(self):
return (f'{self.__class__.__name__}({self.dataset}, '
return (f'{self.__class__.__name__}({self.data}, '
f'num_parts={self.num_parts})')


class ClusterLoader(torch.utils.data.DataLoader):
def __init__(self, cluster_dataset, batch_size=1, shuffle=False, **kwargs):
r"""The data loader scheme from the `"Cluster-GCN: An Efficient Algorithm
for Training Deep and Large Graph Convolutional Networks"
<https://arxiv.org/abs/1905.07953>`_ paper which merges partioned subgraphs
and their between-cluster links from a large-scale graph data object to
form a mini-batch.
Args:
cluster_data (torch_geometric.data.ClusterData): The already
partioned data object.
batch_size (int, optional): How many samples per batch to load.
(default: :obj:`1`)
shuffle (bool, optional): If set to :obj:`True`, the data will be
reshuffled at every epoch. (default: :obj:`False`)
"""
def __init__(self, cluster_data, batch_size=1, shuffle=False, **kwargs):
class HelperDataset(torch.utils.data.Dataset):
def __len__(self):
return len(cluster_dataset)
return len(cluster_data)

def __getitem__(self, idx):
start = int(cluster_dataset.__partptr__[idx])
length = int(cluster_dataset.__partptr__[idx + 1]) - start
start = int(cluster_data.partptr[idx])
length = int(cluster_data.partptr[idx + 1]) - start

data = copy.copy(cluster_dataset.__data__)
data = copy.copy(cluster_data.data)
num_nodes = data.num_nodes
for key, item in data:
if item.size(0) == num_nodes:
data[key] = item.narrow(0, start, length)

return data, idx

def collate(data):
data_list, parts = [d[0] for d in data], [d[1] for d in data]
partptr = cluster_dataset.__partptr__
def collate(batch):
data_list = [data[0] for data in batch]
parts: List[int] = [data[1] for data in batch]
partptr = cluster_data.partptr

adj = cat([data.adj for data in data_list], dim=0)

Expand All @@ -108,7 +138,7 @@ def collate(data):
adj = cat(adjs, dim=0).t()
row, col, value = adj.coo()

data = cluster_dataset.__data__.__class__()
data = cluster_data.data.__class__()
data.num_nodes = adj.size(0)
data.edge_index = torch.stack([row, col], dim=0)
data.edge_attr = value
Expand All @@ -124,9 +154,6 @@ def collate(data):
data[key] = torch.cat([d[key] for d in data_list],
dim=ref.__cat_dim__(key, ref[key]))

if cluster_dataset.dataset.transform is not None:
data = cluster_dataset.dataset.transform(data)

return data

super(ClusterLoader,
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/datasets/citation_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def __repr__(self):


class CoraFull(CitationFull):
r"""Alias for :class:`torch_geometric.dataset.CitationFull`:obj:``("cora)`.
"""
r"""Alias for :class:`torch_geometric.dataset.CitationFull` with
:obj:`name="cora"`."""
def __init__(self, root, transform=None, pre_transform=None):
super(CoraFull, self).__init__(root, 'cora', transform, pre_transform)

Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/datasets/tu_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class TUDataset(InMemoryDataset):
contain only non-isomorphic graphs. (default: :obj:`False`)
"""

url = ('https://ls11-www.cs.tu-dortmund.de/people/morris/'
url = ('http://ls11-www.cs.tu-dortmund.de/people/morris/'
'graphkerneldatasets')
cleaned_url = ('https://raw.githubusercontent.com/nd7141/'
'graph_datasets/master/datasets')
Expand Down
5 changes: 4 additions & 1 deletion torch_geometric/transforms/gdc.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,10 @@ def __calculate_eps__(self, matrix, num_nodes, avg_degree):
sorted_edges = torch.sort(matrix.flatten(), descending=True).values
if avg_degree * num_nodes > len(sorted_edges):
return -np.inf
return sorted_edges[avg_degree * num_nodes - 1]

left = sorted_edges[avg_degree * num_nodes - 1]
right = sorted_edges[avg_degree * num_nodes]
return (left + right) / 2.0

def __neighbors_to_graph__(self, neighbors, neighbor_weights,
normalization='row', device='cpu'):
Expand Down

0 comments on commit a72fd98

Please sign in to comment.