diff --git a/README.md b/README.md index c8e5ff45150e..44e3e70fa522 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,7 @@ tinygrad supports GPUs through PyOpenCL. ```python from tinygrad.tensor import Tensor -(Tensor.ones(4,4).cuda() + Tensor.ones(4,4).cuda()).cpu() +(Tensor.ones(4,4).gpu() + Tensor.ones(4,4).gpu()).cpu() ``` ### ANE Support?! diff --git a/extra/training.py b/extra/training.py index 3ce050679fd6..24fc84acb40b 100644 --- a/extra/training.py +++ b/extra/training.py @@ -2,21 +2,22 @@ import numpy as np from tqdm import trange from extra.utils import get_parameters -from tinygrad.tensor import Tensor, GPU +from tinygrad.tensor import Tensor, GPU, Device -def train(model, X_train, Y_train, optim, steps, num_classes=None, BS=128, gpu=False, lossfn = lambda out,y: out.mul(y).mean()): - if gpu is True: [x.cuda_() for x in get_parameters([model, optim])] +def train(model, X_train, Y_train, optim, steps, num_classes=None, BS=128, device=Device.CPU, lossfn = lambda out,y: out.mul(y).mean()): + if device == Device.GPU: [x.gpu_() for x in get_parameters([model, optim])] + elif device == Device.ANE: [x.ane_() for x in get_parameters([model, optim])] if num_classes is None: num_classes = Y_train.max().astype(int)+1 losses, accuracies = [], [] for i in (t := trange(steps, disable=os.getenv('CI') is not None)): samp = np.random.randint(0, X_train.shape[0], size=(BS)) - x = Tensor(X_train[samp].reshape((-1, 28*28)).astype(np.float32), gpu=gpu) + x = Tensor(X_train[samp].reshape((-1, 28*28)).astype(np.float32), device=device) Y = Y_train[samp] y = np.zeros((len(samp),num_classes), np.float32) # correct loss for NLL, torch NLL loss returns one per row y[range(y.shape[0]),Y] = -1.0*num_classes - y = Tensor(y, gpu=gpu) + y = Tensor(y, device=device) # network out = model.forward(x) @@ -36,11 +37,11 @@ def train(model, X_train, Y_train, optim, steps, num_classes=None, BS=128, gpu=F accuracies.append(accuracy) t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy)) -def evaluate(model, X_test, Y_test, num_classes=None, gpu=False, BS=128): +def evaluate(model, X_test, Y_test, num_classes=None, device=Device.CPU, BS=128): def numpy_eval(num_classes): Y_test_preds_out = np.zeros((len(Y_test),num_classes)) for i in trange(len(Y_test)//BS, disable=os.getenv('CI') is not None): - Y_test_preds_out[i*BS:(i+1)*BS] = model.forward(Tensor(X_test[i*BS:(i+1)*BS].reshape((-1, 28*28)).astype(np.float32), gpu=gpu)).cpu().data + Y_test_preds_out[i*BS:(i+1)*BS] = model.forward(Tensor(X_test[i*BS:(i+1)*BS].reshape((-1, 28*28)).astype(np.float32), device=device)).cpu().data Y_test_preds = np.argmax(Y_test_preds_out, axis=1) return (Y_test == Y_test_preds).mean() diff --git a/setup.py b/setup.py index af121472768a..823bfcab2c05 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ install_requires=['numpy', 'requests'], python_requires='>=3.8', extras_require={ - 'gpu': ["pyopencl"], + 'gpu': ["pyopencl", "six"], 'testing': [ "pytest", "torch", diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/config.py b/test/config.py new file mode 100644 index 000000000000..ab20e8b39e70 --- /dev/null +++ b/test/config.py @@ -0,0 +1,3 @@ +import os + +ANE = os.environ.get('ANE', False) diff --git a/test/test_gc.py b/test/test_gc.py index 7aa90921396d..2dc007d34bdc 100644 --- a/test/test_gc.py +++ b/test/test_gc.py @@ -1,31 +1,32 @@ #!/usr/bin/env python import gc import unittest -from tinygrad.tensor import Tensor, GPU +from tinygrad.tensor import Tensor, GPU, Device +from .config import ANE def tensors_allocated(): return sum([isinstance(x, Tensor) for x in gc.get_objects()]) class TestGC(unittest.TestCase): - gpu = False + device = Device.CPU def test_gc(self): - a = Tensor.zeros(4,4, gpu=self.gpu) - b = Tensor.zeros(4,4, gpu=self.gpu) + a = Tensor.zeros(4,4, device=self.device) + b = Tensor.zeros(4,4, device=self.device) (a*b).mean().backward() assert(tensors_allocated() > 0) del a,b assert(tensors_allocated() == 0) def test_gc_complex(self): - a = Tensor.zeros(4,4, gpu=self.gpu) - b = Tensor.zeros(4,4, gpu=self.gpu) + a = Tensor.zeros(4,4, device=self.device) + b = Tensor.zeros(4,4, device=self.device) assert(tensors_allocated() == 2) (a*b).mean().backward() assert(tensors_allocated() == 4) del b assert(tensors_allocated() == 2) - b = Tensor.zeros(4,4, gpu=self.gpu) + b = Tensor.zeros(4,4, device=self.device) print(tensors_allocated()) (a*b).mean().backward() print(tensors_allocated()) @@ -33,11 +34,13 @@ def test_gc_complex(self): del b assert(tensors_allocated() == 2) - +@unittest.skipUnless(GPU, "Requires GPU") +class TestGCGPU(TestGC): + device = Device.GPU -if GPU: - class TestGCGPU(TestGC): - gpu = True +@unittest.skipUnless(ANE, "Requires ANE") +class TestGCANE(TestGC): + device=Device.ANE if __name__ == '__main__': unittest.main() diff --git a/test/test_mnist.py b/test/test_mnist.py index a123938d1f84..8edae37f4d1c 100644 --- a/test/test_mnist.py +++ b/test/test_mnist.py @@ -2,10 +2,11 @@ import os import unittest import numpy as np -from tinygrad.tensor import Tensor, GPU +from tinygrad.tensor import Tensor, GPU, Device import tinygrad.optim as optim from extra.training import train, evaluate from extra.utils import fetch, get_parameters +from .config import ANE # mnist loader def fetch_mnist(): @@ -55,32 +56,36 @@ def forward(self, x): return x.dot(self.l1).logsoftmax() class TestMNIST(unittest.TestCase): - gpu=False + device = Device.CPU def test_conv(self): np.random.seed(1337) model = TinyConvNet() optimizer = optim.Adam(model.parameters(), lr=0.001) - train(model, X_train, Y_train, optimizer, steps=200, gpu=self.gpu) - assert evaluate(model, X_test, Y_test, gpu=self.gpu) > 0.95 + train(model, X_train, Y_train, optimizer, steps=200, device=self.device) + assert evaluate(model, X_test, Y_test, device=self.device) > 0.95 def test_sgd(self): np.random.seed(1337) model = TinyBobNet() optimizer = optim.SGD(model.parameters(), lr=0.001) - train(model, X_train, Y_train, optimizer, steps=1000, gpu=self.gpu) - assert evaluate(model, X_test, Y_test, gpu=self.gpu) > 0.95 + train(model, X_train, Y_train, optimizer, steps=1000, device=self.device) + assert evaluate(model, X_test, Y_test, device=self.device) > 0.95 def test_rmsprop(self): np.random.seed(1337) model = TinyBobNet() optimizer = optim.RMSprop(model.parameters(), lr=0.0002) - train(model, X_train, Y_train, optimizer, steps=1000, gpu=self.gpu) - assert evaluate(model, X_test, Y_test, gpu=self.gpu) > 0.95 + train(model, X_train, Y_train, optimizer, steps=1000, device=self.device) + assert evaluate(model, X_test, Y_test, device=self.device) > 0.95 @unittest.skipUnless(GPU, "Requires GPU") class TestMNISTGPU(TestMNIST): - gpu = True + device = Device.GPU + +@unittest.skipUnless(ANE, "Requires ANE") +class TestMNISTANE(TestMNIST): + device=Device.ANE if __name__ == '__main__': unittest.main() diff --git a/test/test_net_speed.py b/test/test_net_speed.py index 94a852d3dccc..8b1ed84a16e5 100644 --- a/test/test_net_speed.py +++ b/test/test_net_speed.py @@ -4,7 +4,8 @@ import pstats import unittest import torch -from tinygrad.tensor import Tensor +from tinygrad.tensor import Tensor, GPU, Device +from .config import ANE def start_profile(): import time @@ -20,6 +21,8 @@ def stop_profile(pr, sort='cumtime'): ps.print_stats(0.2) class TestConvSpeed(unittest.TestCase): + device= Device.CPU + def test_mnist(self): # https://keras.io/examples/vision/mnist_convnet/ conv = 3 @@ -62,15 +65,15 @@ def test_mnist(self): # ****** tinygrad compare ******* - c1 = Tensor(c1.detach().numpy()) - c2 = Tensor(c2.detach().numpy()) - l1 = Tensor(l1.detach().numpy()) + c1 = Tensor(c1.detach().numpy(), device=self.device) + c2 = Tensor(c2.detach().numpy(), device=self.device) + l1 = Tensor(l1.detach().numpy(), device=self.device) cnt = 5 fpt, bpt = 0.0, 0.0 for i in range(1+cnt): et0 = time.time() - x = Tensor.randn(128, 1, 28, 28) + x = Tensor.randn(128, 1, 28, 28, device=self.device) x = x.conv2d(c1).relu().avg_pool2d() x = x.conv2d(c2).relu().max_pool2d() x = x.reshape(shape=(x.shape[0], -1)) @@ -91,6 +94,14 @@ def test_mnist(self): print("forward pass: %.3f ms, %.2fx off baseline %.3f ms" % (fpt, fpt/fpt_baseline, fpt_baseline)) print("backward pass: %.3f ms, %.2fx off baseline %.3f ms" % (bpt, bpt/bpt_baseline, bpt_baseline)) +@unittest.skipUnless(GPU, "Requires GPU") +class TestConvSpeedGPU(TestConvSpeed): + device = Device.GPU + +@unittest.skipUnless(ANE, "Requires ANE") +class TestConvSpeedANE(TestConvSpeed): + device=Device.ANE + if __name__ == '__main__': unittest.main() diff --git a/test/test_nn.py b/test/test_nn.py index fa769d3fff10..ba00c7340bb5 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -1,10 +1,15 @@ #!/usr/bin/env python import unittest import numpy as np +from tinygrad.tensor import GPU, Device from tinygrad.nn import * +from extra.utils import get_parameters import torch +from .config import ANE class TestNN(unittest.TestCase): + device = Device.CPU + def test_batchnorm2d(self, training=False): sz = 4 @@ -29,13 +34,13 @@ def test_batchnorm2d(self, training=False): np.testing.assert_allclose(bn.running_var.data, tbn.running_var.detach().numpy(), rtol=1e-5) # trial - inn = Tensor.randn(2, sz, 3, 3) + inn = Tensor.randn(2, sz, 3, 3, device=self.device) # in tinygrad outt = bn(inn) # in torch - toutt = tbn(torch.tensor(inn.data)) + toutt = tbn(torch.tensor(inn.cpu().data)) # close np.testing.assert_allclose(outt.data, toutt.detach().numpy(), rtol=5e-5) @@ -48,6 +53,27 @@ def test_batchnorm2d(self, training=False): def test_batchnorm2d_training(self): self.test_batchnorm2d(True) +@unittest.skipUnless(GPU, "Requires GPU") +class TestNNGPU(TestNN): + device = Device.GPU + + @unittest.skip("Tests not added") + def test_batchnorm2d(self): pass + + @unittest.skip("Tests not added") + def test_batchnorm2d_training(self): pass + + +@unittest.skipUnless(ANE, "Requires ANE") +class TestNNANE(TestNN): + device=Device.ANE + + @unittest.skip("Tests not added") + def test_batchnorm2d(self): pass + + @unittest.skip("Tests not added") + def test_batchnorm2d_training(self): pass + if __name__ == '__main__': unittest.main() diff --git a/test/test_ops.py b/test/test_ops.py index 3226e4eafefb..4e283f87e20d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -4,14 +4,17 @@ import unittest import timeit import functools -from tinygrad.tensor import Tensor, GPU +from tinygrad.tensor import Tensor, GPU, Device +from .config import ANE -def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0, grad_rtol=1e-6, gpu=False, forward_only=False): +def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0, grad_rtol=1e-6, device=Device.CPU, forward_only=False): torch.manual_seed(0) ts = [torch.rand(x, requires_grad=True) for x in shps] tst = [Tensor(x.detach().numpy()) for x in ts] - if gpu: - tst = [x.cuda() for x in tst] + if device==Device.GPU: + tst = [x.gpu() for x in tst] + elif device==Device.ANE: + tst = [x.ane() for x in tst] out = torch_fxn(*ts) ret = tinygrad_fxn(*tst) @@ -23,7 +26,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0 ret.mean().backward() for t, tt in zip(ts, tst): - np.testing.assert_allclose(t.grad, tt.grad.cpu().data, atol=grad_atol, rtol=grad_rtol) + np.testing.assert_allclose(t.grad, tt.cpu().grad.data, atol=grad_atol, rtol=grad_rtol) # speed torch_fp = timeit.Timer(functools.partial(torch_fxn, *ts)).timeit(5) * 1000/5 @@ -38,58 +41,59 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0 print("testing %30r torch/tinygrad fp: %.2f / %.2f ms bp: %.2f / %.2f ms" % (shps, torch_fp, tinygrad_fp, torch_fbp-torch_fp, tinygrad_fbp-tinygrad_fp)) class TestOps(unittest.TestCase): - gpu = False + device=Device.CPU + def test_add(self): - helper_test_op([(45,65), (45,65)], lambda x,y: x+y, Tensor.add, gpu=self.gpu) + helper_test_op([(45,65), (45,65)], lambda x,y: x+y, Tensor.add, device=self.device) def test_sub(self): - helper_test_op([(45,65), (45,65)], lambda x,y: x-y, Tensor.sub, gpu=self.gpu) + helper_test_op([(45,65), (45,65)], lambda x,y: x-y, Tensor.sub, device=self.device) def test_mul(self): - helper_test_op([(45,65), (45,65)], lambda x,y: x*y, Tensor.mul, gpu=self.gpu) + helper_test_op([(45,65), (45,65)], lambda x,y: x*y, Tensor.mul, device=self.device) def test_div(self): - helper_test_op([(45,65), (45,65)], lambda x,y: x/y, Tensor.div, gpu=self.gpu) + helper_test_op([(45,65), (45,65)], lambda x,y: x/y, Tensor.div, device=self.device) def test_pow(self): - helper_test_op([(45,65), (45,65)], lambda x,y: x**y, Tensor.pow, gpu=self.gpu) + helper_test_op([(45,65), (45,65)], lambda x,y: x**y, Tensor.pow, device=self.device) def test_sqrt(self): - helper_test_op([(45,65)], lambda x: x.sqrt(), Tensor.sqrt, gpu=self.gpu) + helper_test_op([(45,65)], lambda x: x.sqrt(), Tensor.sqrt, device=self.device) def test_relu(self): - helper_test_op([(45,65)], lambda x: x.relu(), Tensor.relu, gpu=self.gpu) + helper_test_op([(45,65)], lambda x: x.relu(), Tensor.relu, device=self.device) def test_leakyrelu(self): - helper_test_op([(45,65)], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leakyrelu, gpu=self.gpu) + helper_test_op([(45,65)], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leakyrelu, device=self.device) def test_abs(self): - helper_test_op([(45,65)], lambda x: torch.abs(x), Tensor.abs, gpu=self.gpu) + helper_test_op([(45,65)], lambda x: torch.abs(x), Tensor.abs, device=self.device) def test_sigmoid(self): - helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid, gpu=self.gpu) + helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid, device=self.device) def test_dot(self): - helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, gpu=self.gpu) + helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, device=self.device) def test_sum(self): - helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum, gpu=self.gpu) + helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum, device=self.device) def test_sum_axis(self): - helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)), gpu=self.gpu) + helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)), device=self.device) def test_mean_axis(self): - helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)), lambda x: Tensor.mean(x, axis=(1,2)), gpu=self.gpu) + helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)), lambda x: Tensor.mean(x, axis=(1,2)), device=self.device) def test_logsoftmax(self): - helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.logsoftmax, atol=1e-7, grad_atol=1e-7, gpu=self.gpu) + helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.logsoftmax, atol=1e-7, grad_atol=1e-7, device=self.device) def test_tanh(self): - helper_test_op([(45,65)], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6, gpu=self.gpu) + helper_test_op([(45,65)], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6, device=self.device) def test_topo_sort(self): - helper_test_op([(45,65)], lambda x: (x+x)*x, lambda x: x.add(x).mul(x), atol=1e-6, grad_atol=1e-6, gpu=self.gpu) + helper_test_op([(45,65)], lambda x: (x+x)*x, lambda x: x.add(x).mul(x), atol=1e-6, grad_atol=1e-6, device=self.device) def test_scalar_mul(self): - helper_test_op([(45,65)], lambda x: x*2, lambda x: x*2, gpu=self.gpu) + helper_test_op([(45,65)], lambda x: x*2, lambda x: x*2, device=self.device) def test_scalar_rmul(self): - helper_test_op([(45,65)], lambda x: 2*x, lambda x: 2*x, gpu=self.gpu) + helper_test_op([(45,65)], lambda x: 2*x, lambda x: 2*x, device=self.device) def test_scalar_sub(self): - helper_test_op([(45,65)], lambda x: x-2, lambda x: x-2, gpu=self.gpu) + helper_test_op([(45,65)], lambda x: x-2, lambda x: x-2, device=self.device) def test_scalar_rsub(self): - helper_test_op([(45,65)], lambda x: 2-x, lambda x: 2-x, gpu=self.gpu) + helper_test_op([(45,65)], lambda x: 2-x, lambda x: 2-x, device=self.device) def test_broadcast_full(self): for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul), (torch.div, Tensor.div), (torch.pow, Tensor.pow)]: for shapes in [((5,13,24,16), (5,1,24,1)), ((1,3,1,7,1), (2,1,5,1,8))]: with self.subTest(op=torch_op.__name__, shapes=shapes): - helper_test_op(shapes, torch_op, tinygrad_op, gpu=self.gpu) + helper_test_op(shapes, torch_op, tinygrad_op, device=self.device) def test_broadcast_partial(self): @@ -98,17 +102,18 @@ def test_broadcast_partial(self): for shapes in [((1,32,32,32), (1,32,1,1)), ((5,13,24,16,2), (1,13,24,1,1)), ((4,1), (4,5)), ((1,4), (5,4))]: with self.subTest(op=torch_op.__name__, shapes=shapes): - helper_test_op(shapes, torch_op, tinygrad_op, gpu=self.gpu, forward_only=self.gpu) + # NOTE: ANE backwards? + helper_test_op(shapes, torch_op, tinygrad_op, device=self.device, forward_only=self.device!=Device.CPU) def test_pad2d(self): - helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)), gpu=self.gpu) + helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)), device=self.device) def test_reshape(self): - helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,3,6,6)), lambda x: x.reshape(shape=(-1,3,6,6)), gpu=self.gpu) - helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,1,6,6)), lambda x: x.reshape(shape=(-1,1,6,6)), gpu=self.gpu) + helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,3,6,6)), lambda x: x.reshape(shape=(-1,3,6,6)), device=self.device) + helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,1,6,6)), lambda x: x.reshape(shape=(-1,1,6,6)), device=self.device) def test_detach(self): - helper_test_op([(4,3,6,6)], lambda x: x.detach(), lambda x: x.detach(), gpu=self.gpu, forward_only=True) + helper_test_op([(4,3,6,6)], lambda x: x.detach(), lambda x: x.detach(), device=self.device, forward_only=True) def test_conv2d(self): for bs in [1,8]: @@ -119,7 +124,7 @@ def test_conv2d(self): with self.subTest(batch_size=bs, channels=cin, groups=groups, height=H, width=W): helper_test_op([(bs,cin,11,28), (6,cin//groups,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(), - lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), gpu=self.gpu, grad_rtol=1e-5) + lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), device=self.device, grad_rtol=1e-5) def test_strided_conv2d(self): bs = 4 @@ -128,18 +133,18 @@ def test_strided_conv2d(self): with self.subTest(stride := 2): helper_test_op([(bs,cin,11,28), (4,cin,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,stride=2).relu(), - lambda x,w: Tensor.conv2d(x,w,stride=stride).relu(), gpu=self.gpu) + lambda x,w: Tensor.conv2d(x,w,stride=stride).relu(), device=self.device) with self.subTest(stride := (2,1)): helper_test_op([(bs,cin,11,28), (4,cin,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,stride=stride).relu(), - lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu(), gpu=self.gpu) + lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu(), device=self.device) def test_maxpool2d(self): for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]: with self.subTest(kernel_size=ksz): helper_test_op([(32,2,110,28)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz), - lambda x: Tensor.max_pool2d(x, kernel_size=ksz), gpu=self.gpu) + lambda x: Tensor.max_pool2d(x, kernel_size=ksz), device=self.device) def test_avgpool2d(self): shape = (32,2,111,28) @@ -147,11 +152,15 @@ def test_avgpool2d(self): with self.subTest(kernel_size=ksz): helper_test_op([shape], lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz), - lambda x: Tensor.avg_pool2d(x, kernel_size=ksz), gpu=self.gpu) + lambda x: Tensor.avg_pool2d(x, kernel_size=ksz), device=self.device) @unittest.skipUnless(GPU, "Requires GPU") class TestOpsGPU(TestOps): - gpu = True + device=Device.GPU + +@unittest.skipUnless(ANE, "Requires ANE") +class TestOpsANE(TestOps): + device=Device.ANE if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/test/test_optim.py b/test/test_optim.py index 0dc1976ce38e..99ddcff57c20 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -1,18 +1,20 @@ import numpy as np import torch import unittest -from tinygrad.tensor import Tensor, GPU +from tinygrad.tensor import Tensor, GPU, Device from tinygrad.optim import Adam, SGD, RMSprop from extra.utils import get_parameters +from .config import ANE x_init = np.random.randn(1,3).astype(np.float32) W_init = np.random.randn(3,3).astype(np.float32) m_init = np.random.randn(1,3).astype(np.float32) -def step_tinygrad(optim, kwargs={}, gpu=False): +def step_tinygrad(optim, kwargs={}, device=Device.CPU): net = TinyNet() optim = optim([net.x, net.W], **kwargs) - if gpu is True: [x.cuda_() for x in get_parameters([net, optim])] + if device==Device.GPU: [x.gpu_() for x in get_parameters([net, optim])] + elif device==Device.ANE: [x.ane_() for x in get_parameters([net, optim])] out = net.forward() out.backward() optim.step() @@ -54,20 +56,20 @@ def forward(self): class TestOptim(unittest.TestCase): - gpu = False + device = Device.CPU def test_adam(self): - for x,y in zip(step_tinygrad(Adam, gpu=self.gpu), + for x,y in zip(step_tinygrad(Adam, device=self.device), step_pytorch(torch.optim.Adam)): np.testing.assert_allclose(x, y, atol=1e-4) def test_sgd(self): - for x,y in zip(step_tinygrad(SGD, kwargs={'lr': 0.001}, gpu=self.gpu), + for x,y in zip(step_tinygrad(SGD, kwargs={'lr': 0.001}, device=self.device), step_pytorch(torch.optim.SGD, kwargs={'lr': 0.001})): np.testing.assert_allclose(x, y, atol=1e-5) def test_rmsprop(self): - for x,y in zip(step_tinygrad(RMSprop, kwargs={'lr': 0.001, 'decay': 0.99}, gpu=self.gpu), + for x,y in zip(step_tinygrad(RMSprop, kwargs={'lr': 0.001, 'decay': 0.99}, device=self.device), step_pytorch(torch.optim.RMSprop, kwargs={'lr': 0.001, 'alpha': 0.99})): np.testing.assert_allclose(x, y, atol=1e-5) @@ -75,7 +77,11 @@ def test_rmsprop(self): @unittest.skipUnless(GPU, "Requires GPU") class TestOptimGPU(TestOptim): - gpu = True + device = Device.GPU + +@unittest.skipUnless(ANE, "Requires ANE") +class TestOptimANE(TestOptim): + device = Device.ANE if __name__ == '__main__': diff --git a/test/test_tensor.py b/test/test_tensor.py index 3d2955deb65f..f54527b08402 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -1,8 +1,10 @@ import numpy as np import torch import unittest -from tinygrad.tensor import Tensor, GPU +from tinygrad.tensor import Tensor, GPU, Device from extra.gradcheck import numerical_jacobian, jacobian, gradcheck +from .config import ANE + x_init = np.random.randn(1,3).astype(np.float32) U_init = np.random.randn(3,3).astype(np.float32) @@ -11,13 +13,13 @@ m_init = np.random.randn(1,3).astype(np.float32) class TestTinygrad(unittest.TestCase): - gpu = False + device = Device.CPU def test_backward_pass(self): def test_tinygrad(): - x = Tensor(x_init, gpu=self.gpu) - W = Tensor(W_init, gpu=self.gpu) - m = Tensor(m_init, gpu=self.gpu) + x = Tensor(x_init, device=self.device) + W = Tensor(W_init, device=self.device) + m = Tensor(m_init, device=self.device) out = x.dot(W).relu() out = out.logsoftmax() out = out.mul(m).add(m).sum() @@ -39,16 +41,16 @@ def test_pytorch(): def test_backward_pass_diamond_model(self): def test_tinygrad(): - u = Tensor(U_init) - v = Tensor(V_init) - w = Tensor(W_init) + u = Tensor(U_init, device=self.device) + v = Tensor(V_init, device=self.device) + w = Tensor(W_init, device=self.device) x = u.mul(v).relu() y = u.mul(w).relu() out = x.add(y).mul(y).relu() out = out.logsoftmax() out = out.sum() out.backward() - return out.data, u.grad.data, v.grad.data, w.grad.data + return out.cpu().data, u.cpu().grad.data, v.cpu().grad.data, w.cpu().grad.data def test_pytorch(): u = torch.tensor(U_init, requires_grad=True) @@ -74,8 +76,8 @@ def test_jacobian(self): torch_func = lambda x: torch.nn.functional.log_softmax(x.matmul(torch_W).relu(), dim=1) PJ = torch.autograd.functional.jacobian(torch_func, torch_x).squeeze().numpy() - tiny_x = Tensor(x, gpu=self.gpu) - tiny_W = Tensor(W, gpu=self.gpu) + tiny_x = Tensor(x, device=self.device) + tiny_W = Tensor(W, device=self.device) tiny_func = lambda x: x.dot(tiny_W).relu().logsoftmax() J = jacobian(tiny_func, tiny_x) NJ = numerical_jacobian(tiny_func, tiny_x) @@ -87,8 +89,8 @@ def test_gradcheck(self): W = np.random.RandomState(1337).random((10, 5)) x = np.random.RandomState(7331).random((1, 10)) - 0.5 - tiny_x = Tensor(x, gpu=self.gpu) - tiny_W = Tensor(W, gpu=self.gpu) + tiny_x = Tensor(x, device=self.device) + tiny_W = Tensor(W, device=self.device) tiny_func = lambda x: x.dot(tiny_W).relu().logsoftmax() self.assertTrue(gradcheck(tiny_func, tiny_x)) @@ -99,7 +101,7 @@ def test_gradcheck(self): @unittest.skipUnless(GPU, "Requires GPU") class TestTinygradGPU(TestTinygrad): - gpu = True + device = Device.GPU @unittest.skip("float64 not supported on GPU") def test_jacobian(self): pass @@ -107,6 +109,9 @@ def test_jacobian(self): pass @unittest.skip("float64 not supported on GPU") def test_gradcheck(self): pass +@unittest.skipUnless(ANE, "Requires ANE") +class TestOpsANE(TestTinygrad): + device=Device.ANE if __name__ == '__main__': unittest.main() diff --git a/tinygrad/ops_gpu.py b/tinygrad/ops_gpu.py index 32c46c6fb30d..269dbb49ac1a 100644 --- a/tinygrad/ops_gpu.py +++ b/tinygrad/ops_gpu.py @@ -1,5 +1,5 @@ import numpy as np -from .tensor import Function, register, GPUBuffer, Tensor +from .tensor import Function, register, GPUBuffer, Tensor, Device import pyopencl as cl import functools @@ -178,7 +178,7 @@ def backward(ctx, grad_output): grad_x, grad_y = grad_output, grad_output shape_x, shape_y = ctx.saved_tensors return unbroadcast(ctx, grad_x, shape_x), unbroadcast(ctx, grad_y, shape_y), -register('add', Add, device=Tensor.GPU) +register('add', Add, device=Device.GPU) class Sub(Function): @staticmethod @@ -191,7 +191,7 @@ def backward(ctx, grad_output): grad_x, grad_y = grad_output, unary_op(ctx, '-a', grad_output) shape_x, shape_y = ctx.saved_tensors return unbroadcast(ctx, grad_x, shape_x), unbroadcast(ctx, grad_y, shape_y), -register('sub', Sub, device=Tensor.GPU) +register('sub', Sub, device=Device.GPU) class Mul(Function): @staticmethod @@ -205,7 +205,7 @@ def backward(ctx, grad_output): grad_x = binary_op(ctx, 'a*b', y, grad_output) grad_y = binary_op(ctx, 'a*b', x, grad_output) return unbroadcast(ctx, grad_x, x.shape), unbroadcast(ctx, grad_y, y.shape), -register('mul', Mul, device=Tensor.GPU) +register('mul', Mul, device=Device.GPU) class Pow(Function): @staticmethod @@ -221,7 +221,7 @@ def backward(ctx, grad_output): grad_y = binary_op(ctx, 'a*b', grad_output, binary_op(ctx, 'pow(a, (float)b) * log(a);', x, y)) return unbroadcast(ctx, grad_x, x.shape), unbroadcast(ctx, grad_y, y.shape), -register('pow', Pow, device=Tensor.GPU) +register('pow', Pow, device=Device.GPU) class Sum(Function): @staticmethod @@ -237,8 +237,8 @@ def backward(ctx, grad_output): input, axis = ctx.saved_tensors shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))] output = GPUBuffer(shape, hostbuf=grad_output) - return binary_op(ctx, 'a+b', output, buffer_new(ctx, input.shape)) -register('sum', Sum, device=Tensor.GPU) + return binary_op(ctx, 'a+b', output, buffer_new(ctx, input.shape, zero=True)) +register('sum', Sum, device=Device.GPU) class Dot(Function): @staticmethod @@ -289,7 +289,7 @@ def backward(ctx, grad_output): i32(1), msize, isize, i32(1), osize, osize) return grad_input, grad_weight -register('dot', Dot, device=Tensor.GPU) +register('dot', Dot, device=Device.GPU) # ************* simple ops ************* @@ -332,7 +332,7 @@ def backward(ctx, grad_output): i32(oy), i32(ox), i32(iy), i32(ix) ) return ret -register('pad2d', Pad2D, device=Tensor.GPU) +register('pad2d', Pad2D, device=Device.GPU) class Reshape(Function): @staticmethod @@ -347,7 +347,7 @@ def forward(ctx, x, shape): def backward(ctx, grad_output): in_shape, = ctx.saved_tensors return GPUBuffer(in_shape, hostbuf=grad_output) -register('reshape', Reshape, device=Tensor.GPU) +register('reshape', Reshape, device=Device.GPU) # ************* activation ops ************* @@ -361,7 +361,7 @@ def forward(ctx, input): def backward(ctx, grad_output): input, = ctx.saved_tensors return binary_op(ctx, 'a * (b >= 0)', grad_output, input) -register('relu', ReLU, device=Tensor.GPU) +register('relu', ReLU, device=Device.GPU) class Sigmoid(Function): @staticmethod @@ -374,7 +374,7 @@ def forward(ctx, input): def backward(ctx, grad_output): ret, = ctx.saved_tensors return binary_op(ctx, 'a * (b * (1 - b));', grad_output, ret) -register('sigmoid', Sigmoid, device=Tensor.GPU) +register('sigmoid', Sigmoid, device=Device.GPU) class AvgPool2D(Function): @staticmethod @@ -389,7 +389,7 @@ def backward(ctx, grad_output): orig_shape, = ctx.saved_tensors return supersample_op(ctx, grad_output, orig_shape, ctx.kernel_size, result_op="input[iid] / (ksz.x * ksz.y)") -register('avg_pool2d', AvgPool2D, device=Tensor.GPU) +register('avg_pool2d', AvgPool2D, device=Device.GPU) class MaxPool2D(Function): @staticmethod @@ -409,7 +409,7 @@ def backward(ctx, grad_output): result_op="(maxidx == kernidx) * input[iid]", decls="int maxidx=((__global float*)input2)[iid]; int kernidx=(gid.x%ksz.x) + ksz.x*(gid.y%ksz.y)", input2=idxs) -register('max_pool2d', MaxPool2D, device=Tensor.GPU) +register('max_pool2d', MaxPool2D, device=Device.GPU) class LogSoftmax(Function): @staticmethod @@ -426,7 +426,7 @@ def backward(ctx, grad_output): lsum = reduce_op(ctx, "out += a", "out", grad_output, axis=[1]) texp = binary_op(ctx, "exp(a) * b", output, lsum) return binary_op(ctx, "a - b", grad_output, texp) -register('logsoftmax', LogSoftmax, device=Tensor.GPU) +register('logsoftmax', LogSoftmax, device=Device.GPU) # ************* conv ops ************* @@ -553,4 +553,4 @@ def backward(ctx, grad_output): convw(ctx.cl_queue, [ctx.groups*rcout*cin, H, W], None, x.cl, grad_output.cl, dw.cl, *conv_args) convx(ctx.cl_queue, [bs, ctx.groups, cin], None, w.cl, grad_output.cl, dx.cl, *conv_args) return dx, dw -register('conv2d', Conv2D, device=Tensor.GPU) +register('conv2d', Conv2D, device=Device.GPU) diff --git a/tinygrad/optim.py b/tinygrad/optim.py index eebe670ae09b..e793f7741ba7 100644 --- a/tinygrad/optim.py +++ b/tinygrad/optim.py @@ -25,7 +25,7 @@ def __init__(self, params, lr=0.001, decay=0.9, eps=1e-8): super(RMSprop, self).__init__(params) self.lr, self.decay, self.eps = lr, decay, eps - self.v = [Tensor(np.zeros(t.shape, dtype=np.float32), gpu=params[0].gpu, requires_grad=False) for t in self.params] + self.v = [Tensor(np.zeros(t.shape, dtype=np.float32), device=params[0].device, requires_grad=False) for t in self.params] def step(self): for i, t in enumerate(self.params): @@ -37,8 +37,8 @@ def __init__(self, params, lr=0.001, b1=0.9, b2=0.999, eps=1e-8): super(Adam, self).__init__(params) self.lr, self.b1, self.b2, self.eps, self.t = lr, b1, b2, eps, 0 - self.m = [Tensor(np.zeros(t.shape, dtype=np.float32), gpu=params[0].gpu, requires_grad=False) for t in self.params] - self.v = [Tensor(np.zeros(t.shape, dtype=np.float32), gpu=params[0].gpu, requires_grad=False) for t in self.params] + self.m = [Tensor(np.zeros(t.shape, dtype=np.float32), device=params[0].device, requires_grad=False) for t in self.params] + self.v = [Tensor(np.zeros(t.shape, dtype=np.float32), device=params[0].device, requires_grad=False) for t in self.params] def step(self): self.t = self.t + 1 diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 730f0b13dd3c..14c5643c1607 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1,5 +1,6 @@ # inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py from inspect import signature +import functools import numpy as np import os from collections import defaultdict @@ -34,6 +35,7 @@ def __exit__(self, *junk): cl_ctx, cl_queue = None, None def require_init_gpu(): + if not GPU: raise Exception("No GPU Support, install pyopencl") global cl_ctx, cl_queue if cl_queue is None: devices = cl.get_platforms()[0].get_devices(device_type=cl.device_type.GPU) @@ -64,33 +66,16 @@ def require_init_ane(): # **** start with two base classes, Tensor and Function **** +class Device: CPU, GPU, ANE = 0, 1, 2 + class Tensor: did_float_warning = False ops = defaultdict(dict) - CPU, GPU, ANE = 0, 1, 2 - - def __init__(self, data, gpu=None, requires_grad=True): - if "ANETensor" in str(type(data)): - self.device = Tensor.ANE - elif isinstance(data, list): - data = np.array(data, dtype=np.float32) - elif GPU and isinstance(data, GPUBuffer): - self.device = Tensor.GPU - elif not isinstance(data, np.ndarray): - raise TypeError(f"Error constructing tensor with {data!r}") - - if isinstance(data, np.ndarray): - if data.dtype != np.float32 and not Tensor.did_float_warning: - # warning? float64 is actually needed for numerical jacobian - print(f"warning, {data.shape!r} isn't float32") - Tensor.did_float_warning = True - self.device = Tensor.CPU - - self.data, self.grad, self.requires_grad = data, None, requires_grad + def __init__(self, data, device=Device.CPU, requires_grad=True): + self.data = self._move_data(data, device) - if gpu: - self.cuda_() + self.device, self.grad, self.requires_grad = device, None, requires_grad # internal variables used for autograd graph construction self._ctx = None @@ -145,7 +130,7 @@ def backward(self): # fill in the first grad with one # this is "implicit gradient creation" - self.grad = Tensor(np.ones(self.shape, dtype=self.dtype), gpu=self.gpu, requires_grad=False) + self.grad = Tensor(np.ones(self.shape, dtype=self.dtype), device=self.device, requires_grad=False) for t0 in reversed(self.deepwalk(set(), [])): assert (t0.grad is not None) @@ -157,53 +142,59 @@ def backward(self): if g is not None: assert g.shape == t.shape, \ f"grad shape must match tensor shape in {self._ctx!r}, {g.shape!r} != {t.shape!r}" - gt = Tensor(g, requires_grad=False) + gt = Tensor(g, device=self.device, requires_grad=False) t.grad = gt if t.grad is None else (t.grad + gt) # ***** tinygrad supports CPU and GPU ***** - def cpu(self): - if self.device == Tensor.GPU: - with ProfileOp("toCPU", [self]): - ret = Tensor(np.empty(self.shape, dtype=np.float32), gpu=False) - cl.enqueue_copy(cl_queue, ret.data, self.data.cl, is_blocking=True) - if self.grad: - ret.grad = self.grad.cpu() - return ret - elif self.device == Tensor.ANE: - return Tensor(self.data.data().astype(np.float32), gpu=False) - else: - return self + @staticmethod + def _move_data(data, device): + if isinstance(data, GPUBuffer): + if device == Device.GPU: return data + old = data + data = np.empty(old.shape, dtype=np.float32) + with ProfileOp("toCPU", [data]): + cl.enqueue_copy(cl_queue, data, old.cl, is_blocking=True) + + elif "ANETensor" in str(type(data)): + if device == Device.ANE: return data + with ProfileOp("toCPU", [data]): + data = data.data().astype(np.float32) + + if not isinstance(data, np.ndarray): + data = np.array(data, dtype=np.float32) + + if data.dtype != np.float32 and not Tensor.did_float_warning: + # warning? float64 is actually needed for numerical jacobian + print(f"warning, {data.shape!r} isn't float32") + Tensor.did_float_warning = True + + if device == Device.GPU: + require_init_gpu() + with ProfileOp("toGPU", [data]): + return GPUBuffer(data.shape, data) + + elif device == Device.ANE: + require_init_ane() + with ProfileOp("toANE", [data]): + ndata = ane.tensor(data.shape) + ndata.data()[:] = data + return ndata + return data + + def to_(self, device): + self.data, self.device = self._move_data(self.data, device), device + if self.grad: self.grad.to_(device) + + def to(self, device): + ret = Tensor(self.data, device) + if self.grad: ret.grad = self.grad.to(device) + return ret - @property - def gpu(self): - return self.device == Tensor.GPU - - def cuda_(self): - self.data = self.cuda().data - self.device = Tensor.GPU - - def cuda(self): - if not GPU: - raise Exception("No GPU Support, install pyopencl") - if not self.gpu: - with ProfileOp("toGPU", [self]): - require_init_gpu() - ret = Tensor(GPUBuffer(self.shape, self.data)) - if self.grad: - ret.grad = self.grad.cuda() - return ret - return self - - def ane(self): - assert(not self.gpu) - require_init_ane() - ndata = ane.tensor(self.shape) - ndata.data()[:] = self.data - return Tensor(ndata) + def _is(self, device): return self.device == device def detach(self): - return Tensor(self.data, self.gpu) + return Tensor(self.data, device=self.device) # ***** non first class ops ***** @@ -232,7 +223,7 @@ def leakyrelu(self, neg_slope=0.01): def dropout(self, p=0.5): _mask = np.asarray(np.random.binomial(1, 1.0-p, size=self.shape), dtype=self.dtype) - ret = self * Tensor(_mask, requires_grad=False, gpu=self.gpu) + ret = self * Tensor(_mask, requires_grad=False, device=self.device) return ret.div(1.0 - p) def abs(self): @@ -259,18 +250,18 @@ def apply(self, *x, **kwargs): setattr(ctx, k, v) with ProfileOp(ctx.__class__.__name__, x): ret = Tensor(self.forward(ctx, *[t.data for t in x], **kwargs), - requires_grad=any([t.requires_grad for t in x])) + device=ctx.device, requires_grad=any([t.requires_grad for t in x])) if ret.requires_grad: ret._ctx = ctx return ret -def register(name, fxn, device=Tensor.CPU): +def register(name, fxn, device=Device.CPU): Tensor.ops[device][name] = fxn def dispatch(*x, **kwargs): tt = [arg for arg in x if isinstance(arg, Tensor)][0] - x = [Tensor(np.array([arg], dtype=tt.dtype), gpu=tt.gpu, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x] + x = [Tensor(np.array([arg], dtype=tt.dtype), device=tt.device, requires_grad=False) if not isinstance(arg, Tensor) else arg for arg in x] f = (Tensor.ops[tt.device])[name] - f.cl_ctx, f.cl_queue, f.ane = cl_ctx, cl_queue, ane + f.cl_ctx, f.cl_queue, f.ane, f.device = cl_ctx, cl_queue, ane, tt.device return f.apply(f, *x, **kwargs) setattr(Tensor, name, dispatch) # TODO: div is a second class op, so it doesn't work here @@ -279,6 +270,11 @@ def dispatch(*x, **kwargs): setattr(Tensor, f"__i{name}__", lambda self,x: self.assign(dispatch(self,x))) setattr(Tensor, f"__r{name}__", lambda self,x: dispatch(x,self)) +for device in [device for device in Device.__dict__.keys() if device[0] != "_"]: + setattr(Tensor, f"{device.lower()}", functools.partialmethod(Tensor.to, Device.__dict__[device])) + setattr(Tensor, f"{device.lower()}_", functools.partialmethod(Tensor.to_, Device.__dict__[device])) + setattr(Tensor, f"is_{device.lower()}", property(functools.partialmethod(Tensor._is, Device.__dict__[device]))) + # this registers all the operations import tinygrad.ops_cpu try: