Skip to content

Commit

Permalink
All devices are equal! (#196)
Browse files Browse the repository at this point in the history
* Update all devices to be tested

ANE, CPU and OCL all now support all tests.

However tests are not currently passing on GPU and I cannot test on CPU.

Failing GPU test are not an issue caused by this update. Tests have not
been passing due to a missing "six" required installation.

OpenCL Tests have not been run since commit: 1a1c63a

devices have 3 types and are handle by a new DeviceTypes enum. (The goal
is to revert to Tensor.<type>, but this current setup allows for keyword
argument defaults: `device=DeviceType.CPU`)

All references to Tensor.GPU/CPU/ANE as been converted to the
corresponding `DeviceTypes` enum.

Refactor of the conversion code to allow for any device to any device
conversion.

* Add six dependency in requirements.txt

* Resolve failure to run tests

Move six into gpu required installs. Remove six from standard
installation.

* Remove repeated data conversion

* Refactor method names

Also reduce code with .to and .to_

* Dynamic device handlers

* Refactor DeviceTypes -> Device

* Add mem copy profiling back

* test_backward_pass_diamond_model passing

* Resolve Sum issue on GPU

* Revert batchnorm2d tests

* Update README with upadated API

* ANE testing with

* Last minute line gains
  • Loading branch information
Liamdoult committed Dec 16, 2020
1 parent 78210b5 commit bcf1518
Show file tree
Hide file tree
Showing 15 changed files with 249 additions and 184 deletions.
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -84,7 +84,7 @@ tinygrad supports GPUs through PyOpenCL.

```python
from tinygrad.tensor import Tensor
(Tensor.ones(4,4).cuda() + Tensor.ones(4,4).cuda()).cpu()
(Tensor.ones(4,4).gpu() + Tensor.ones(4,4).gpu()).cpu()
```

### ANE Support?!
Expand Down
15 changes: 8 additions & 7 deletions extra/training.py
Expand Up @@ -2,21 +2,22 @@
import numpy as np
from tqdm import trange
from extra.utils import get_parameters
from tinygrad.tensor import Tensor, GPU
from tinygrad.tensor import Tensor, GPU, Device

def train(model, X_train, Y_train, optim, steps, num_classes=None, BS=128, gpu=False, lossfn = lambda out,y: out.mul(y).mean()):
if gpu is True: [x.cuda_() for x in get_parameters([model, optim])]
def train(model, X_train, Y_train, optim, steps, num_classes=None, BS=128, device=Device.CPU, lossfn = lambda out,y: out.mul(y).mean()):
if device == Device.GPU: [x.gpu_() for x in get_parameters([model, optim])]
elif device == Device.ANE: [x.ane_() for x in get_parameters([model, optim])]
if num_classes is None: num_classes = Y_train.max().astype(int)+1
losses, accuracies = [], []
for i in (t := trange(steps, disable=os.getenv('CI') is not None)):
samp = np.random.randint(0, X_train.shape[0], size=(BS))

x = Tensor(X_train[samp].reshape((-1, 28*28)).astype(np.float32), gpu=gpu)
x = Tensor(X_train[samp].reshape((-1, 28*28)).astype(np.float32), device=device)
Y = Y_train[samp]
y = np.zeros((len(samp),num_classes), np.float32)
# correct loss for NLL, torch NLL loss returns one per row
y[range(y.shape[0]),Y] = -1.0*num_classes
y = Tensor(y, gpu=gpu)
y = Tensor(y, device=device)

# network
out = model.forward(x)
Expand All @@ -36,11 +37,11 @@ def train(model, X_train, Y_train, optim, steps, num_classes=None, BS=128, gpu=F
accuracies.append(accuracy)
t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy))

