In [11]:
import importlib
import sys
import os
sys.path.insert(0, '../src')

In [12]:
import torch
import torch.nn as nn
import os
import json
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import pyarrow.parquet as pq
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm
import pytorch_lightning as pl
import sys
# from sklearn import *
from torchmetrics.classification import accuracy

In [13]:
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, DeviceStatsMonitor, TQDMProgressBar,EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.tuner import Tuner
from config import *
from data import data_utils
from data.dataset import ASL_DATASET
from dl_utils import get_dataloader

In [14]:
from src.models.models import LSTM_BASELINE_Model, LSTM_Predictor, TransformerPredictor

In [15]:
DEVICE

'mps'

In [16]:
class MyProgressBar(TQDMProgressBar):
    def init_validation_tqdm(self):
        bar = super().init_validation_tqdm()
        if not sys.stdout.isatty():
            bar.disable = True
        return bar

    def init_predict_tqdm(self):
        bar = super().init_predict_tqdm()
        if not sys.stdout.isatty():
            bar.disable = True
        return bar

    def init_test_tqdm(self):
        bar = super().init_test_tqdm()
        if not sys.stdout.isatty():
            bar.disable = True
        return bar

In [17]:
# MAX_SEQUENCES = 150
BATCH_SIZE = 256  # Not optimal as not a perfect power of 2, but maximum that fits in my GPU
num_workers = 4 #os.cpu_count() // 2  # or 0
mod_name = "FIRST_TRANSFORMER_MODEL_2"
DL_FRAMEWORK = "PYTORCH"

In [18]:
asl_dataset = ASL_DATASET(augment=True, augmentation_threshold=0.3)

In [19]:
train_ds, val_ds, test_ds = data_utils.create_data_loaders(asl_dataset,
                                                               batch_size=BATCH_SIZE,
                                                               dl_framework=DL_FRAMEWORK,
                                                               num_workers=num_workers)

In [20]:
print(f'Got the lengths for Train-Dataset: {len(train_ds)}, {len(val_ds)}, {len(test_ds)}')

Got the lengths for Train-Dataset: 468, 28, 56


In [21]:
batch = next(iter(train_ds))[0]
batch.dtype

torch.float32

In [22]:
model = TransformerPredictor(
        d_model=192,
        n_head=8,
        dim_feedforward=512,
        dropout=0.3,
        layer_norm_eps=1e-6,
        norm_first=False,
        batch_first=True,
        num_layers=3,
        num_classes=250,
        learning_rate = LEARNING_RATE
    )
model(batch)
model = model.float().to(DEVICE)
print(model)

TransformerPredictor(
  (model): TransformerSequenceClassifier(
    (transformer): TransformerEncoder(
      (layers): ModuleList(
        (0-2): 3 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=192, out_features=192, bias=True)
          )
          (linear1): Linear(in_features=192, out_features=512, bias=True)
          (dropout): Dropout(p=0.3, inplace=False)
          (linear2): Linear(in_features=512, out_features=192, bias=True)
          (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
          (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
          (dropout1): Dropout(p=0.3, inplace=False)
          (dropout2): Dropout(p=0.3, inplace=False)
        )
      )
      (norm): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
    )
    (output_layer): Linear(in_features=192, out_features=250, bias=True)
  )
  (criterion): CrossEntropyLoss()
  (accuracy): Mu

In [23]:
checkpoint_callback = ModelCheckpoint(
        filename=mod_name + "-{epoch:02d}-{val_accuracy:.2f}",
        save_top_k=1,
        monitor="val_accuracy",
        verbose=True,
        mode="max"
    )

In [24]:
tb_logger = TensorBoardLogger(
        save_dir=os.path.join(ROOT_PATH, "lightning_logs"),
        name=mod_name,
        # version=mod_name
    )

In [25]:
early_stop_callback = EarlyStopping(
        monitor='val_accuracy',
        min_delta=0.005,
        patience=6,
        verbose=True,
        mode='max'
    )

