In [3]:
import seaborn as sns

import torch
import torch.nn as nn

sns.set_theme()

In [18]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()

        self.hidden_size = hidden_size

        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # initial hidden state
        h0 = torch.zeros(1, x.size(0), self.hidden_size).to(x.device)
        out, _ = self.rnn(x, h0)  # get RNN output
        # pass last output to Fully Connected layer
        out = self.fc(out[:, -1, :])
        return out


rnn = RNN(10, 20, 1)  # 10 features, 20 hidden units, 1 output
print(rnn)

RNN(
  (rnn): RNN(10, 20, batch_first=True)
  (fc): Linear(in_features=20, out_features=1, bias=True)
)


In [15]:
input_size = 10
time_steps = 5
output_size = 1

X = torch.randn(100, time_steps, input_size)  # 100 samples, 5 time steps, 10 features
Y = torch.randn(100, 1)


In [16]:

loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(rnn.parameters(), lr=0.004)

for epoch in range(1000):
    output = rnn(X)
    loss = loss_fn(output, Y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

Epoch 1, Loss: 1.3370981216430664
Epoch 2, Loss: 1.3342825174331665
Epoch 3, Loss: 1.331527829170227
Epoch 4, Loss: 1.3288323879241943
Epoch 5, Loss: 1.3261938095092773
Epoch 6, Loss: 1.3236100673675537
Epoch 7, Loss: 1.3210793733596802
Epoch 8, Loss: 1.318600058555603
Epoch 9, Loss: 1.3161700963974
Epoch 10, Loss: 1.313787579536438
Epoch 11, Loss: 1.3114514350891113
Epoch 12, Loss: 1.3091599941253662
Epoch 13, Loss: 1.306911587715149
Epoch 14, Loss: 1.3047049045562744
Epoch 15, Loss: 1.3025386333465576
Epoch 16, Loss: 1.3004112243652344
Epoch 17, Loss: 1.2983217239379883
Epoch 18, Loss: 1.2962685823440552
Epoch 19, Loss: 1.2942510843276978
Epoch 20, Loss: 1.2922680377960205
Epoch 21, Loss: 1.2903180122375488
Epoch 22, Loss: 1.2884001731872559
Epoch 23, Loss: 1.2865136861801147
Epoch 24, Loss: 1.2846574783325195
Epoch 25, Loss: 1.2828304767608643
Epoch 26, Loss: 1.2810320854187012
Epoch 27, Loss: 1.2792612314224243
Epoch 28, Loss: 1.2775171995162964
Epoch 29, Loss: 1.2757993936538696
E