In [None]:
import numpy as np

import torch
import torch.nn as nn
import math

from datasets import load_dataset
from torch.onnx.symbolic_opset9 import tensor
from torch.utils.data import DataLoader
from transformers import BertTokenizerFast
from tokenizers import (
    decoders,
    models,
    normalizers,
    pre_tokenizers,
    processors,
    trainers,
    Tokenizer,
)

In [None]:
BATCH_SIZE = 64
LEARNING_RATE = 1e-3
MAX_LENGTH = 400

In [None]:
imdb_ds = load_dataset("stanfordnlp/imdb")
tokenizer = torch.hub.load(
    "huggingface/pytorch-transformers", "tokenizer", "bert-base-uncased"
)

In [None]:
def collate_imdb(batch):
    texts = []
    for row in batch:
        texts.append(row["text"])

    # (batch_size, MAX_LENGTH)
    inputs = torch.LongTensor(
        tokenizer(texts, padding=True, truncation=True, max_length=MAX_LENGTH).input_ids
    )

    last_words = []
    for text in inputs:
        last_word_token_idx = (text != tokenizer.pad_token_id).nonzero()[-2].item()
        last_words.append(text[last_word_token_idx])
    labels = torch.LongTensor(last_words)
    return inputs, labels


train_data_loader = DataLoader(
    imdb_ds["train"], batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_imdb
)
test_data_loader = DataLoader(
    imdb_ds["test"], batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_imdb
)

In [None]:
def get_angles(pos, i, d_model):
    angle_rates = 1 / np.power(10_000, (2 * (i // 2)) / np.float32(d_model))
    return pos * angle_rates


def positional_encoding(position, d_model):
    angle_rads = get_angles(
        np.arange(position)[:, None], np.arange(d_model)[None, :], d_model
    )
    angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
    angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
    pos_encoding = angle_rads[None, ...]

    return torch.FloatTensor(pos_encoding)

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, input_dim, d_model, n_heads):
        super().__init__()

        self.input_dim = input_dim
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        self.wq = nn.Linear(input_dim, d_model)
        self.wk = nn.Linear(input_dim, d_model)
        self.wv = nn.Linear(input_dim, d_model)
        self.wo = nn.Linear(d_model, d_model)

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        batch_size, seq_length, _ = x.size()

        # 1. Q, K, V 생성
        Q = self.wq(x)
        K = self.wk(x)
        V = self.wv(x)

        # 1.1 Reshape Q, K, V
        Q = Q.view(batch_size, seq_length, self.n_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_length, self.n_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_length, self.n_heads, self.d_k).transpose(1, 2)

        # 2. Attention 점수 계산
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        # 3. Mask 적용 (필요한 경우)
        if mask is not None:
            # mask shape: [batch_size, 1, 1, seq_length]
            # scores shape: [batch_size, n_heads, seq_length, seq_length]
            # 마스크를 n_heads 차원으로 확장
            mask = mask.repeat(1, self.n_heads, 1, 1)
            scores = scores.masked_fill(mask == 0, -1e9)

        # 4. Softmax 적용 및 Value 와 곱셈
        attention_weights = self.softmax(scores)
        output = torch.matmul(attention_weights, V)

        # 4.1 Transpose 및 Reshape
        output = (
            output.transpose(1, 2)
            .contiguous()
            .view(batch_size, seq_length, self.d_model)
        )

        # 5. 최종 선형 변환
        output = self.wo(output)

        return output

In [None]:
class TransformerLayer(nn.Module):
    def __init__(self, input_dim, d_model, n_heads, dff, dropout_rate=0.1):
        super().__init__()

        self.multi_head_attention = MultiHeadAttention(input_dim, d_model, n_heads)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, dff),
            nn.ReLU(),
            nn.Linear(dff, d_model),
        )

        self.layer_norm1 = nn.LayerNorm(d_model)
        self.layer_norm2 = nn.LayerNorm(d_model)

        self.dropout1 = nn.Dropout(dropout_rate)
        self.dropout2 = nn.Dropout()

    def forward(self, x, mask):
        x1 = self.multi_head_attention(x, mask)
        x1 = self.dropout1(x1)
        x1 = self.layer_norm1(x1 + x)

        x2 = self.ffn(x1)
        x2 = self.dropout2(x2)
        return self.layer_norm2(x2 + x1)

In [None]:
class TextClassifier(nn.Module):
    def __init__(self, vocab_size, d_model, n_layers, n_heads, dff, max_len):
        super().__init__()

        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.dff = dff
        self.max_len = max_len

        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = nn.Parameter(positional_encoding(max_len, d_model))

        self.layers = nn.ModuleList(
            [TransformerLayer(d_model, d_model, n_heads, dff) for _ in range(n_layers)]
        )

        self.classification = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        mask = (x != tokenizer.pad_token_id).unsqueeze(1).unsqueeze(2)
        seq_len = x.shape[1]

        x = self.embedding(x) * math.sqrt(self.d_model)
        x = x + self.pos_encoding[:, :seq_len]

        for layer in self.layers:
            x = layer(x, mask)

        x = x[:, 0]
        x = self.classification(x)
        return x

In [None]:
my_device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")


def accuracy(m, dataloader):
    cnt = 0
    acc = 0

    for data in dataloader:
        inputs, labels = data
        inputs, labels = inputs.to(my_device), labels.to(my_device)

        preds = m(inputs)
        preds = torch.argmax(preds, dim=-1)

        cnt += labels.shape[0]
        acc += (labels == preds).sum().item()

    return acc / cnt

In [None]:
model = TextClassifier(
    vocab_size=len(tokenizer),
    d_model=32,
    n_layers=5,
    n_heads=4,
    dff=32,
    max_len=MAX_LENGTH,
).to(my_device)

In [None]:
from torch.optim import Adam

criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)

In [None]:
n_epochs = 1

for epoch in range(n_epochs):
    total_loss = 0.0
    model.train()
    for data in train_data_loader:
        inputs, labels = data
        inputs, labels = inputs.to(my_device), labels.to(my_device)

        optimizer.zero_grad()

        outputs = model(inputs)
        predictions = outputs.squeeze()

        loss = criterion(predictions, labels)
        loss.backward()

        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch + 1:3d} | Train Loss: {total_loss}")

    with torch.no_grad():
        model.eval()
        train_acc = accuracy(model, train_data_loader)
        test_acc = accuracy(model, test_data_loader)
        print(f"=========> Train acc: {train_acc:.3f} | Test acc: {test_acc:.3f}")