In [1]:
from torchvision import datasets
from tempfile import gettempdir
from uuid import uuid1
import os
import json
import random
from tqdm.auto import trange
import torch

In [2]:
ds_train = datasets.CIFAR10(gettempdir(), download = True, train = True)
ds_test = datasets.CIFAR10(gettempdir(), download = True, train = False)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
# create labels
labels = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck"
]
class_to_id = {x:i for i,x in enumerate(labels)}

# create a target dir
target_dir = os.path.join(gettempdir(), "cifar10-ByTe-train")
if not os.path.exists(target_dir):
    print("Creating dataset")
    os.makedirs(target_dir, exist_ok=True)

    truth = {}
    for _, (x, l) in zip(trange(len(ds_train)), ds_train):
        fp = os.path.join(target_dir, str(uuid1()) + random.choice([".png", ".jpg", ".tif"]))
        truth[fp] = labels[l]
        x.save(fp)
        
    with open(os.path.join(target_dir, "truth.json"), "w") as f:
        f.write(json.dumps(truth))
else:
    print("Loading pre-prepared dataset")
    with open(os.path.join(target_dir, "truth.json"), "r") as f:
        truth = json.load(f)

Loading pre-prepared dataset


In [4]:
from gperc import Consumer, BinaryConfig, Perceiver, ArrowConsumer

In [5]:
data = ArrowConsumer(truth, n_bytes=1, class_to_id=class_to_id)

Using custom data configuration 804009c286e1c7b157519e9cf95893bc5b2b5c6c132b210f8a0935fab09e4417
Reusing dataset binary_arrow_builder (/Users/yashbonde/.cache/huggingface/datasets/binary_arrow_builder/804009c286e1c7b157519e9cf95893bc5b2b5c6c132b210f8a0935fab09e4417/0.0.0)


  0%|          | 0/1 [00:00<?, ?it/s]

Tokenising the entire datset, this can takes some time (will use all cores) ...


In [6]:
data

<gperc ArrowConsumer {
  "total_samples": 50000,
  "mode": "F2",
  "n_classes": 10,
  "n_bytes": 1,
  "seqlen": 3323,
  "vocab_size": 257,
  "style": "diff"
}>

In [7]:
data[0]["input_array"].shape

torch.Size([1, 3323])

In [8]:
# let's just check the header always to [:-1] to ignore the <EOF> otherwise it wont be able to convert
# to bytes
bytearray(data[0]["input_array"][0].tolist()[:-1])[:250]

bytearray(b'/var/folders/dg/b9jch2h97kj2qbcsxj7kb6rc0000gn/T/cifar10-ByTe-train/0f344d5a-47ad-11ec-b801-1e00ea1e7259.tif|II*\x00\x08\x00\x00\x00\n\x00\x00\x01\x04\x00\x01\x00\x00\x00 \x00\x00\x00\x01\x01\x04\x00\x01\x00\x00\x00 \x00\x00\x00\x02\x01\x03\x00\x03\x00\x00\x00\x86\x00\x00\x00\x03\x01\x03\x00\x01\x00\x00\x00\x01\x00\x00\x00\x06\x01\x03\x00\x01\x00\x00\x00\x02\x00\x00\x00\x11\x01\x04\x00\x01\x00\x00\x00\x8c\x00\x00\x00\x15\x01\x03\x00\x01\x00\x00\x00\x03\x00\x00\x00\x16\x01\x04\x00\x01\x00\x00\x00 \x00\x00\x00\x17\x01\x04\x00\x01\x00\x00\x00\x00\x0c\x00\x00\x1c\x01\x03\x00\x01\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x08\x00\x08\x00\x08\x00\xca')

In [9]:
data[None]

{'input_array': tensor([[ 47, 118,  97,  ..., 243,   0, 256]]),
 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1]]),
 'class': tensor([0])}

In [10]:
# create the model
config = BinaryConfig(
    seqlen = data.seqlen,
    vocab_size = data.vocab_size,
    latent_dim = 32,
    eot_token = data.EOF_ID,
    n_classes = len(class_to_id),
    ffw_ratio=1.0
)
model = Perceiver(config)
model.num_parameters()

257962

In [11]:
with torch.no_grad():
    out = model(data[[0, 1, 2]]["input_array"])
    print([labels[i] for i in out.argmax(-1).tolist()])

['automobile', 'bird', 'bird']


In [12]:
data.create_batches(32)

In [16]:
from time import time

class Timer:
    def __init__(self, msg = ""):
        self.msg = msg
    
    def __enter__(self, *a):
        self.st = time()
        
    def __exit__(self, *a):
        print(f"[{time()-self.st:.2f}s]\t{self.msg}")
        
    

# training the model
optim = torch.optim.Adam(model.parameters(), 0.001)
n_epochs = 1000
pbar = trange(n_epochs)
device = "cpu" if not torch.cuda.is_available() else "cuda:0"
model = model.to(device)

all_losses = []; all_acc = []

for epoch in pbar:
    st = time()
    with Timer("Getting data"):
        batch = data.get_next_batch()
        
    batch = {k:v.to(device) for k,v in batch.items()}
    
    with Timer("Forward pass"):
        out = model(batch["input_array"])

    target = batch["class"].to(device)
    
    with Timer("Optim Step"):
        loss = torch.nn.functional.cross_entropy(out, target)

        optim.zero_grad()
        loss.backward()
        optim.step()
    
    all_losses.append(loss.item())
    acc = out.argmax(-1).eq(target).float().mean()
    all_acc.append(acc.item())
    pbar.set_description(f"loss: {all_losses[-1]:.5f} | acc: {all_acc[-1]:.2f}")

    if all_acc[-1] == 1.0:
      # memorisation complete
      break

  0%|          | 0/1000 [00:00<?, ?it/s]

[0.03s]	Getting data
[1.30s]	Forward pass
[0.67s]	Optim Step
[0.02s]	Getting data
[1.17s]	Forward pass
[0.59s]	Optim Step
[0.02s]	Getting data
[1.17s]	Forward pass
[0.59s]	Optim Step
[0.02s]	Getting data
[1.18s]	Forward pass
[0.58s]	Optim Step


KeyboardInterrupt: 