### **This is just the structure of a ViT model, not a completed model**

In [1]:
import torch
from torchvision import datasets, transforms
from torch import nn
import torch.optim as optim
import numpy as np
from torch.utils.data import ConcatDataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
transform = transforms.Compose([
    #transforms.Resize((48, 48)),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor()
])

flip_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Lambda(lambda img: transforms.functional.hflip(img)),
    transforms.ToTensor()
])

In [3]:
data_train = "Dataset\\train"
data_test = "Dataset\\test"
dataset = datasets.ImageFolder(root=data_train, transform=transform)
evaluate = datasets.ImageFolder(root=data_test, transform=transform)
flipped_dataset = datasets.ImageFolder(root=data_train, transform=flip_transform)
flipped_evaluate = datasets.ImageFolder(root=data_test, transform=flip_transform)
combined_dataset = ConcatDataset([dataset, flipped_dataset])
combined_evaluate = ConcatDataset([evaluate, flipped_evaluate])
data_loader = torch.utils.data.DataLoader(combined_dataset, batch_size=64, shuffle=True)
evaluation = torch.utils.data.DataLoader(combined_evaluate, batch_size=64, shuffle=False)

In [14]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size = 48, in_channels = 1, patch_size = 4, emb_size = 64):
        super().__init__()
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2


        self.projection = nn.Conv2d(
                in_channels,
                emb_size,
                kernel_size=patch_size,
                stride=patch_size,
        )

    def forward(self, x):
        x = self.projection(x)
        x = x.flatten(2)
        x = x.transpose(1, 2)

        return x

In [15]:
class Attention(nn.Module):
    def __init__(self, dim, n_heads, dropout):
        super().__init__()
        self.n_heads = n_heads
        self.att = torch.nn.MultiheadAttention(embed_dim=dim,
                                               num_heads=n_heads,
                                               dropout=dropout)
        self.q = torch.nn.Linear(dim, dim)
        self.k = torch.nn.Linear(dim, dim)
        self.v = torch.nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)
        attn_output, _ = self.att(q, k, v)
        attn_output = self.dropout(attn_output)
        return attn_output

In [16]:
class FeedForward(nn.Sequential):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

In [17]:
class TransformerBlock(nn.Module):
    def __init__(self, dim, n_heads, mlp_ratio=4.0, dropout=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.attn = Attention(
                dim = dim,
                n_heads=n_heads,
                dropout=dropout
        )
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        hidden_features = int(dim * mlp_ratio)
        self.feedforward = FeedForward(
                dim=dim,
                hidden_dim=hidden_features,
                dropout=dropout,
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.feedforward(self.norm2(x))

        return x

In [18]:
class VisionTransformer(nn.Module):
    def __init__(self, img_size=48, patch_size=4, channels=1, n_classes=7, embed_dim=64, depth=8, n_heads=4, mlp_ratio=4., dropout=0.1):
        super().__init__()

        self.patch_embedding = PatchEmbedding(
                img_size=img_size,
                patch_size=patch_size,
                in_channels=channels,
                emb_size=embed_dim,
        )

        num_patches = (img_size // patch_size) ** 2
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embedding = nn.Parameter(
                torch.zeros(1, 1 + num_patches, embed_dim)
        )
        #self.pos_drop = nn.Dropout(dropout)

        self.transform_blocks = nn.ModuleList(
            [
                TransformerBlock(
                    dim=embed_dim,
                    n_heads=n_heads,
                    mlp_ratio=mlp_ratio,
                    dropout=dropout
                )
                for _ in range(depth)
            ]
        )

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, n_classes)


    def forward(self, x):
        n_samples = x.shape[0]
        x = self.patch_embedding(x)

        cls_token = self.cls_token.expand(n_samples, -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x = x + self.pos_embedding
        #x = self.pos_drop(x)

        for block in self.transform_blocks:
            x = block(x)
        x = self.norm(x)
        cls_token_final = x[:, 0]
        x = self.head(cls_token_final)

        return x

In [None]:
'''
model = VisionTransformer()
device = "cuda"
model = VisionTransformer().to(device)
optimizer = optim.AdamW(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
'''

In [None]:
'''
for epoch in range(101):
    epoch_losses = []
    model.train()
    for step, (inputs, labels) in enumerate(data_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_losses.append(loss.item())
    if epoch % 5 == 0:
        print(f">>> Epoch {epoch} train loss: ", np.mean(epoch_losses))
        epoch_losses = []
        model.eval()
        with torch.no_grad():
            for step, (inputs, labels) in enumerate(evaluation):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                epoch_losses.append(loss.item())
        print(f">>> Epoch {epoch} test loss: ", np.mean(epoch_losses))
        model.train()

torch.save(model.state_dict(), 'model.pth')
'''

In [None]:
'''
correct = 0
total = 0
with torch.no_grad():
    for data in evaluation:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the test images: %d %%' % (
    100 * correct / total))
'''

In [None]:
'''
inputs, labels = next(iter(evaluation))
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)

print("Predicted classes", outputs.argmax(-1))
print("Actual classes", labels)
'''