In [1]:
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

from lightning_models import AELSTMPredictor
from torchinfo import summary

In [2]:
model = AELSTMPredictor(data_path = 'Dataset/IKCO1.csv',      
                        learning_rate = 3e-4,
                        weight_decay = 0,
                        split = [0.70, 0.15, 0.15],
                        batch_size = 128, 
                        input_size = 5,
                        code_size = 2,
                        intr_size = 3,
                        hidden_size = 128,
                        output_layer_size = 5,
                        num_layers = 3,
                        prob = 0.2,
                        timestep = 10,)


summary(model, input_size = (32, 10, 5))

Layer (type:depth-idx)                   Output Shape              Param #
AELSTMPredictor                          --                        --
├─AELSTM: 1-1                            [32, 5]                   --
│    └─Encoder: 2-1                      [32, 10, 2]               --
│    │    └─LSTM: 3-1                    [32, 10, 3]               120
│    │    └─LSTM: 3-2                    [32, 10, 2]               56
│    └─Decoder: 2-2                      [32, 10, 5]               --
│    │    └─LSTM: 3-3                    [32, 10, 3]               84
│    │    └─LSTM: 3-4                    [32, 10, 5]               200
│    └─PredictNextTimestep: 2-3          [32, 5]                   --
│    │    └─LSTM: 3-5                    [32, 10, 128]             331,776
│    │    └─Linear: 3-6                  [32, 5]                   645
│    │    └─Dropout: 3-7                 [32, 5]                   --
├─MSELoss: 1-2                           --                        --
├─MeanA

In [3]:
checkpoint_callback = ModelCheckpoint(dirpath = 'aelstmpredictor',
                                    every_n_epochs = 1,
                                    monitor = 'val_loss',
                                    mode = 'min')

early_stop_callback = EarlyStopping(monitor = "val_loss",
                                    min_delta = 0.00,
                                    patience = 10,
                                    verbose = False,
                                    mode = "min")

logger = TensorBoardLogger('aelstmpredictor-logs/', name = 'aelstmpredictor', version = 0)

In [4]:
trainer = Trainer(
    default_root_dir = 'lstmpredictor-logs/',
    gpus = (1 if torch.cuda.is_available() else 0),
    callbacks = [checkpoint_callback, early_stop_callback],
    max_epochs = 100,
    logger = logger)

trainer.fit(model = model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                | Type              | Params
----------------------------------------------------------
0 | model               | AELSTM            | 332 K 
1 | loss                | MSELoss           | 0     
2 | mean_absolute_error | MeanAbsoluteError | 0     
3 | mean_squared_error  | MeanSquaredError  | 0     
----------------------------------------------------------
332 K     Trainable params
0         Non-trainable params
332 K     Total params
1.332     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


                                                                            

  rank_zero_warn(
  rank_zero_warn(


Epoch 99: 100%|██████████| 26/26 [00:21<00:00,  1.20it/s, loss=0.0243, v_num=0, train_loss=0.0232]


In [5]:
trainer.test(datamodule = model)

  rank_zero_warn(
Restoring states from the checkpoint path at C:\Users\Yegyanathan V\Desktop\Python\Deep Learning\AE-LSTM\aelstmpredictor\epoch=97-step=2156.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at C:\Users\Yegyanathan V\Desktop\Python\Deep Learning\AE-LSTM\aelstmpredictor\epoch=97-step=2156.ckpt
  rank_zero_warn(


Testing DataLoader 0: 100%|██████████| 4/4 [00:00<00:00, 34.54it/s] 
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
           MAE              0.2007925808429718
           MSE              0.5661695599555969
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'MAE': 0.2007925808429718, 'MSE': 0.5661695599555969}]