Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

All devices are equal! #196

Merged
merged 15 commits into from Dec 16, 2020
14 changes: 7 additions & 7 deletions extra/training.py
Expand Up @@ -2,21 +2,21 @@
import numpy as np
from tqdm import trange
from extra.utils import get_parameters
from tinygrad.tensor import Tensor, GPU
from tinygrad.tensor import Tensor, GPU, DeviceTypes
Liamdoult marked this conversation as resolved.
Show resolved Hide resolved

def train(model, X_train, Y_train, optim, steps, num_classes=None, BS=128, gpu=False, lossfn = lambda out,y: out.mul(y).mean()):
if gpu is True: [x.cuda_() for x in get_parameters([model, optim])]
def train(model, X_train, Y_train, optim, steps, num_classes=None, BS=128, device=DeviceTypes.CPU, lossfn = lambda out,y: out.mul(y).mean()):
if device == DeviceTypes.GPU: [x.cuda_() for x in get_parameters([model, optim])]
if num_classes is None: num_classes = Y_train.max().astype(int)+1
losses, accuracies = [], []
for i in (t := trange(steps, disable=os.getenv('CI') is not None)):
samp = np.random.randint(0, X_train.shape[0], size=(BS))

x = Tensor(X_train[samp].reshape((-1, 28*28)).astype(np.float32), gpu=gpu)
x = Tensor(X_train[samp].reshape((-1, 28*28)).astype(np.float32), device=device)
Y = Y_train[samp]
y = np.zeros((len(samp),num_classes), np.float32)
# correct loss for NLL, torch NLL loss returns one per row
y[range(y.shape[0]),Y] = -1.0*num_classes
y = Tensor(y, gpu=gpu)
y = Tensor(y, device=device)

# network
out = model.forward(x)
Expand All @@ -36,11 +36,11 @@ def train(model, X_train, Y_train, optim, steps, num_classes=None, BS=128, gpu=F
accuracies.append(accuracy)
t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy))

