Skip to content

Commit 003abd5

Browse files
authored
Improve scatter performance + upgrade to torch-scatter==2.1.0 (#338)
* update * update * update * update * update * update * update * update * update * typo * typo
1 parent c128508 commit 003abd5

File tree

7 files changed

+11
-9
lines changed

7 files changed

+11
-9
lines changed

.github/workflows/building-conda.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
# We have trouble building for Windows - drop for now.
1414
os: [ubuntu-18.04, macos-10.15] # windows-2019
1515
python-version: ['3.7', '3.8', '3.9', '3.10']
16-
torch-version: [1.13.0] # [1.12.0, 1.13.0]
16+
torch-version: [1.12.0, 1.13.0]
1717
cuda-version: ['cpu', 'cu102', 'cu113', 'cu116', 'cu117']
1818
exclude:
1919
- torch-version: 1.12.0
@@ -32,8 +32,6 @@ jobs:
3232
cuda-version: 'cu117'
3333
- os: windows-2019
3434
cuda-version: 'cu102'
35-
- os: windows-2019 # Complains about CUDA mismatch.
36-
python-version: '3.7'
3735

3836
steps:
3937
- uses: actions/checkout@v2

.github/workflows/building.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,11 @@ jobs:
6363
if: ${{ runner.os != 'macOS' }}
6464
run: |
6565
VERSION=`sed -n "s/^__version__ = '\(.*\)'/\1/p" torch_scatter/__init__.py`
66-
sed -i "s/$VERSION/$VERSION+${{ matrix.cuda-version }}/" torch_scatter/__init__.py
66+
TORCH_VERSION=`echo "pt${{ matrix.torch-version }}" | sed "s/..$//" | sed "s/\.//g"`
67+
CUDA_VERSION=`echo ${{ matrix.cuda-version }}`
68+
echo "New version name: $VERSION+$TORCH_VERSION$CUDA_VERSION"
69+
sed -i "s/$VERSION/$VERSION+$TORCH_VERSION$CUDA_VERSION/" setup.py
70+
sed -i "s/$VERSION/$VERSION+$TORCH_VERSION$CUDA_VERSION/" torch_scatter/__init__.py
6771
shell:
6872
bash
6973

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
cmake_minimum_required(VERSION 3.0)
22
project(torchscatter)
33
set(CMAKE_CXX_STANDARD 14)
4-
set(TORCHSCATTER_VERSION 2.0.9)
4+
set(TORCHSCATTER_VERSION 2.1.0)
55

66
option(WITH_CUDA "Enable CUDA support" OFF)
77
option(WITH_PYTHON "Link to Python when building" ON)

conda/pytorch-scatter/meta.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package:
22
name: pytorch-scatter
3-
version: 2.0.9
3+
version: 2.1.0
44

55
source:
66
path: ../..

csrc/cuda/scatter_cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#include "reducer.cuh"
88
#include "utils.cuh"
99

10-
#define THREADS 1024
10+
#define THREADS 256
1111
#define BLOCKS(N) (N + THREADS - 1) / THREADS
1212

1313
template <typename scalar_t, ReductionType REDUCE>

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension,
1212
CUDAExtension)
1313

14-
__version__ = '2.0.9'
14+
__version__ = '2.1.0'
1515
URL = 'https://github.com/rusty1s/pytorch_scatter'
1616

1717
WITH_CUDA = False

torch_scatter/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch
66

7-
__version__ = '2.0.9'
7+
__version__ = '2.1.0'
88

99
for library in ['_version', '_scatter', '_segment_csr', '_segment_coo']:
1010
cuda_spec = importlib.machinery.PathFinder().find_spec(

0 commit comments

Comments
 (0)