def evaluate(model, X_test, Y_test, num_classes=None, gpu=False, BS=128):
def evaluate(model, X_test, Y_test, num_classes=None, device=Device.CPU, BS=128):
def numpy_eval(num_classes):
Y_test_preds_out = np.zeros((len(Y_test),num_classes))
for i in trange(len(Y_test)//BS, disable=os.getenv('CI') is not None):
Y_test_preds_out[i*BS:(i+1)*BS] = model.forward(Tensor(X_test[i*BS:(i+1)*BS].reshape((-1, 28*28)).astype(np.float32), gpu=gpu)).cpu().data
Y_test_preds_out[i*BS:(i+1)*BS] = model.forward(Tensor(X_test[i*BS:(i+1)*BS].reshape((-1, 28*28)).astype(np.float32), device=device)).cpu().data
Y_test_preds = np.argmax(Y_test_preds_out, axis=1)
return (Y_test == Y_test_preds).mean()

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -22,7 +22,7 @@
install_requires=['numpy', 'requests'],
python_requires='>=3.8',
extras_require={
'gpu': ["pyopencl"],
'gpu': ["pyopencl", "six"],
'testing': [
"pytest",
"torch",
Expand Down
Empty file added test/__init__.py
Empty file.
3 changes: 3 additions & 0 deletions test/config.py
@@ -0,0 +1,3 @@
import os

ANE = os.environ.get('ANE', False)
25 changes: 14 additions & 11 deletions test/test_gc.py
@@ -1,43 +1,46 @@
#!/usr/bin/env python
import gc
import unittest
from tinygrad.tensor import Tensor, GPU
from tinygrad.tensor import Tensor, GPU, Device
from .config import ANE

def tensors_allocated():
return sum([isinstance(x, Tensor) for x in gc.get_objects()])

class TestGC(unittest.TestCase):
gpu = False
device = Device.CPU

def test_gc(self):
a = Tensor.zeros(4,4, gpu=self.gpu)
b = Tensor.zeros(4,4, gpu=self.gpu)
a = Tensor.zeros(4,4, device=self.device)
b = Tensor.zeros(4,4, device=self.device)
(a*b).mean().backward()
assert(tensors_allocated() > 0)
del a,b
assert(tensors_allocated() == 0)

def test_gc_complex(self):
a = Tensor.zeros(4,4, gpu=self.gpu)
b = Tensor.zeros(4,4, gpu=self.gpu)
a = Tensor.zeros(4,4, device=self.device)
b = Tensor.zeros(4,4, device=self.device)
assert(tensors_allocated() == 2)
(a*b).mean().backward()
assert(tensors_allocated() == 4)
del b
assert(tensors_allocated() == 2)
b = Tensor.zeros(4,4, gpu=self.gpu)
b = Tensor.zeros(4,4, device=self.device)
print(tensors_allocated())
(a*b).mean().backward()
print(tensors_allocated())
assert(tensors_allocated() == 4)
del b
assert(tensors_allocated() == 2)


@unittest.skipUnless(GPU, "Requires GPU")
class TestGCGPU(TestGC):
device = Device.GPU

if GPU:
class TestGCGPU(TestGC):
gpu = True
@unittest.skipUnless(ANE, "Requires ANE")
class TestGCANE(TestGC):
device=Device.ANE

if __name__ == '__main__':
unittest.main()
23 changes: 14 additions & 9 deletions test/test_mnist.py
Expand Up @@ -2,10 +2,11 @@
import os
import unittest
import numpy as np
from tinygrad.tensor import Tensor, GPU
from tinygrad.tensor import Tensor, GPU, Device
import tinygrad.optim as optim
from extra.training import train, evaluate
from extra.utils import fetch, get_parameters
from .config import ANE

# mnist loader
def fetch_mnist():
Expand Down Expand Up @@ -55,32 +56,36 @@ def forward(self, x):
return x.dot(self.l1).logsoftmax()

class TestMNIST(unittest.TestCase):
gpu=False
device = Device.CPU

def test_conv(self):
np.random.seed(1337)
model = TinyConvNet()
optimizer = optim.Adam(model.parameters(), lr=0.001)
train(model, X_train, Y_train, optimizer, steps=200, gpu=self.gpu)
assert evaluate(model, X_test, Y_test, gpu=self.gpu) > 0.95
train(model, X_train, Y_train, optimizer, steps=200, device=self.device)
assert evaluate(model, X_test, Y_test, device=self.device) > 0.95

def test_sgd(self):
np.random.seed(1337)
model = TinyBobNet()
optimizer = optim.SGD(model.parameters(), lr=0.001)
train(model, X_train, Y_train, optimizer, steps=1000, gpu=self.gpu)
assert evaluate(model, X_test, Y_test, gpu=self.gpu) > 0.95
train(model, X_train, Y_train, optimizer, steps=1000, device=self.device)
assert evaluate(model, X_test, Y_test, device=self.device) > 0.95

def test_rmsprop(self):
np.random.seed(1337)
model = TinyBobNet()
optimizer = optim.RMSprop(model.parameters(), lr=0.0002)
train(model, X_train, Y_train, optimizer, steps=1000, gpu=self.gpu)
assert evaluate(model, X_test, Y_test, gpu=self.gpu) > 0.95
train(model, X_train, Y_train, optimizer, steps=1000, device=self.device)
assert evaluate(model, X_test, Y_test, device=self.device) > 0.95

@unittest.skipUnless(GPU, "Requires GPU")
class TestMNISTGPU(TestMNIST):
gpu = True
device = Device.GPU

@unittest.skipUnless(ANE, "Requires ANE")
class TestMNISTANE(TestMNIST):
device=Device.ANE

if __name__ == '__main__':
unittest.main()
21 changes: 16 additions & 5 deletions test/test_net_speed.py
Expand Up @@ -4,7 +4,8 @@
import pstats
import unittest
import torch
from tinygrad.tensor import Tensor
from tinygrad.tensor import Tensor, GPU, Device
from .config import ANE

def start_profile():
import time
Expand All @@ -20,6 +21,8 @@ def stop_profile(pr, sort='cumtime'):
ps.print_stats(0.2)

class TestConvSpeed(unittest.TestCase):
device= Device.CPU

def test_mnist(self):
# https://keras.io/examples/vision/mnist_convnet/
conv = 3
Expand Down Expand Up @@ -62,15 +65,15 @@ def test_mnist(self):

# ****** tinygrad compare *******

c1 = Tensor(c1.detach().numpy())
c2 = Tensor(c2.detach().numpy())
l1 = Tensor(l1.detach().numpy())
c1 = Tensor(c1.detach().numpy(), device=self.device)
c2 = Tensor(c2.detach().numpy(), device=self.device)
l1 = Tensor(l1.detach().numpy(), device=self.device)

cnt = 5
fpt, bpt = 0.0, 0.0
for i in range(1+cnt):
et0 = time.time()
x = Tensor.randn(128, 1, 28, 28)
x = Tensor.randn(128, 1, 28, 28, device=self.device)
x = x.conv2d(c1).relu().avg_pool2d()
x = x.conv2d(c2).relu().max_pool2d()
x = x.reshape(shape=(x.shape[0], -1))
Expand All @@ -91,6 +94,14 @@ def test_mnist(self):
print("forward pass: %.3f ms, %.2fx off baseline %.3f ms" % (fpt, fpt/fpt_baseline, fpt_baseline))
print("backward pass: %.3f ms, %.2fx off baseline %.3f ms" % (bpt, bpt/bpt_baseline, bpt_baseline))

@unittest.skipUnless(GPU, "Requires GPU")
class TestConvSpeedGPU(TestConvSpeed):
device = Device.GPU

@unittest.skipUnless(ANE, "Requires ANE")
class TestConvSpeedANE(TestConvSpeed):
device=Device.ANE


if __name__ == '__main__':
unittest.main()
Expand Down
30 changes: 28 additions & 2 deletions test/test_nn.py
@@ -1,10 +1,15 @@
#!/usr/bin/env python
import unittest
import numpy as np
from tinygrad.tensor import GPU, Device
from tinygrad.nn import *
from extra.utils import get_parameters
import torch
from .config import ANE

class TestNN(unittest.TestCase):
device = Device.CPU

def test_batchnorm2d(self, training=False):
sz = 4

Expand All @@ -29,13 +34,13 @@ def test_batchnorm2d(self, training=False):
np.testing.assert_allclose(bn.running_var.data, tbn.running_var.detach().numpy(), rtol=1e-5)

# trial
inn = Tensor.randn(2, sz, 3, 3)
inn = Tensor.randn(2, sz, 3, 3, device=self.device)

# in tinygrad
outt = bn(inn)

# in torch
toutt = tbn(torch.tensor(inn.data))
toutt = tbn(torch.tensor(inn.cpu().data))

# close
np.testing.assert_allclose(outt.data, toutt.detach().numpy(), rtol=5e-5)
Expand All @@ -48,6 +53,27 @@ def test_batchnorm2d(self, training=False):
def test_batchnorm2d_training(self):
self.test_batchnorm2d(True)

@unittest.skipUnless(GPU, "Requires GPU")
class TestNNGPU(TestNN):
device = Device.GPU

@unittest.skip("Tests not added")
def test_batchnorm2d(self): pass

@unittest.skip("Tests not added")
def test_batchnorm2d_training(self): pass


@unittest.skipUnless(ANE, "Requires ANE")
class TestNNANE(TestNN):
device=Device.ANE

@unittest.skip("Tests not added")
def test_batchnorm2d(self): pass

@unittest.skip("Tests not added")
def test_batchnorm2d_training(self): pass


if __name__ == '__main__':
unittest.main()

0 comments on commit bcf1518

Please sign in to comment.