In [26]:
lr_monitor = LearningRateMonitor(logging_interval='step')

In [27]:
trainer = pl.Trainer(
        enable_progress_bar=True,
        accelerator="gpu",
        logger=tb_logger,
        callbacks=[
            DeviceStatsMonitor(),
            early_stop_callback,
           checkpoint_callback,
           MyProgressBar(),
            lr_monitor
        ],
        max_epochs=100,
       # limit_train_batches=10,
        # limit_val_batches=0,
        num_sanity_val_steps=0,
        profiler=None,  # select from None
    )

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [28]:
trainer.fit(
        model=model,
        train_dataloaders=train_ds,
        val_dataloaders=val_ds,

    )



  | Name      | Type                          | Params
------------------------------------------------------------
0 | model     | TransformerSequenceClassifier | 1.1 M 
1 | criterion | CrossEntropyLoss              | 0     
2 | accuracy  | MulticlassAccuracy            | 0     
------------------------------------------------------------
1.1 M     Trainable params
0         Non-trainable params
1.1 M     Total params
4.350     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_accuracy improved. New best score: 0.035
Epoch 0, global step 468: 'val_accuracy' reached 0.03464 (best 0.03464), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/FIRST_TRANSFORMER_MODEL_2/version_8/checkpoints/FIRST_TRANSFORMER_MODEL_2-epoch=00-val_accuracy=0.03.ckpt' as top 1


EPOCH 0, Validation Accuracy: 0.034826502203941345
 
EPOCH 0: Train accuracy: 0.016549669206142426
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.122 >= min_delta = 0.005. New best score: 0.157
Epoch 1, global step 936: 'val_accuracy' reached 0.15673 (best 0.15673), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/FIRST_TRANSFORMER_MODEL_2/version_8/checkpoints/FIRST_TRANSFORMER_MODEL_2-epoch=01-val_accuracy=0.16.ckpt' as top 1


EPOCH 1, Validation Accuracy: 0.15663893520832062
 
EPOCH 1: Train accuracy: 0.09451935440301895
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.135 >= min_delta = 0.005. New best score: 0.292
Epoch 2, global step 1404: 'val_accuracy' reached 0.29160 (best 0.29160), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/FIRST_TRANSFORMER_MODEL_2/version_8/checkpoints/FIRST_TRANSFORMER_MODEL_2-epoch=02-val_accuracy=0.29.ckpt' as top 1


EPOCH 2, Validation Accuracy: 0.2936282753944397
 
EPOCH 2: Train accuracy: 0.21381966769695282
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.089 >= min_delta = 0.005. New best score: 0.381
Epoch 3, global step 1872: 'val_accuracy' reached 0.38075 (best 0.38075), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/FIRST_TRANSFORMER_MODEL_2/version_8/checkpoints/FIRST_TRANSFORMER_MODEL_2-epoch=03-val_accuracy=0.38.ckpt' as top 1


EPOCH 3, Validation Accuracy: 0.38333672285079956
 
EPOCH 3: Train accuracy: 0.334629625082016
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.108 >= min_delta = 0.005. New best score: 0.489
Epoch 4, global step 2340: 'val_accuracy' reached 0.48864 (best 0.48864), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/FIRST_TRANSFORMER_MODEL_2/version_8/checkpoints/FIRST_TRANSFORMER_MODEL_2-epoch=04-val_accuracy=0.49.ckpt' as top 1


EPOCH 4, Validation Accuracy: 0.4883151054382324
 
EPOCH 4: Train accuracy: 0.43657970428466797
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.062 >= min_delta = 0.005. New best score: 0.551
Epoch 5, global step 2808: 'val_accuracy' reached 0.55054 (best 0.55054), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/FIRST_TRANSFORMER_MODEL_2/version_8/checkpoints/FIRST_TRANSFORMER_MODEL_2-epoch=05-val_accuracy=0.55.ckpt' as top 1


