In [1]:
import importlib
import sys
import os
sys.path.insert(0, '../src')

In [2]:
import torch
import torch.nn as nn
import os
import json
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import pyarrow.parquet as pq
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm
import pytorch_lightning as pl
import sys
# from sklearn import *
from torchmetrics.classification import accuracy

In [3]:
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, DeviceStatsMonitor, TQDMProgressBar,EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.tuner import Tuner
from config import *
from data import data_utils
from data.dataset import ASL_DATASET
from dl_utils import get_dataloader

In [4]:
DL_FRAMEWORK

'pytorch'

In [5]:
module_name = f"models.{DL_FRAMEWORK}.models"
class_name = "TransformerPredictor"

In [6]:
module = importlib.import_module(module_name)
MyTransformerModel = getattr(module, class_name)
MyTransformerModel

models.pytorch.models.TransformerPredictor

In [7]:
DEVICE

'mps'

In [8]:
class MyProgressBar(TQDMProgressBar):
    def init_validation_tqdm(self):
        bar = super().init_validation_tqdm()
        if not sys.stdout.isatty():
            bar.disable = True
        return bar

    def init_predict_tqdm(self):
        bar = super().init_predict_tqdm()
        if not sys.stdout.isatty():
            bar.disable = True
        return bar

    def init_test_tqdm(self):
        bar = super().init_test_tqdm()
        if not sys.stdout.isatty():
            bar.disable = True
        return bar

In [22]:
# Define parameters
params = {
    "d_model": 192,
    "n_head": 8,
    "dim_feedforward": 512,
    "dropout": 0.1,
    "layer_norm_eps": 1e-5,
    "norm_first": True,
    "batch_first": True,
    "num_layers": 2,
    "num_classes": 250,  
    "learning_rate": 0.001, 
}

In [23]:
asl_dataset = ASL_DATASET(augment=True, augmentation_threshold=0.3)

In [24]:
train_ds, val_ds, test_ds = data_utils.create_data_loaders(asl_dataset,
                                                               batch_size=BATCH_SIZE,
                                                               dl_framework=DL_FRAMEWORK,
                                                               num_workers=4)

In [25]:
print(f'Got the lengths for Train-Dataset: {len(train_ds)}, {len(val_ds)}, {len(test_ds)}')

Got the lengths for Train-Dataset: 468, 28, 56


In [26]:
batch = next(iter(train_ds))[0]
batch.shape

torch.Size([128, 32, 96, 2])

In [27]:
model = MyTransformerModel(**params)
model(batch)
model = model.float().to(DEVICE)
print(model)

