Import Libraries and Prepare Data

In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [2]:
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7c82ba0a6cf0>

In [3]:
t = np.linspace(0, 100, 1000)
data = np.sin(t)

In [4]:
def create_sequences(data, seq_length):
    xs, ys = [], []
    for i in range(len(data) - seq_length):
        x = data[i:(i + seq_length)]
        y = data[i + seq_length]
        xs.append(x)
        ys.append(y)
    return np.array(xs), np.array(ys)

In [5]:
seq_length = 10
X, y = create_sequences(data, seq_length)

In [6]:
trainX = torch.tensor(X[:, :, None], dtype=torch.float32)
trainY = torch.tensor(y[:, None], dtype=torch.float32)

In [9]:
print(X.shape)        # (990, 10)
print(y.shape)        # (990,)
print(trainX.shape)   # torch.Size([990, 10, 1])
print(trainY.shape)   # torch.Size([990, 1])

# See the first input-target pair
print("first numpy array x:", X[0])        # 10 values: data[0:10]
print("first numpy array y:", y[0])        # the 11th value: data[10]
print("first tensor x:", trainX[0])   # 10 values: data[0:10]
print("first tensor y:", trainY[0])   # the 11th value: data[10]

(990, 10)
(990,)
torch.Size([990, 10, 1])
torch.Size([990, 1])
first numpy array x: [0.         0.09993302 0.19886554 0.29580708 0.3897871  0.47986471
 0.56513807 0.64475345 0.71791378 0.7838866 ]
first numpy array y: 0.8420114062884005
first tensor x: tensor([[0.0000],
        [0.0999],
        [0.1989],
        [0.2958],
        [0.3898],
        [0.4799],
        [0.5651],
        [0.6448],
        [0.7179],
        [0.7839]])
first tensor y: tensor([0.8420])


Define the LSTM Model

In [10]:
class LSTMModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
        super(LSTMModel, self).__init__()
        self.hidden_dim = hidden_dim
        self.layer_dim = layer_dim
        self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
        