EPOCH 5, Validation Accuracy: 0.5514999628067017
 
EPOCH 5: Train accuracy: 0.5170468688011169
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.020 >= min_delta = 0.005. New best score: 0.571
Epoch 6, global step 3276: 'val_accuracy' reached 0.57070 (best 0.57070), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/FIRST_TRANSFORMER_MODEL_2/version_8/checkpoints/FIRST_TRANSFORMER_MODEL_2-epoch=06-val_accuracy=0.57.ckpt' as top 1


EPOCH 6, Validation Accuracy: 0.572882890701294
 
EPOCH 6: Train accuracy: 0.5743610858917236
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.034 >= min_delta = 0.005. New best score: 0.605
Epoch 7, global step 3744: 'val_accuracy' reached 0.60505 (best 0.60505), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/FIRST_TRANSFORMER_MODEL_2/version_8/checkpoints/FIRST_TRANSFORMER_MODEL_2-epoch=07-val_accuracy=0.61.ckpt' as top 1


EPOCH 7, Validation Accuracy: 0.6061198711395264
 
EPOCH 7: Train accuracy: 0.621048092842102
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.035 >= min_delta = 0.005. New best score: 0.640
Epoch 8, global step 4212: 'val_accuracy' reached 0.63998 (best 0.63998), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/FIRST_TRANSFORMER_MODEL_2/version_8/checkpoints/FIRST_TRANSFORMER_MODEL_2-epoch=08-val_accuracy=0.64.ckpt' as top 1


EPOCH 8, Validation Accuracy: 0.6409632563591003
 
EPOCH 8: Train accuracy: 0.656510591506958
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.014 >= min_delta = 0.005. New best score: 0.654
Epoch 9, global step 4680: 'val_accuracy' reached 0.65417 (best 0.65417), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/FIRST_TRANSFORMER_MODEL_2/version_8/checkpoints/FIRST_TRANSFORMER_MODEL_2-epoch=09-val_accuracy=0.65.ckpt' as top 1


EPOCH 9, Validation Accuracy: 0.6551761627197266
 
EPOCH 9: Train accuracy: 0.6831437945365906
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.028 >= min_delta = 0.005. New best score: 0.683
Epoch 10, global step 5148: 'val_accuracy' reached 0.68257 (best 0.68257), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/FIRST_TRANSFORMER_MODEL_2/version_8/checkpoints/FIRST_TRANSFORMER_MODEL_2-epoch=10-val_accuracy=0.68.ckpt' as top 1


EPOCH 10, Validation Accuracy: 0.6836022138595581
 
EPOCH 10: Train accuracy: 0.7049322724342346
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.016 >= min_delta = 0.005. New best score: 0.699
Epoch 11, global step 5616: 'val_accuracy' reached 0.69903 (best 0.69903), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/FIRST_TRANSFORMER_MODEL_2/version_8/checkpoints/FIRST_TRANSFORMER_MODEL_2-epoch=11-val_accuracy=0.70.ckpt' as top 1


EPOCH 11, Validation Accuracy: 0.7010958194732666
 
EPOCH 11: Train accuracy: 0.725311279296875
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Epoch 12, global step 6084: 'val_accuracy' was not in top 1


EPOCH 12, Validation Accuracy: 0.6947798132896423
 
EPOCH 12: Train accuracy: 0.7438212633132935
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.012 >= min_delta = 0.005. New best score: 0.711
Epoch 13, global step 6552: 'val_accuracy' reached 0.71068 (best 0.71068), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/FIRST_TRANSFORMER_MODEL_2/version_8/checkpoints/FIRST_TRANSFORMER_MODEL_2-epoch=13-val_accuracy=0.71.ckpt' as top 1


EPOCH 13, Validation Accuracy: 0.711749255657196
 
