In [1]:
import torch
import torch.nn as nn

class TransformerModel(nn.Module):
    def __init__(self, input_dim, output_dim, nhead, num_layers, hidden_dim, dropout=0.1):
        super(TransformerModel, self).__init__()
        self.encoder_layer = nn.TransformerEncoderLayer(input_dim, nhead, hidden_dim, dropout)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers)
        self.decoder_layer = nn.TransformerDecoderLayer(input_dim, nhead, hidden_dim, dropout)
        self.transformer_decoder = nn.TransformerDecoder(self.decoder_layer, num_layers)
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, src, tgt):
        memory = self.transformer_encoder(src)
        output = self.transformer_decoder(tgt, memory)
        output = self.linear(output)
        return output

In [2]:
# 예시 데이터 생성
seq_length = 50
data = torch.randn(seq_length, 32, 1)  # (시퀀스 길이, 배치 크기, 입력 차원)

# 모델 초기화
input_dim = 1 # 입력 차원
output_dim = 1 # 출력 차원
nhead = 16 # 어텐션 헤드 개수
num_layers = 2 # Transformer 레이어 개수
hidden_dim = 64 # 은닉 차원

In [3]:
model = TransformerModel(input_dim, output_dim, nhead, num_layers, hidden_dim)

AssertionError: embed_dim must be divisible by num_heads

In [None]:
# 훈련
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(100):
    optimizer.zero_grad()
    output = model(data[:-1], data[:-1])  # tgt를 입력과 동일하게 설정
    loss = criterion(output, data)
    loss.backward()
    optimizer.step()

    if (epoch+1) % 10 == 0:
        print(f'Epoch [{epoch+1}/100], Loss: {loss.item():.4f}')

In [None]:

# 테스트
with torch.no_grad():
    future = 20
    pred = model(data[:-1], data[:-1])
    for _ in range(future):
        next_pred = model(pred[-seq_length:], data[:-1])  # tgt를 입력과 동일하게 설정
        pred = torch.cat([pred, next_pred.unsqueeze(0)], dim=0)


In [None]:

# 시각화
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 6))
plt.plot(data[:, 0, :].numpy(), label='Original data')
plt.plot(range(seq_length, seq_length+future), pred[-future:, 0, :].numpy(), label='Predictions')
plt.legend()
plt.show()