In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision import transforms, datasets

np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7f518d2c04f0>

In [2]:
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")

train_set = datasets.mnist.MNIST(root='./datasets', train=True, download=True, transform=transforms.ToTensor())
test_set = datasets.mnist.MNIST(root='./datasets', train=False, download=True, transform=transforms.ToTensor())

train_loader = DataLoader(train_set, shuffle=True, batch_size=128)
test_loader = DataLoader(test_set, shuffle=False, batch_size=128)

In [3]:
def patchify(images, n_patches):
    n, c, h, w = images.shape

    assert h == w, "Patchify method is implemented for square images only"

    patches = torch.zeros(n, n_patches ** 2, h * w * c // n_patches ** 2)
    patch_size = h // n_patches

    for idx, image in enumerate(images):
        for i in range(n_patches):
            for j in range(n_patches):
                patch = image[:, i * patch_size: (i + 1) * patch_size, j * patch_size: (j + 1) * patch_size]
                patches[idx, i * n_patches + j] = patch.flatten()
    return patches

def get_positional_embeddings(sequence_length, d):
    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
    return result

class MultiHeadAttention(nn.Module):
    def __init__(self, d, n_heads=2):
        super(MultiHeadAttention, self).__init__()
        self.d = d
        self.n_heads = n_heads

        assert d % n_heads == 0, f"Can't divide dimension {d} into {n_heads} heads"

        d_head = int(d / n_heads)
        self.q_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.k_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.v_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.d_head = d_head
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, sequences):
        result = []
        for sequence in sequences:
            seq_result = []
            for head in range(self.n_heads):
                q_mapping = self.q_mappings[head]
                k_mapping = self.k_mappings[head]
                v_mapping = self.v_mappings[head]

                seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]
                q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)

                attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
                seq_result.append(attention @ v)
            result.append(torch.hstack(seq_result))
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])

class ViTBlock(nn.Module):
    def __init__(self, hidden_d, n_heads, mlp_ratio=4):
        super(ViTBlock, self).__init__()
        self.hidden_d = hidden_d
        self.n_heads = n_heads

        self.norm1 = nn.LayerNorm(hidden_d)
        self.mha = MultiHeadAttention(hidden_d, n_heads)
        self.norm2 = nn.LayerNorm(hidden_d)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_d, mlp_ratio * hidden_d),
            nn.GELU(),
            nn.Linear(mlp_ratio * hidden_d, hidden_d)
        )
    def forward(self, x):
        out = x + self.mha(self.norm1(x))
        out = out + self.mlp(self.norm2(out))
        return out
        
class ViT(nn.Module):
    def __init__(self, chw=(1, 28, 28), n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10):
        # Super constructor
        super(ViT, self).__init__()

        # Attributes
        self.chw = chw # (C, H, W)
        self.n_patches = n_patches
        self.hidden_d = hidden_d

        assert chw[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        assert chw[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        self.patch_size = (chw[1] / n_patches, chw[2] / n_patches)

        # 1) Linear mapper
        self.input_d = int(chw[0] * self.patch_size[0] * self.patch_size[1])
        self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)

        # 2) Learnable classification token
        self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))

        # 3) Positional embedding
        self.register_buffer('positional_embeddings', get_positional_embeddings(n_patches ** 2 + 1, hidden_d), persistent=False)

        # 4) Transformer encoder blocks
        self.blocks = nn.ModuleList([ViTBlock(hidden_d, n_heads) for _ in range(n_blocks)])

        # 5) Classification MLP
        self.mlp = nn.Sequential(
            nn.Linear(self.hidden_d, out_d),
            nn.Softmax(dim=-1)
        )

    def forward(self, images):
        n, c, h, w = images.shape
        patches = patchify(images, self.n_patches).to(self.positional_embeddings.device)
        tokens = self.linear_mapper(patches)
        tokens = torch.cat((self.class_token.expand(n, 1, -1), tokens), dim=1)
        out = tokens + self.positional_embeddings.repeat(n, 1, 1)

        for block in self.blocks:
            out = block(out)
        out = out[:, 0]  # classification token만 선택
        out = self.mlp(out)  # 분류 레이어를 통해 최종 예측 생성
        return out

In [4]:
model = ViT(chw=(1, 28, 28), n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

model.to(device)

ViT(
  (linear_mapper): Linear(in_features=16, out_features=8, bias=True)
  (blocks): ModuleList(
    (0-1): 2 x ViTBlock(
      (norm1): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
      (mha): MultiHeadAttention(
        (q_mappings): ModuleList(
          (0-1): 2 x Linear(in_features=4, out_features=4, bias=True)
        )
        (k_mappings): ModuleList(
          (0-1): 2 x Linear(in_features=4, out_features=4, bias=True)
        )
        (v_mappings): ModuleList(
          (0-1): 2 x Linear(in_features=4, out_features=4, bias=True)
        )
        (softmax): Softmax(dim=-1)
      )
      (norm2): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
      (mlp): Sequential(
        (0): Linear(in_features=8, out_features=32, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=32, out_features=8, bias=True)
      )
    )
  )
  (mlp): Sequential(
    (0): Linear(in_features=8, out_features=10, bias=True)
    (1): Softmax(dim=-1)
  )
)

In [None]:
def accuracy(predictions, labels):
    _, preds = torch.max(predictions, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

n_epochs = 10
for epoch in range(n_epochs):
    model.train()
    total_loss = 0.0
    total_acc = 0.0
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{n_epochs}")
    for i, (images, labels) in progress_bar:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_acc += accuracy(outputs, labels)
        progress_bar.set_postfix({'loss': total_loss / (i+1), 'accuracy': total_acc / (i+1)})

    print(f"Epoch [{epoch+1}/{n_epochs}], Loss: {total_loss/len(train_loader):.4f}, Accuracy: {total_acc/len(train_loader):.4f}")

Epoch 1/10: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [05:15<00:00,  1.49it/s, loss=2.21, accuracy=tensor(0.2165)]


Epoch [1/10], Loss: 2.2147, Accuracy: 0.2165


Epoch 2/10: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [05:09<00:00,  1.51it/s, loss=2.05, accuracy=tensor(0.4105)]


Epoch [2/10], Loss: 2.0501, Accuracy: 0.4105


Epoch 3/10: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [05:10<00:00,  1.51it/s, loss=1.96, accuracy=tensor(0.5029)]


Epoch [3/10], Loss: 1.9599, Accuracy: 0.5029


Epoch 4/10: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [05:08<00:00,  1.52it/s, loss=1.91, accuracy=tensor(0.5531)]


Epoch [4/10], Loss: 1.9103, Accuracy: 0.5531


Epoch 5/10: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [05:11<00:00,  1.50it/s, loss=1.87, accuracy=tensor(0.6001)]


Epoch [5/10], Loss: 1.8656, Accuracy: 0.6001


Epoch 6/10: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [05:11<00:00,  1.50it/s, loss=1.83, accuracy=tensor(0.6292)]


Epoch [6/10], Loss: 1.8345, Accuracy: 0.6292


Epoch 7/10: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [11:55<00:00,  1.53s/it, loss=1.81, accuracy=tensor(0.6542)]


Epoch [7/10], Loss: 1.8094, Accuracy: 0.6542


Epoch 8/10:  45%|█████████████████████████████████████████████████████                                                                  | 209/469 [02:09<02:34,  1.68it/s, loss=1.79, accuracy=tensor(0.6689)]