EPOCH 13: Train accuracy: 0.7608405351638794
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.016 >= min_delta = 0.005. New best score: 0.727
Epoch 14, global step 7020: 'val_accuracy' reached 0.72686 (best 0.72686), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/FIRST_TRANSFORMER_MODEL_2/version_8/checkpoints/FIRST_TRANSFORMER_MODEL_2-epoch=14-val_accuracy=0.73.ckpt' as top 1


EPOCH 14, Validation Accuracy: 0.7279154062271118
 
EPOCH 14: Train accuracy: 0.773835301399231
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Epoch 15, global step 7488: 'val_accuracy' reached 0.72856 (best 0.72856), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/FIRST_TRANSFORMER_MODEL_2/version_8/checkpoints/FIRST_TRANSFORMER_MODEL_2-epoch=15-val_accuracy=0.73.ckpt' as top 1


EPOCH 15, Validation Accuracy: 0.7290652394294739
 
EPOCH 15: Train accuracy: 0.7840132117271423
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.006 >= min_delta = 0.005. New best score: 0.733
Epoch 16, global step 7956: 'val_accuracy' reached 0.73282 (best 0.73282), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/FIRST_TRANSFORMER_MODEL_2/version_8/checkpoints/FIRST_TRANSFORMER_MODEL_2-epoch=16-val_accuracy=0.73.ckpt' as top 1


EPOCH 16, Validation Accuracy: 0.7340368032455444
 
EPOCH 16: Train accuracy: 0.7953646779060364
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Epoch 17, global step 8424: 'val_accuracy' was not in top 1


EPOCH 17, Validation Accuracy: 0.7298685312271118
 
EPOCH 17: Train accuracy: 0.8067269921302795
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.014 >= min_delta = 0.005. New best score: 0.746
Epoch 18, global step 8892: 'val_accuracy' reached 0.74645 (best 0.74645), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/FIRST_TRANSFORMER_MODEL_2/version_8/checkpoints/FIRST_TRANSFORMER_MODEL_2-epoch=18-val_accuracy=0.75.ckpt' as top 1


EPOCH 18, Validation Accuracy: 0.7474296689033508
 
EPOCH 18: Train accuracy: 0.8151064515113831
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Epoch 19, global step 9360: 'val_accuracy' was not in top 1


EPOCH 19, Validation Accuracy: 0.7427033185958862
 
EPOCH 19: Train accuracy: 0.8251281976699829
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Epoch 20, global step 9828: 'val_accuracy' reached 0.74673 (best 0.74673), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/FIRST_TRANSFORMER_MODEL_2/version_8/checkpoints/FIRST_TRANSFORMER_MODEL_2-epoch=20-val_accuracy=0.75.ckpt' as top 1


EPOCH 20, Validation Accuracy: 0.7484949827194214
 
EPOCH 20: Train accuracy: 0.8310579657554626
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.008 >= min_delta = 0.005. New best score: 0.754
Epoch 21, global step 10296: 'val_accuracy' reached 0.75440 (best 0.75440), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/FIRST_TRANSFORMER_MODEL_2/version_8/checkpoints/FIRST_TRANSFORMER_MODEL_2-epoch=21-val_accuracy=0.75.ckpt' as top 1


EPOCH 21, Validation Accuracy: 0.7560284733772278
 
EPOCH 21: Train accuracy: 0.8373644351959229
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Epoch 22, global step 10764: 'val_accuracy' reached 0.75610 (best 0.75610), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/FIRST_TRANSFORMER_MODEL_2/version_8/checkpoints/FIRST_TRANSFORMER_MODEL_2-epoch=22-val_accuracy=0.76.ckpt' as top 1


EPOCH 22, Validation Accuracy: 0.7566542029380798
 
EPOCH 22: Train accuracy: 0.8438284397125244
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Epoch 23, global step 11232: 'val_accuracy' reached 0.75781 (best 0.75781), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/FIRST_TRANSFORMER_MODEL_2/version_8/checkpoints/FIRST_TRANSFORMER_MODEL_2-epoch=23-val_accuracy=0.76.ckpt' as top 1


