In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from flash.core.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier
from flash_attn import FlashAttn

In [None]:
download_data("https://pl-flash-data.s3.amazonaws.com/mnist.zip", "./data")

In [None]:
# Define the model
class TransformerModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Linear(784, 256)
        self.encoder = FlashAttn(input_dim=256, hidden_dim=256, num_layers=4, num_heads=8)
        self.decoder = nn.Linear(256, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.embedding(x)
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [None]:
dataset = MNIST('./data', train=True, download=False, transform=transforms.ToTensor())
train_data, val_data = random_split(dataset, [55000, 5000])
train_loader = DataLoader(train_data, batch_size=256, num_workers=4)
val_loader = DataLoader(val_data, batch_size=256, num_workers=4)


In [None]:
model = TransformerModel()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


In [None]:
criterion = nn.CrossEntropyLoss()
metrics = {'accuracy': torchmetrics.Accuracy()}


In [None]:
trainer = flash.Trainer(max_epochs=10, gpus=torch.cuda.device_count())
trainer.finetune(model, train_loader, val_loader, criterion=criterion, optimizer=optimizer, metrics=metrics)


In [None]:
test_dataset = MNIST('./data', train=False, download=False, transform=transforms.ToTensor())
test_loader = DataLoader(test_dataset, batch_size=256, num_workers=4)
trainer.test(test_dataloaders=test_loader)