In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import random

# 하이퍼파라미터 설정
VOCAB_SIZE = 20
SEQ_LEN = 10
BATCH_SIZE = 32
NUM_BATCHES = 100
EMBED_DIM = 32
NUM_HEADS = 2
NUM_LAYERS = 2
EPOCHS = 5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 랜덤 시드 고정
random.seed(42)
torch.manual_seed(42)

# 1. 역순 데이터셋 정의
class ReverseDataset(Dataset):
    def __init__(self, num_samples):
        self.data = []
        for _ in range(num_samples):
            seq = [random.randint(2, VOCAB_SIZE-1) for _ in range(SEQ_LEN)]
            self.data.append((seq, list(reversed(seq))))

    def __len__(self):
        return len(self.data) 

    def __getitem__(self, idx):
        src, tgt = self.data[idx]
        return torch.tensor(src), torch.tensor(tgt)

train_dataset = ReverseDataset(BATCH_SIZE * NUM_BATCHES)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# 2. 포지셔널 인코딩
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=SEQ_LEN):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)
        self.pe = pe.unsqueeze(0)  # (1, max_len, d_model)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)].to(x.device)

# 3. 간단한 Transformer 모델
class SimpleTransformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(VOCAB_SIZE, EMBED_DIM)
        self.pos_encoder = PositionalEncoding(EMBED_DIM)
        encoder_layer = nn.TransformerEncoderLayer(d_model=EMBED_DIM, nhead=NUM_HEADS)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=NUM_LAYERS)
        self.fc_out = nn.Linear(EMBED_DIM, VOCAB_SIZE)

    def forward(self, src):
        x = self.embedding(src)
        x = self.pos_encoder(x)
        x = x.transpose(0, 1)  # (seq_len, batch, embed_dim)
        x = self.transformer(x)
        x = x.transpose(0, 1)
        return self.fc_out(x)

# 4. 학습 루프
model = SimpleTransformer().to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(EPOCHS):
    total_loss = 0
    for src, tgt in train_loader:
        src, tgt = src.to(DEVICE), tgt.to(DEVICE)
        out = model(src)  # (batch, seq_len, vocab_size)
        loss = criterion(out.view(-1, VOCAB_SIZE), tgt.view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}: Loss = {total_loss / len(train_loader):.4f}")




Epoch 1: Loss = 2.8405
Epoch 2: Loss = 2.1008
Epoch 3: Loss = 0.5492
Epoch 4: Loss = 0.2658
Epoch 5: Loss = 0.1760


In [5]:
train_dataset[0][0]

tensor([ 5,  2, 10,  9,  9,  6,  5, 19,  4, 15])

In [11]:
for src, tgt in train_loader:
    src, tgt = src.to(DEVICE), tgt.to(DEVICE)
    print(src.shape,tgt.shape)

torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([32, 10]) torch.Size([32, 10])
torch.Size([32, 10]) torch.Size([3