In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

In [None]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print(device)
else:
    print ("MPS device not found.")

In [None]:
batch_size = 64

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
class LSTM(nn.Module):
    def __init__(self, lstm_num_layers = 1, lstm_hidden_size = 64, lstm_dropout = 0.2, fc1_output_size = 16):
        super(LSTM, self).__init__()
        self.lstm_num_layers = lstm_num_layers
        self.lstm_hidden_size = lstm_hidden_size
        self.lstm_dropout = lstm_dropout
        self.fc1_output_size = fc1_output_size

        self.lstm = nn.LSTM(input_size = 8, 
                             hidden_size = self.lstm_hidden_size,
                             num_layers = self.lstm_num_layers,
                             batch_first = True)
        
        # self.lstm2 = nn.LSTM(input_size = self.lstm_hidden_size * 2, 
        #                      hidden_size = self.lstm_hidden_size,
        #                      num_layers = self.lstm_num_layers,
        #                      batch_first = True,
        #                      bidirectional = True)

        self.dropout = nn.Dropout(p = self.lstm_dropout)
        self.fc1 = nn.Linear(self.lstm_hidden_size, self.fc1_output_size)
        self.fc2 = nn.Linear(self.fc1_output_size, 1)

    def forward(self, x):
        h0 = torch.zeros(self.lstm_num_layers * 2, x.size(0), self.lstm_hidden_size).to(device)
        c0 = torch.zeros(self.lstm_num_layers * 2, x.size(0), self.lstm_hidden_size).to(device)

        x = F.relu(x)

        h_lstm, _ = self.lstm(x, (h0, c0))
        h_dropout = self.dropout(h_lstm)

        h_fc1 = self.fc1(h_dropout)
        h_fc1 = F.relu(h_fc1)

        h_fc2 = self.fc2(h_fc2)
        output = h_fc2[:, -1, :]

        return output