### Custom BERT implementation

In [None]:
%cd ../..
import torch
from inpainting.models import FOOBERT
from inpainting.trainer import Trainer
from inpainting.datasets import MNIST
from inpainting.utils import configure_environment

seed = torch.randint(0, 123456, (1,)).item()
configure_environment(seed=seed)

In [2]:
mnist = MNIST(frac=0.05, clusters=35, unimask=False, shape=2)

In [3]:
model = FOOBERT(
    vocab_size = mnist.tokens,                      # BERT's vocabulary size
    embed_size = mnist.tokens * 12,                 # Hidden size
    num_layers = 8,                                 # Number of Transformer layers
    num_heads = 12,                                 # Numbertargets of attention heads
    ff_hidden = 3200,                               # Feed-forward hidden size
    max_len = (28//mnist.shape)**2,                 # Maximum sequence length
    dropout = 0.00,
	# ce_weights=weights,
	patches=mnist.itop(torch.arange(mnist.clusters)),
)

In [None]:
Trainer(model, mnist).train(epochs=50, batch_size=100, lr=2e-4)

In [None]:
x = mnist[9]["input_ids"]
mnist.plot_sample(x)

In [None]:
y = model(x.unsqueeze(0).cuda()).logits.argmax(dim=-1).cpu()
mnist.plot_sample(y)

In [7]:
model.save(f"saved/foobert_{seed}")