def evaluate(model, X_test, Y_test, num_classes=None, gpu=False, BS=128):
def evaluate(model, X_test, Y_test, num_classes=None, device=DeviceTypes.CPU, BS=128):
def numpy_eval(num_classes):
Y_test_preds_out = np.zeros((len(Y_test),num_classes))
for i in trange(len(Y_test)//BS, disable=os.getenv('CI') is not None):
Y_test_preds_out[i*BS:(i+1)*BS] = model.forward(Tensor(X_test[i*BS:(i+1)*BS].reshape((-1, 28*28)).astype(np.float32), gpu=gpu)).cpu().data
Y_test_preds_out[i*BS:(i+1)*BS] = model.forward(Tensor(X_test[i*BS:(i+1)*BS].reshape((-1, 28*28)).astype(np.float32), device=device)).cpu().data
Y_test_preds = np.argmax(Y_test_preds_out, axis=1)
return (Y_test == Y_test_preds).mean()

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -22,7 +22,7 @@
install_requires=['numpy', 'requests'],
python_requires='>=3.8',
extras_require={
'gpu': ["pyopencl"],
'gpu': ["pyopencl", "six"],
Liamdoult marked this conversation as resolved.
Show resolved Hide resolved
'testing': [
"pytest",
"torch",
Expand Down
24 changes: 13 additions & 11 deletions test/test_gc.py
@@ -1,43 +1,45 @@
#!/usr/bin/env python
import gc
import unittest
from tinygrad.tensor import Tensor, GPU
from tinygrad.tensor import Tensor, ANE, GPU, DeviceTypes

def tensors_allocated():
return sum([isinstance(x, Tensor) for x in gc.get_objects()])

class TestGC(unittest.TestCase):
gpu = False
device = DeviceTypes.CPU

def test_gc(self):
a = Tensor.zeros(4,4, gpu=self.gpu)
b = Tensor.zeros(4,4, gpu=self.gpu)
a = Tensor.zeros(4,4, device=self.device)
b = Tensor.zeros(4,4, device=self.device)
(a*b).mean().backward()
assert(tensors_allocated() > 0)
del a,b
assert(tensors_allocated() == 0)

def test_gc_complex(self):
a = Tensor.zeros(4,4, gpu=self.gpu)
b = Tensor.zeros(4,4, gpu=self.gpu)
a = Tensor.zeros(4,4, device=self.device)
b = Tensor.zeros(4,4, device=self.device)
assert(tensors_allocated() == 2)
(a*b).mean().backward()
assert(tensors_allocated() == 4)
del b
assert(tensors_allocated() == 2)
b = Tensor.zeros(4,4, gpu=self.gpu)
b = Tensor.zeros(4,4, device=self.device)
print(tensors_allocated())
(a*b).mean().backward()
print(tensors_allocated())
assert(tensors_allocated() == 4)
del b
assert(tensors_allocated() == 2)


@unittest.skipUnless(GPU, "Requires GPU")
class TestGCGPU(TestGC):
device = DeviceTypes.GPU

if GPU:
class TestGCGPU(TestGC):
gpu = True
@unittest.skipUnless(ANE, "Requires ANE")
class TestGCANE(TestGC):
device=DeviceTypes.ANE

if __name__ == '__main__':
unittest.main()
22 changes: 13 additions & 9 deletions test/test_mnist.py
Expand Up @@ -2,7 +2,7 @@
import os
import unittest
import numpy as np
from tinygrad.tensor import Tensor, GPU
from tinygrad.tensor import Tensor, ANE, GPU, DeviceTypes
import tinygrad.optim as optim
from extra.training import train, evaluate
from extra.utils import fetch, get_parameters
Expand Down Expand Up @@ -55,32 +55,36 @@ def forward(self, x):
return x.dot(self.l1).logsoftmax()

class TestMNIST(unittest.TestCase):
gpu=False
device = DeviceTypes.CPU

def test_conv(self):
np.random.seed(1337)
model = TinyConvNet()
optimizer = optim.Adam(model.parameters(), lr=0.001)
train(model, X_train, Y_train, optimizer, steps=200, gpu=self.gpu)
assert evaluate(model, X_test, Y_test, gpu=self.gpu) > 0.95
train(model, X_train, Y_train, optimizer, steps=200, device=self.device)
assert evaluate(model, X_test, Y_test, device=self.device) > 0.95

def test_sgd(self):
np.random.seed(1337)
model = TinyBobNet()
optimizer = optim.SGD(model.parameters(), lr=0.001)
train(model, X_train, Y_train, optimizer, steps=1000, gpu=self.gpu)
assert evaluate(model, X_test, Y_test, gpu=self.gpu) > 0.95
train(model, X_train, Y_train, optimizer, steps=1000, device=self.device)
assert evaluate(model, X_test, Y_test, device=self.device) > 0.95

def test_rmsprop(self):
np.random.seed(1337)
model = TinyBobNet()
optimizer = optim.RMSprop(model.parameters(), lr=0.0002)
train(model, X_train, Y_train, optimizer, steps=1000, gpu=self.gpu)
assert evaluate(model, X_test, Y_test, gpu=self.gpu) > 0.95
train(model, X_train, Y_train, optimizer, steps=1000, device=self.device)
assert evaluate(model, X_test, Y_test, device=self.device) > 0.95

@unittest.skipUnless(GPU, "Requires GPU")
class TestMNISTGPU(TestMNIST):
gpu = True
device = DeviceTypes.GPU

@unittest.skipUnless(ANE, "Requires ANE")
class TestMNISTANE(TestMNIST):
device=DeviceTypes.ANE

if __name__ == '__main__':
unittest.main()
20 changes: 15 additions & 5 deletions test/test_net_speed.py
Expand Up @@ -4,7 +4,7 @@
import pstats
import unittest
import torch
from tinygrad.tensor import Tensor
from tinygrad.tensor import Tensor, ANE, GPU, DeviceTypes

def start_profile():
import time
Expand All @@ -20,6 +20,8 @@ def stop_profile(pr, sort='cumtime'):
ps.print_stats(0.2)

class TestConvSpeed(unittest.TestCase):
device= DeviceTypes.CPU

def test_mnist(self):
# https://keras.io/examples/vision/mnist_convnet/
conv = 3
Expand Down Expand Up @@ -62,15 +64,15 @@ def test_mnist(self):

# ****** tinygrad compare *******

c1 = Tensor(c1.detach().numpy())
c2 = Tensor(c2.detach().numpy())
l1 = Tensor(l1.detach().numpy())
c1 = Tensor(c1.detach().numpy(), device=self.device)
c2 = Tensor(c2.detach().numpy(), device=self.device)
l1 = Tensor(l1.detach().numpy(), device=self.device)

cnt = 5
fpt, bpt = 0.0, 0.0
for i in range(1+cnt):
et0 = time.time()
x = Tensor.randn(128, 1, 28, 28)
x = Tensor.randn(128, 1, 28, 28, device=self.device)
x = x.conv2d(c1).relu().avg_pool2d()
x = x.conv2d(c2).relu().max_pool2d()
x = x.reshape(shape=(x.shape[0], -1))
Expand All @@ -91,6 +93,14 @@ def test_mnist(self):
print("forward pass: %.3f ms, %.2fx off baseline %.3f ms" % (fpt, fpt/fpt_baseline, fpt_baseline))
print("backward pass: %.3f ms, %.2fx off baseline %.3f ms" % (bpt, bpt/bpt_baseline, bpt_baseline))

@unittest.skipUnless(GPU, "Requires GPU")
class TestConvSpeedGPU(TestConvSpeed):
device = DeviceTypes.GPU

@unittest.skipUnless(ANE, "Requires ANE")
class TestConvSpeedANE(TestConvSpeed):
device=DeviceTypes.ANE


if __name__ == '__main__':
unittest.main()
Expand Down
23 changes: 19 additions & 4 deletions test/test_nn.py
@@ -1,10 +1,14 @@
#!/usr/bin/env python
import unittest
import numpy as np
from tinygrad.tensor import ANE, GPU, DeviceTypes
from tinygrad.nn import *
from extra.utils import get_parameters
import torch

class TestNN(unittest.TestCase):
device = DeviceTypes.CPU

def test_batchnorm2d(self, training=False):
sz = 4

Expand All @@ -16,6 +20,9 @@ def test_batchnorm2d(self, training=False):
bn.running_var = Tensor.randn(sz)
bn.running_var.data[bn.running_var.data < 0] = 0

if self.device==DeviceTypes.GPU: [x.cuda_() for x in get_parameters(bn)]
elif self.device==DeviceTypes.ANE: [x.ane_() for x in get_parameters(bn)]

# create in torch
with torch.no_grad():
tbn = torch.nn.BatchNorm2d(sz).eval()
Expand All @@ -29,25 +36,33 @@ def test_batchnorm2d(self, training=False):
np.testing.assert_allclose(bn.running_var.data, tbn.running_var.detach().numpy(), rtol=1e-5)

# trial
inn = Tensor.randn(2, sz, 3, 3)
inn = Tensor.randn(2, sz, 3, 3, device=self.device)

# in tinygrad
outt = bn(inn)

# in torch
toutt = tbn(torch.tensor(inn.data))
toutt = tbn(torch.tensor(inn.cpu().data))

# close
np.testing.assert_allclose(outt.data, toutt.detach().numpy(), rtol=5e-5)
np.testing.assert_allclose(outt.cpu().data, toutt.detach().numpy(), rtol=5e-5)

np.testing.assert_allclose(bn.running_mean.data, tbn.running_mean.detach().numpy(), rtol=1e-5)
np.testing.assert_allclose(bn.running_mean.cpu().data, tbn.running_mean.detach().numpy(), rtol=1e-5)

# TODO: this is failing
#np.testing.assert_allclose(bn.running_var.data, tbn.running_var.detach().numpy(), rtol=1e-5)

def test_batchnorm2d_training(self):
self.test_batchnorm2d(True)

@unittest.skipUnless(GPU, "Requires GPU")
class TestNNGPU(TestNN):
device = DeviceTypes.GPU

@unittest.skipUnless(ANE, "Requires ANE")
class TestNNANE(TestNN):
device=DeviceTypes.ANE


if __name__ == '__main__':
unittest.main()