In [1]:
from torchvision import datasets
from tempfile import gettempdir
from uuid import uuid1
import os
import json
import random
from tqdm.auto import trange
import torch

In [2]:
ds_train = datasets.CIFAR10(gettempdir(), download = True, train = True)
ds_test = datasets.CIFAR10(gettempdir(), download = True, train = False)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
# create labels
labels = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck"
]
class_to_id = {x:i for i,x in enumerate(labels)}

# create a target dir
target_dir = os.path.join(gettempdir(), "cifar10-ByTe-train")
if not os.path.exists(target_dir):
    print("Creating dataset")
    os.makedirs(target_dir, exist_ok=True)

    truth = {}
    for _, (x, l) in zip(trange(len(ds_train)), ds_train):
        fp = os.path.join(target_dir, str(uuid1()) + random.choice([".png", ".jpg", ".tif"]))
        truth[fp] = labels[l]
        x.save(fp)
        
    with open(os.path.join(target_dir, "truth.json"), "w") as f:
        f.write(json.dumps(truth))
else:
    print("Loading pre-prepared dataset")
    with open(os.path.join(target_dir, "truth.json"), "r") as f:
        truth = json.load(f)

Loading pre-prepared dataset


In [4]:
from gperc import Consumer, BinaryConfig, Perceiver

In [5]:
data = Consumer(truth, n_bytes=1, class_to_id=class_to_id)

[2021-11-17 19:23:08] Creating metadata


  0%|          | 0/5000 [00:00<?, ?it/s]

  0%|          | 0/5000 [00:00<?, ?it/s]

  0%|          | 0/5000 [00:00<?, ?it/s]

  0%|          | 0/5000 [00:00<?, ?it/s]

  0%|          | 0/5000 [00:00<?, ?it/s]

  0%|          | 0/5000 [00:00<?, ?it/s]

  0%|          | 0/5000 [00:00<?, ?it/s]

  0%|          | 0/5000 [00:00<?, ?it/s]

  0%|          | 0/5000 [00:00<?, ?it/s]

  0%|          | 0/5000 [00:00<?, ?it/s]

In [6]:
data

<gperc Consumer {
  "total_samples": 50000,
  "mode": null,
  "n_classes": 10,
  "n_bytes": 1,
  "seqlen": 3447,
  "vocab_size": 257,
  "style": "diff"
}>

In [7]:
data[0]["input_array"].shape

torch.Size([1, 3447])

In [8]:
# let's just check the header always to [:-1] to ignore the <EOF> otherwise it wont be able to convert
# to bytes
bytearray(data[0]["input_array"][0].tolist()[:-1])[:250]

bytearray(b'/var/folders/dg/b9jch2h97kj2qbcsxj7kb6rc0000gn/T/cifar10-ByTe-train/0f31ac58-47ad-11ec-b801-1e00ea1e7259.tif: TIFF image data, little-endian, direntries=10, height=32, bps=134, compression=none, PhotometricIntepretation=RGB, width=32\nII*\x00\x08\x00\x00\x00\n\x00\x00\x01\x04\x00\x01\x00')

In [9]:
data.set_supervised_mode()
data

<gperc Consumer {
  "total_samples": 50000,
  "mode": "supervised",
  "n_classes": 10,
  "n_bytes": 1,
  "seqlen": 3447,
  "vocab_size": 257,
  "style": "diff"
}>

In [11]:
data[0, "supervised"]

{'input_array': tensor([[ 47, 118,  97,  ...,  92,  72, 256]]),
 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1]]),
 'class': tensor([6])}

In [12]:
data[None]

{'input_array': tensor([[ 47, 118,  97,  ...,  92,  72, 256]]),
 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1]])}

In [13]:
# create the model
config = BinaryConfig(
    seqlen = data.seqlen,
    vocab_size = data.vocab_size,
    latent_dim = 32,
    eot_token = data.EOF_ID,
    n_classes = len(class_to_id),
    ffw_ratio=1.0
)
model = Perceiver(config)
model.num_parameters()

365578

In [14]:
with torch.no_grad():
    out = model(**data[[0, 1, 2]])
    print([labels[i] for i in out.argmax(-1).tolist()])

['ship', 'ship', 'ship']


In [15]:
data.create_batches(2)

In [16]:
# training the model
optim = torch.optim.Adam(model.parameters(), 0.001)
n_epochs = 1000
pbar = trange(n_epochs)
device = "cpu" if not torch.cuda.is_available() else "cuda:0"
model = model.to(device)

all_losses = []; all_acc = []

for epoch in pbar:
    batch = data.get_next_batch("supervised")
    batch = {k:v.to(device) for k,v in batch.items()}
    out = model(batch["input_array"])
    target = batch["class"].to(device)
    loss = torch.nn.functional.cross_entropy(out, target)
    
    optim.zero_grad()
    loss.backward()
    optim.step()
    
    all_losses.append(loss.item())
    acc = out.argmax(-1).eq(target).float().mean()
    all_acc.append(acc.item())
    pbar.set_description(f"loss: {all_losses[-1]:.5f} | acc: {all_acc[-1]:.2f}")

    if all_acc[-1] == 1.0:
      # memorisation complete
      break

  0%|          | 0/1000 [00:00<?, ?it/s]

KeyboardInterrupt: 