In [1]:
import torch
import torch.nn as nn

In [2]:
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()

        self.hidden_size = hidden_size

        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # initial hidden state
        h0 = torch.zeros(1, x.size(0), self.hidden_size).to(x.device)
        # initial cell state
        c0 = torch.zeros(1, x.size(0), self.hidden_size).to(x.device)

        out, _ = self.lstm(x, (h0, c0))  # get LSTM output
        # pass last output to Fully Connected layer
        out = self.fc(out[:, -1, :])

        return out


lstm = LSTM(10, 20, 1)  # 10 features, 20 hidden units, 1 output
print(lstm)

LSTM(
  (lstm): LSTM(10, 20, batch_first=True)
  (fc): Linear(in_features=20, out_features=1, bias=True)
)


In [3]:
input_size = 10
time_steps = 5
output_size = 1

X = torch.randn(100, time_steps, input_size)  # 100 samples, 5 time steps, 10 features
Y = torch.randn(100, 1)


In [4]:
loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(lstm.parameters(), lr=0.01)

for epoch in range(100):
    output = lstm(X)
    loss = loss_fn(output, Y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

Epoch 1, Loss: 0.916992723941803
Epoch 2, Loss: 0.9167881011962891
Epoch 3, Loss: 0.9165842533111572
Epoch 4, Loss: 0.9163811206817627
Epoch 5, Loss: 0.916178822517395
Epoch 6, Loss: 0.9159771800041199
Epoch 7, Loss: 0.9157764315605164
Epoch 8, Loss: 0.9155763983726501
Epoch 9, Loss: 0.9153769612312317
Epoch 10, Loss: 0.9151782393455505
Epoch 11, Loss: 0.9149802327156067
Epoch 12, Loss: 0.9147830009460449
Epoch 13, Loss: 0.9145863056182861
Epoch 14, Loss: 0.9143903851509094
Epoch 15, Loss: 0.9141950011253357
Epoch 16, Loss: 0.9140003323554993
Epoch 17, Loss: 0.913806140422821
Epoch 18, Loss: 0.9136127233505249
Epoch 19, Loss: 0.913419783115387
Epoch 20, Loss: 0.9132277965545654
Epoch 21, Loss: 0.913036048412323
Epoch 22, Loss: 0.9128448963165283
Epoch 23, Loss: 0.9126545786857605
Epoch 24, Loss: 0.9124644994735718
Epoch 25, Loss: 0.9122752547264099
Epoch 26, Loss: 0.9120864272117615
Epoch 27, Loss: 0.9118981957435608
Epoch 28, Loss: 0.9117104411125183
Epoch 29, Loss: 0.9115233421325684