In [105]:
import torch
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import lovely_tensors

lovely_tensors.monkey_patch()

import torch
import torch.nn as nn
import torch.nn.functional as F

In [106]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_size, num_heads, hidden_dim, dropout_rate=0.1):
        super(TransformerEncoderLayer, self).__init__()

        self.multihead_attn = nn.MultiheadAttention(embed_dim=embed_size, num_heads=num_heads)
        self.linear1 = nn.Linear(embed_size, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, embed_size)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.dropout = nn.Dropout(dropout_rate)
        self.activation = nn.GELU()

    def forward(self, x):
        y = self.norm1(x)
        y, _ = self.multihead_attn(y, y, y)
        y = self.dropout(y)
        x = x + y

        y = self.norm2(x)
        y = self.linear1(y)
        y = self.activation(y)
        y = self.dropout(y)
        y = self.linear2(y)
        y = self.dropout(y)
        x = x + y

        return x

In [107]:
import torch
import torch.nn as nn


class PatchEmbedding(nn.Module):
    def __init__(self, img_size, patch_size, in_channels, embed_size):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2

        self.linear = nn.Linear(patch_size * patch_size * in_channels, embed_size)

    def forward(self, x):
        num_patches_h = self.img_size // self.patch_size
        num_patches_w = self.img_size // self.patch_size
        x = x.permute(0, 2, 3, 1)
        x = x.view(x.shape[0], num_patches_h, self.patch_size, num_patches_w, self.patch_size, x.shape[3])
        x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
        x = x.view(x.shape[0], -1, self.patch_size * self.patch_size * x.shape[-1])
        x = self.linear(x)
        return x

In [108]:
class ViT(nn.Module):
    def __init__(self, img_size, patch_size, in_channels, embed_size, num_heads, depth, n_classes):
        super().__init__()
        self.patch_embedding = PatchEmbedding(img_size, patch_size, in_channels, embed_size)

        self.positional_encoding = nn.Parameter(torch.randn(1, self.patch_embedding.n_patches + 1, embed_size))

        transformers = [TransformerEncoderLayer(embed_size=embed_size, num_heads=num_heads, hidden_dim=512)] * depth

        self.transformers = nn.Sequential(*transformers)

        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_size))
        self.linear = nn.Linear(embed_size, n_classes)

    def forward(self, x):
        batch_size = x.shape[0]
        x = self.patch_embedding(x)
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.positional_encoding
        x = self.transformers(x)
        x = self.linear(x[:, 0])

        return x

In [109]:
vit = ViT(img_size=32, patch_size=4, in_channels=3, embed_size=256, num_heads=8, depth=6, n_classes=10)

input_tensor = torch.randn(1, 3, 32, 32)
out = vit(input_tensor)
out

tensor[1, 10] x∈[-2.232, 5.778] μ=0.753 σ=2.212 grad AddmmBackward0 [[-0.393, 1.612, 1.139, -1.140, 0.874, 2.270, -2.232, -0.205, 5.778, -0.177]]

In [113]:
from torchviz import make_dot

make_dot(out.mean(), params=dict(vit.named_parameters()))

ExecutableNotFound: failed to execute Path('dot'), make sure the Graphviz executables are on your systems' PATH

<graphviz.graphs.Digraph at 0x7fb85c416e30>

In [110]:
input_names = ["img"]
output_names = ["preds"]

x = torch.randn(1, 3, 32, 32)
torch.onnx.export(vit, x, "vit.onnx", input_names=input_names, output_names=output_names, verbose=True)

verbose: False, log level: Level.ERROR
ERROR: missing-standard-symbolic-function
Exporting the operator 'aten::unflatten' to ONNX opset version 14 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues.
None
<Set verbose=True to see more details>




UnsupportedOperatorError: Exporting the operator 'aten::unflatten' to ONNX opset version 14 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues.

In [111]:
transform = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.RandomResizedCrop(size=32, scale=(0.8, 1.0), ratio=(0.9, 1.1)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ]
)

train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [112]:
import torch
import torch.nn as nn
import torch.optim as optim

model = ViT(img_size=32, patch_size=4, in_channels=3, embed_size=256, num_heads=8, depth=6, n_classes=10)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001)
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader)}")

    model.eval()
    total = 0
    correct = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    if epoch in lr_drop_epochs:
        for param_group in optimizer.param_groups:
            param_group["lr"] /= 10

    print(f"Accuracy of the model on the test images: {100 * correct / total} %")

KeyboardInterrupt: 