TransformerPredictor(
  (criterion): CrossEntropyLoss()
  (accuracy): MulticlassAccuracy()
  (model): TransformerSequenceClassifier(
    (transformer): TransformerEncoder(
      (layers): ModuleList(
        (0-1): 2 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=192, out_features=192, bias=True)
          )
          (linear1): Linear(in_features=192, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=512, out_features=192, bias=True)
          (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
      (norm): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
    )
    (output_layer): Linear(in_features=192, out_features=2

In [28]:
checkpoint_callback = ModelCheckpoint(
        filename=class_name + "-{epoch:02d}-{val_accuracy:.2f}",
        save_top_k=1,
        monitor="val_accuracy",
        verbose=True,
        mode="max"
    )

In [29]:
tb_logger = TensorBoardLogger(
        save_dir=os.path.join(ROOT_PATH, "lightning_logs"),
        name=class_name,
        # version=mod_name
    )

In [30]:
early_stop_callback = EarlyStopping(
        monitor='val_accuracy',
        min_delta=0.005,
        patience=6,
        verbose=True,
        mode='max'
    )

In [31]:
lr_monitor = LearningRateMonitor(logging_interval='step')

In [32]:
trainer = pl.Trainer(
        enable_progress_bar=True,
        accelerator="gpu",
        logger=tb_logger,
        callbacks=[
            DeviceStatsMonitor(),
            early_stop_callback,
           checkpoint_callback,
           MyProgressBar(),
            lr_monitor
        ],
        max_epochs=100,
       # limit_train_batches=10,
        # limit_val_batches=0,
        num_sanity_val_steps=0,
        profiler=None,  # select from None
    )

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [33]:
trainer.fit(
        model=model,
        train_dataloaders=train_ds,
        val_dataloaders=val_ds,

    )



  | Name      | Type                          | Params
------------------------------------------------------------
0 | criterion | CrossEntropyLoss              | 0     
1 | accuracy  | MulticlassAccuracy            | 0     
2 | model     | TransformerSequenceClassifier | 741 K 
------------------------------------------------------------
741 K     Trainable params
0         Non-trainable params
741 K     Total params
2.965     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_accuracy improved. New best score: 0.173
Epoch 0, global step 468: 'val_accuracy' reached 0.17320 (best 0.17320), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/TransformerPredictor/version_1/checkpoints/TransformerPredictor-epoch=00-val_accuracy=0.17.ckpt' as top 1


EPOCH 0, Validation Accuracy: 0.17308409512043
 
EPOCH 0: Train accuracy: 0.06903800368309021
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.126 >= min_delta = 0.005. New best score: 0.299
Epoch 1, global step 936: 'val_accuracy' reached 0.29898 (best 0.29898), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/TransformerPredictor/version_1/checkpoints/TransformerPredictor-epoch=01-val_accuracy=0.30.ckpt' as top 1


EPOCH 1, Validation Accuracy: 0.29904794692993164
 
EPOCH 1: Train accuracy: 0.21435603499412537
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.082 >= min_delta = 0.005. New best score: 0.381
Epoch 2, global step 1404: 'val_accuracy' reached 0.38103 (best 0.38103), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/TransformerPredictor/version_1/checkpoints/TransformerPredictor-epoch=02-val_accuracy=0.38.ckpt' as top 1


EPOCH 2, Validation Accuracy: 0.3817809820175171
 
EPOCH 2: Train accuracy: 0.3444627523422241
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Epoch 3, global step 1872: 'val_accuracy' was not in top 1


EPOCH 3, Validation Accuracy: 0.31179824471473694
 
EPOCH 3: Train accuracy: 0.33515292406082153
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.133 >= min_delta = 0.005. New best score: 0.514
Epoch 4, global step 2340: 'val_accuracy' reached 0.51448 (best 0.51448), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/TransformerPredictor/version_1/checkpoints/TransformerPredictor-epoch=04-val_accuracy=0.51.ckpt' as top 1


EPOCH 4, Validation Accuracy: 0.5144920349121094
 
EPOCH 4: Train accuracy: 0.4358517527580261
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.014 >= min_delta = 0.005. New best score: 0.528
Epoch 5, global step 2808: 'val_accuracy' reached 0.52811 (best 0.52811), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/TransformerPredictor/version_1/checkpoints/TransformerPredictor-epoch=05-val_accuracy=0.53.ckpt' as top 1


EPOCH 5, Validation Accuracy: 0.5286712050437927
 
EPOCH 5: Train accuracy: 0.528845489025116
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.065 >= min_delta = 0.005. New best score: 0.593
Epoch 6, global step 3276: 'val_accuracy' reached 0.59284 (best 0.59284), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/TransformerPredictor/version_1/checkpoints/TransformerPredictor-epoch=06-val_accuracy=0.59.ckpt' as top 1


EPOCH 6, Validation Accuracy: 0.5925494432449341
 
EPOCH 6: Train accuracy: 0.5752254724502563
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.018 >= min_delta = 0.005. New best score: 0.610
Epoch 7, global step 3744: 'val_accuracy' reached 0.61045 (best 0.61045), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/TransformerPredictor/version_1/checkpoints/TransformerPredictor-epoch=07-val_accuracy=0.61.ckpt' as top 1


EPOCH 7, Validation Accuracy: 0.6103727221488953
 
EPOCH 7: Train accuracy: 0.6037247180938721
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.012 >= min_delta = 0.005. New best score: 0.622
Epoch 8, global step 4212: 'val_accuracy' reached 0.62237 (best 0.62237), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/TransformerPredictor/version_1/checkpoints/TransformerPredictor-epoch=08-val_accuracy=0.62.ckpt' as top 1


EPOCH 8, Validation Accuracy: 0.6226156949996948
 
EPOCH 8: Train accuracy: 0.6346408724784851
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Epoch 9, global step 4680: 'val_accuracy' reached 0.62663 (best 0.62663), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/TransformerPredictor/version_1/checkpoints/TransformerPredictor-epoch=09-val_accuracy=0.63.ckpt' as top 1


EPOCH 9, Validation Accuracy: 0.6281114816665649
 
EPOCH 9: Train accuracy: 0.657023012638092
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.040 >= min_delta = 0.005. New best score: 0.662
Epoch 10, global step 5148: 'val_accuracy' reached 0.66212 (best 0.66212), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/TransformerPredictor/version_1/checkpoints/TransformerPredictor-epoch=10-val_accuracy=0.66.ckpt' as top 1


EPOCH 10, Validation Accuracy: 0.6622024774551392
 
EPOCH 10: Train accuracy: 0.680515706539154
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.018 >= min_delta = 0.005. New best score: 0.680
Epoch 11, global step 5616: 'val_accuracy' reached 0.68030 (best 0.68030), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/TransformerPredictor/version_1/checkpoints/TransformerPredictor-epoch=11-val_accuracy=0.68.ckpt' as top 1


EPOCH 11, Validation Accuracy: 0.6813700795173645
 
EPOCH 11: Train accuracy: 0.7037332057952881
****************************************************************************************************


In [34]:
trainer.test(ckpt_path="best",
                 dataloaders=test_ds
                 )