In [1]:
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

from lightning_models import LSTMPredictor
from torchinfo import summary

In [2]:
model = LSTMPredictor(data_path = 'Dataset/IKCO1.csv',
                        split = [0.70, 0.15, 0.15],
                        batch_size = 128,       
                        learning_rate = 3e-4,
                        weight_decay = 0,
                        input_size = 5,
                        hidden_size = 128,
                        output_layer_size = 5,
                        num_layers = 3,
                        prob = 0,
                        timestep = 10,)


summary(model, input_size = (32, 10, 5))

Layer (type:depth-idx)                   Output Shape              Param #
LSTMPredictor                            --                        --
├─PredictNextTimestep: 1-1               [32, 5]                   --
│    └─LSTM: 2-1                         [32, 10, 128]             333,312
│    └─Linear: 2-2                       [32, 5]                   645
│    └─Dropout: 2-3                      [32, 5]                   --
├─MSELoss: 1-2                           --                        --
├─MeanAbsoluteError: 1-3                 --                        --
├─MeanSquaredError: 1-4                  --                        --
Total params: 333,957
Trainable params: 333,957
Non-trainable params: 0
Total mult-adds (M): 106.68
Input size (MB): 0.01
Forward/backward pass size (MB): 0.33
Params size (MB): 1.34
Estimated Total Size (MB): 1.67

In [3]:
checkpoint_callback = ModelCheckpoint(dirpath = 'lstmpredictor',
                                    every_n_epochs = 1,
                                    monitor = 'val_loss',
                                    mode = 'min')

early_stop_callback = EarlyStopping(monitor = "val_loss",
                                    min_delta = 0.00,
                                    patience = 8,
                                    verbose = False,
                                    mode = "min")

logger = TensorBoardLogger('lstmpredictor-logs/', name = 'lstmpredictor', version = 0)

In [4]:
trainer = Trainer(
    default_root_dir = 'lstmpredictor-logs/',
    gpus = (1 if torch.cuda.is_available() else 0),
    callbacks = [checkpoint_callback, early_stop_callback],
    max_epochs = 100,
    logger = logger)

trainer.fit(model = model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                | Type                | Params
------------------------------------------------------------
0 | model               | PredictNextTimestep | 333 K 
1 | loss                | MSELoss             | 0     
2 | mean_absolute_error | MeanAbsoluteError   | 0     
3 | mean_squared_error  | MeanSquaredError    | 0     
------------------------------------------------------------
333 K     Trainable params
0         Non-trainable params
333 K     Total params
1.336     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


                                                                            

  rank_zero_warn(
  rank_zero_warn(


Epoch 17: 100%|██████████| 26/26 [00:21<00:00,  1.21it/s, loss=0.00241, v_num=0, train_loss=0.00224]


In [5]:
trainer.test(datamodule = model)

  rank_zero_warn(
Restoring states from the checkpoint path at C:\Users\Yegyanathan V\Desktop\Python\Deep Learning\AE-LSTM\lstmpredictor\epoch=9-step=220.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at C:\Users\Yegyanathan V\Desktop\Python\Deep Learning\AE-LSTM\lstmpredictor\epoch=9-step=220.ckpt
  rank_zero_warn(


Testing DataLoader 0: 100%|██████████| 4/4 [00:00<00:00, 42.00it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
           MAE              0.1271575540304184
           MSE              0.5411623120307922
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'MAE': 0.1271575540304184, 'MSE': 0.5411623120307922}]

: 