-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Closed
Description
🐛 Bug
To Reproduce
This code:
import time
import torch
import numpy as np
data = [np.random.rand(8, 800, 1333) > 0.5 for _ in range(2)]
dtype_np = np.bool
dtype_pt = torch.bool
def f1():
return np.asarray(data, dtype=dtype_np)
def f2():
return np.stack(data)
def f3():
return torch.as_tensor(data, dtype=dtype_pt)
def f4():
return torch.stack([torch.from_numpy(x) for x in data])
def benchmark(f, iter, warmup):
for k in range(warmup): f()
start = time.perf_counter()
for k in range(iter): f()
torch.cuda.synchronize()
return time.perf_counter() - start
print(benchmark(f1, 10, 1))
print(benchmark(f2, 10, 1))
print(benchmark(f3, 10, 1))
print(benchmark(f4, 10, 1))
prints
0.2839105408638716
0.013459203764796257
20.43221192806959
0.012655530124902725
I don't expect torch.as_tensor
to be 100x slower than np.asarray()
.
Environment
PyTorch version: 1.2.0.dev20190805
Is debug build: No
CUDA used to build PyTorch: 10.0.130
OS: Ubuntu 18.04.2 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04) 7.4.0
CMake version: version 3.12.2
Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration:
GPU 0: Quadro GP100
GPU 1: Quadro GP100
Nvidia driver version: 410.79
cuDNN version: Could not collect
Versions of relevant libraries:
[pip3] msgpack-numpy==0.4.4.3
[pip3] numpy==1.16.4
[pip3] numpydoc==0.7.0
[pip3] torch-nightly==1.2.0.dev20190805
[pip3] torchvision==0.4.0a0+6c7189f
[conda] blas 1.0 mkl
[conda] mkl 2019.1 144
[conda] mkl-include 2019.1 144
[conda] mkl-service 1.1.2 py36he904b0f_5
[conda] mkl_fft 1.0.6 py36hd81dba3_0
[conda] mkl_random 1.0.2 py36hd81dba3_0
Metadata
Metadata
Assignees
Labels
No labels