In [2]:
from tinygrad import Tensor, nn
from tinygrad.helpers import trange
import pickle
import numpy as np

In [3]:
def unpickle(file):
    with open(file, "rb") as f:
        data = pickle.load(f, encoding="bytes")
    return data

In [71]:
H = W = 32
C = 3
batches = [unpickle(f"data/data_batch_{i}") for i in range(1, 6)]
X_train = Tensor.cat(*[Tensor(batch[b"data"].reshape(10000, C, H, W)) for batch in batches])
print(X_train)
Y_train = Tensor.cat(*[Tensor(batch[b"labels"]) for batch in batches])
print(Y_train)

data = unpickle("data/test_batch")
X_test = Tensor(data[b"data"].reshape((10000, C, H, W)))
Y_test = Tensor(np.array(data[b"labels"]))

<Tensor <UOp METAL (50000, 3, 32, 32) uchar (<Ops.ADD: 45>, None)> on METAL with grad None>
<Tensor <UOp METAL (50000,) int (<Ops.ADD: 45>, None)> on METAL with grad None>


In [73]:
P = 16

class MSA:
    def __init__(self, embed_dim):
        self.embed_dim = embed_dim
        self.l_Q = nn.Linear(embed_dim, embed_dim)
        self.l_K = nn.Linear(embed_dim, embed_dim)
        self.l_V = nn.Linear(embed_dim, embed_dim)
    
    def __call__(self, x):
        Q = self.l_Q(x)
        K = self.l_K(x)
        V = self.l_V(x)
        return Q.scaled_dot_product_attention(K, V)

class MLP:
    def __init__(self, embed_dim, ff_dim):
        self.ff1 = nn.Linear(embed_dim, ff_dim)
        self.ff2 = nn.Linear(ff_dim, embed_dim)
    
    def __call__(self, x):
        return self.ff2(self.ff1(x).gelu().dropout()).gelu().dropout()

class EncoderBlock:
    def __init__(self, embed_dim):
        self.msa = MSA(embed_dim)
        self.mlp = MLP(embed_dim, 4 * embed_dim)
    
    def __call__(self, x):
        x = x + self.msa(x.layernorm())
        x = x + self.mlp(x.layernorm())
        return x

class ViT:
    def __init__(self, embed_dim, n_classes):
        self.embed_dim = embed_dim
        self.embedding = nn.Conv2d(3, embed_dim, kernel_size=16, stride=16)
        self.cls_token = Tensor.uniform(1, 1, embed_dim, requires_grad=True)
        self.pos_embedding = Tensor.uniform(1, (H // P) * (W // P) + 1, embed_dim, requires_grad=True)
        self.encoder = EncoderBlock(embed_dim)
        self.mlp_head = nn.Linear(embed_dim, n_classes)

    def __call__(self, x):
        ce = Tensor.zeros(x.shape[0], 1, self.embed_dim) + self.cls_token
        x = self.embedding(x).flatten(2).transpose(1, 2)
        x = ce.cat(x, dim=1)
        x = x + self.pos_embedding
        x = self.encoder(x)
        x = self.mlp_head(x[:,-1,:])
        return x

BS = 512
model = ViT(embed_dim=192, n_classes=10)
opt = nn.optim.Adam(nn.state.get_parameters(model))

@Tensor.train()
def train_step():
    opt.zero_grad()
    samples = Tensor.randint(BS, high=X_train.shape[0])
    loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward()
    opt.step()
    return loss

@Tensor.test()
def get_test_acc() -> Tensor: return (model(X_test).argmax(axis=1) == Y_test).mean()*100

test_acc = float('nan')
for i in (t:=trange(1000)):
    loss = train_step()
    if i%10 == 9: test_acc = get_test_acc().item()
    t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")

loss:   2.23 test_accuracy: 34.48%: 100%|███| 1000/1000 [03:05<00:00,  5.40it/s]
