In [3]:
import torch
import torch.nn as nn
from einops.layers.torch import Rearrange

class MultiheadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads

        self.query = nn.Linear(embed_dim, embed_dim, bias=False)
        self.key = nn.Linear(embed_dim, embed_dim, bias=False)
        self.value = nn.Linear(embed_dim, embed_dim, bias=False)

        self.out = nn.Linear(embed_dim, embed_dim)
        self.attn_dropout = nn.Dropout(p=0.1)
        self.proj_dropout = nn.Dropout(p=0.1)
        
        self.scale = torch.sqrt(torch.FloatTensor([self.embed_dim // self.num_heads])).to(device)

    def forward(self, x):
        b, n, e = x.shape
        q = self.query(x).reshape(b, n, self.num_heads, e // self.num_heads).transpose(1, 2)
        k = self.key(x).reshape(b, n, self.num_heads, e // self.num_heads).permute(0, 2, 3, 1)
        v = self.value(x).reshape(b, n, self.num_heads, e // self.num_heads).transpose(1, 2)

        attn = torch.matmul(q, k)
        attn = attn / self.scale
        attn = torch.softmax(attn, dim=-1)
        attn = self.attn_dropout(attn)

        x = torch.matmul(attn, v)
        x = x.transpose(1, 2).reshape(b, n, e)
        x = self.out(x)
        x = self.proj_dropout(x)
        return x

class FeedForward(nn.Module):
    def __init__(self, embed_dim, hidden_dim):
        super().__init__()
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim

        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, embed_dim)
        self.dropout = nn.Dropout(p=0.1)

    def forward(self, x):
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, hidden_dim, dropout_rate=0.1):
        super().__init__()
        self.attention = MultiheadSelfAttention(embed_dim, num_heads)
        self.feedforward = FeedForward(embed_dim, hidden_dim)
        self.layer_norm1 = nn.LayerNorm(embed_dim)
        self.layer_norm2 = nn.LayerNorm(embed_dim)
        self.dropout1 = nn.Dropout(p=dropout_rate)
        self.dropout2 = nn.Dropout(p=dropout_rate)

    def forward(self, x):
        x_norm = self.layer_norm1(x)
        attn_output = self.attention(x_norm) + x
        attn_output = self.dropout1(attn_output)
        x_norm = self.layer_norm2(attn_output)
        feedforward_output = self.feedforward(x_norm) + attn_output
        transformer_output = self.dropout2(feedforward_output)
        return transformer_output

class ViT(nn.Module):
    def __init__(self, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim):
        super().__init__()
        assert image_size % patch_size == 0, 'image size must be divisible by patch size'
        num_patches = (image_size // patch_size) ** 2
        patch_dim = 3 * patch_size ** 2

        self.patch_size = patch_size
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),
            nn.Linear(patch_dim, dim),
        )

        self.transformer = nn.Sequential(*[
            TransformerBlock(dim, heads, mlp_dim) for _ in range(depth)
        ])

        self.layer_norm = nn.LayerNorm(dim)
        self.mlp_head = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.Dropout(p=0.1),
            nn.ReLU(),
            nn.Linear(mlp_dim, num_classes)
        )

    def forward(self, x):
        x = self.patch_embedding(x)
        x = torch.cat((self.pos_embedding[:, :1].repeat(1, x.size(1), 1), self.pos_embedding[:, 1:]), dim=1)
        x = self.transformer(x)
        x = self.layer_norm(x[:, 0])
        x = self.mlp_head(x)
        return x            


In [4]:
import torch
import torch.nn as nn
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

# CIFAR10 데이터셋 불러오기
train_dataset = dset.CIFAR10(root='../DataSets/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = dset.CIFAR10(root='../DataSets/', train=False, transform=transforms.ToTensor(), download=True)

# 하이퍼파라미터 설정
image_size = 32
patch_size = 4
num_classes = 10
dim = 64
depth = 6
heads = 8
mlp_dim = 128
batch_size = 64
num_epochs = 10
learning_rate = 0.001

# 모델 생성
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ViT(image_size, patch_size, num_classes, dim, depth, heads, mlp_dim).to(device)

# 손실 함수와 옵티마이저 정의
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# 데이터로더 설정
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# 모델 학습
for epoch in range(num_epochs):
    train_loss = 0.0
    train_acc = 0.0
    test_loss = 0.0
    test_acc = 0.0

    # 학습 데이터셋으로 모델 학습
    model.train()
    for images, labels in tqdm(train_loader, desc='Train'):
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_acc += (outputs.argmax(dim=1) == labels).sum().item()

    train_loss /= len(train_loader.dataset)
    train_acc /= len(train_loader.dataset)

    # 테스트 데이터셋으로 모델 평가
    model.eval()
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc='Test'):
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            test_loss += loss.item()
            test_acc += (outputs.argmax(dim=1) == labels).sum().item()

        test_loss /= len(test_loader.dataset)
        test_acc /= len(test_loader.dataset)

    print(f'Epoch {epoch + 1}/{num_epochs}: train_loss={train_loss:.4f}, train_acc={train_acc:.4f}, test_loss={test_loss:.4f}, test_acc={test_acc:.4f}')

print('Training finished')


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

KeyboardInterrupt: 