In [1]:
from tinygrad import Device
print(Device.DEFAULT)

METAL


In [2]:
from tinygrad import Tensor, nn

class Model:
    def __init__(self):
        self.l1 = nn.Conv2d(1, 32, kernel_size=(3,3))
        self.l2 = nn.Conv2d(32, 64, kernel_size=(3,3))
        self.l3 = nn.Linear(1600, 10)
    def __call__(self, x:Tensor) -> Tensor:
        x = self.l1(x).relu().max_pool2d((2,2))
        x = self.l2(x).relu().max_pool2d((2,2))
        return self.l3(x.flatten(1).dropout(0.5))

In [3]:
from tinygrad.nn.datasets import mnist

X_train, Y_train, X_test, Y_test = mnist()
print(X_train.shape, X_train.dtype, Y_train.shape, Y_train.dtype)

(60000, 1, 28, 28) dtypes.uchar (60000,) dtypes.uchar


In [4]:
model = Model()
acc = (model(X_test).argmax(axis=1) == Y_test).mean()
print(acc.item())

0.09309999644756317


In [6]:
optim = nn.optim.Adam(nn.state.get_parameters(model))
batch_size = 128
def step():
    Tensor.training = True
    samples = Tensor.randint(batch_size, high=X_train.shape[0])
    X, Y = X_train[samples], Y_train[samples]
    optim.zero_grad()
    loss = model(X).sparse_categorical_crossentropy(Y).backward()
    optim.step()
    return loss

In [7]:
import timeit
timeit.repeat(step, repeat=5, number=1)

[1.4108480419999978,
 0.2443342919996212,
 0.07059229199876427,
 0.0699064170003112,
 0.06788008300100046]

In [8]:
from tinygrad import GlobalCounters, Context
GlobalCounters.reset()
with Context(DEBUG=2): step()

scheduled 49 kernels
*** METAL      1 E_[90mn11[0m                                     arg  1 mem  0.06 GB tm      7.37us/     0.01ms (     0.00 GFLOPS    0.0|0.0     GB/s) ['__imul__']
*** METAL      2 E_[90mn12[0m                                     arg  1 mem  0.06 GB tm      6.83us/     0.01ms (     0.00 GFLOPS    0.0|0.0     GB/s) ['__imul__']
*** METAL      3 E_[90mn6[0m                                      arg  1 mem  0.06 GB tm      6.37us/     0.02ms (     0.00 GFLOPS    0.0|0.0     GB/s) ['randint']
*** METAL      4 r_[34m625[0m[90m_[0m[36m32[0m[90m_[0m[31m15000[0m[90m_[0m[33m3[0m[90m_[0m[35m4[0m[90m[0m                        arg  1 mem  0.06 GB tm      8.00us/     0.03ms (    57.50 GFLOPS   30.0|30.0    GB/s) ['__getitem__']
*** METAL      5 r_[34m5[0m[90m_[0m[36m2[0m[90m_[0m[35m10[0m[90mn1[0m                                arg  1 mem  0.06 GB tm      5.58us/     0.03ms (     0.06 GFLOPS    0.0|0.0     GB/s) ['sparse_categorical_crossent

In [10]:
from tinygrad import TinyJit
jit_step = TinyJit(step)

In [11]:
import timeit
timeit.repeat(jit_step, repeat=5, number=1)

[0.14869920800083491,
 0.07247608400030003,
 0.0030040419987926725,
 0.04559262499969918,
 0.022407042000850197]

In [12]:
for step in range(7000):
    loss = jit_step()

    if step%100 == 0:
        Tensor.training = False
        acc = (model(X_test).argmax(axis=1) == Y_test).mean().item()
        print(f"step {step:4d}, loss {loss.item():.2f}, acc {acc*100.:.2f}%")

step    0, loss 3.50, acc 73.39%
step  100, loss 0.43, acc 95.36%
step  200, loss 0.28, acc 96.91%
step  300, loss 0.11, acc 97.34%
step  400, loss 0.09, acc 97.62%
step  500, loss 0.04, acc 97.63%
step  600, loss 0.21, acc 97.61%
step  700, loss 0.06, acc 97.75%
step  800, loss 0.16, acc 97.84%
step  900, loss 0.09, acc 98.06%
step 1000, loss 0.22, acc 98.35%
step 1100, loss 0.16, acc 98.21%
step 1200, loss 0.13, acc 98.27%
step 1300, loss 0.07, acc 98.39%
step 1400, loss 0.06, acc 98.49%
step 1500, loss 0.14, acc 98.07%
step 1600, loss 0.02, acc 98.35%
step 1700, loss 0.10, acc 98.57%
step 1800, loss 0.16, acc 98.56%
step 1900, loss 0.17, acc 98.59%
step 2000, loss 0.08, acc 98.48%
step 2100, loss 0.04, acc 98.47%
step 2200, loss 0.06, acc 98.62%
step 2300, loss 0.01, acc 98.50%
step 2400, loss 0.08, acc 98.68%
step 2500, loss 0.13, acc 98.64%
step 2600, loss 0.08, acc 98.75%
step 2700, loss 0.03, acc 98.77%
step 2800, loss 0.08, acc 98.64%
step 2900, loss 0.18, acc 98.62%
step 3000,