EPOCH 23, Validation Accuracy: 0.7583283185958862
 
EPOCH 23: Train accuracy: 0.8490179181098938
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.005 >= min_delta = 0.005. New best score: 0.760
Epoch 24, global step 11700: 'val_accuracy' reached 0.75951 (best 0.75951), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/FIRST_TRANSFORMER_MODEL_2/version_8/checkpoints/FIRST_TRANSFORMER_MODEL_2-epoch=24-val_accuracy=0.76.ckpt' as top 1


EPOCH 24, Validation Accuracy: 0.7607886791229248
 
EPOCH 24: Train accuracy: 0.8533704876899719
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Metric val_accuracy improved by 0.007 >= min_delta = 0.005. New best score: 0.766
Epoch 25, global step 12168: 'val_accuracy' reached 0.76633 (best 0.76633), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/FIRST_TRANSFORMER_MODEL_2/version_8/checkpoints/FIRST_TRANSFORMER_MODEL_2-epoch=25-val_accuracy=0.77.ckpt' as top 1


EPOCH 25, Validation Accuracy: 0.7666988968849182
 
EPOCH 25: Train accuracy: 0.8591304421424866
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Epoch 26, global step 12636: 'val_accuracy' was not in top 1


EPOCH 26, Validation Accuracy: 0.7649909257888794
 
EPOCH 26: Train accuracy: 0.8624807000160217
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Epoch 27, global step 13104: 'val_accuracy' was not in top 1


EPOCH 27, Validation Accuracy: 0.7655996680259705
 
EPOCH 27: Train accuracy: 0.8657236099243164
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Epoch 28, global step 13572: 'val_accuracy' reached 0.76917 (best 0.76917), saving model to '/Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/FIRST_TRANSFORMER_MODEL_2/version_8/checkpoints/FIRST_TRANSFORMER_MODEL_2-epoch=28-val_accuracy=0.77.ckpt' as top 1


EPOCH 28, Validation Accuracy: 0.7692269086837769
 
EPOCH 28: Train accuracy: 0.8691900372505188
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Epoch 29, global step 14040: 'val_accuracy' was not in top 1


EPOCH 29, Validation Accuracy: 0.7702752947807312
 
EPOCH 29: Train accuracy: 0.8723908066749573
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Epoch 30, global step 14508: 'val_accuracy' was not in top 1


EPOCH 30, Validation Accuracy: 0.76807701587677
 
EPOCH 30: Train accuracy: 0.8756888508796692
****************************************************************************************************


Validation: 0it [00:00, ?it/s]

Monitored metric val_accuracy did not improve in the last 6 records. Best score: 0.766. Signaling Trainer to stop.
Epoch 31, global step 14976: 'val_accuracy' was not in top 1


EPOCH 31, Validation Accuracy: 0.7689310312271118
 
EPOCH 31: Train accuracy: 0.8786638379096985
****************************************************************************************************


In [29]:
trainer.test(ckpt_path="best",
                 dataloaders=test_ds
                 )

Restoring states from the checkpoint path at /Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/FIRST_TRANSFORMER_MODEL_2/version_8/checkpoints/FIRST_TRANSFORMER_MODEL_2-epoch=28-val_accuracy=0.77.ckpt
Loaded model weights from the checkpoint at /Users/tgdimas1/git/CAS-AML-FINAL-PROJECT/notebooks/../src/../lightning_logs/FIRST_TRANSFORMER_MODEL_2/version_8/checkpoints/FIRST_TRANSFORMER_MODEL_2-epoch=28-val_accuracy=0.77.ckpt


Testing: 0it [00:00, ?it/s]

[{'test_loss': 0.9785623550415039, 'test_accuracy': 0.7791028022766113}]