In [14]:
from model import CNN, Classifier, ResidualBlock, ResNet
from pytorch_lightning import Trainer
import glob
from ghf import ActGHF
import pytorch_lightning as pl
import torch.nn as nn
from train import test_loader

import logging
logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)

list_model = {'GHF':ActGHF(), 'Logistic':nn.Sigmoid(), 'Tanh':nn.Tanh(), 'ReLU':nn.ReLU(),'Mish': nn.Mish(),'LeakyReLU':nn.LeakyReLU()}
version = 3
num_classes = 10

trainer = Trainer(
            accelerator="gpu", 
            devices=1, 
            callbacks=[
            pl.callbacks.ModelCheckpoint(
                monitor="val_acc",        # Metric to monitor
                mode="max",               # Save when max accuracy
                save_top_k=1,             # Save only the best model
                filename="{epoch}-{val_acc:.4f}",  # Include accuracy in filename
                save_last=False,          # Don't save final epoch if not best
                verbose=True              # Print when new best model is saved
            )
        ],)

for name, act in list_model.items():
    # Path to the checkpoint file
    checkpoint_path = f"logs/{name}/version_{version}/checkpoints"
    ckp_file = glob.glob(f'{checkpoint_path}/*.ckpt')
 
    # cnn_model = CNN(activation_fn=act)
    cnn_model = ResNet(ResidualBlock, [3, 3, 3], num_classes=num_classes, activation_fn=act)
    model = Classifier.load_from_checkpoint(ckp_file[0], model=cnn_model)

    print('-'*50)
    print(f'Accuracy with {name}: \n')
    results = trainer.test(model, dataloaders=test_loader)

    # 4 decimal precision
    for metric, value in results[0].items():
        print(f"{metric}: {value*100:.2f}")


--------------------------------------------------
Accuracy with GHF: 



Testing: |          | 0/? [00:00<?, ?it/s]

test_acc: 80.32
--------------------------------------------------
Accuracy with Logistic: 



Testing: |          | 0/? [00:00<?, ?it/s]

test_acc: 75.06
--------------------------------------------------
Accuracy with Tanh: 



Testing: |          | 0/? [00:00<?, ?it/s]

test_acc: 74.25
--------------------------------------------------
Accuracy with ReLU: 



Testing: |          | 0/? [00:00<?, ?it/s]

test_acc: 81.64
--------------------------------------------------
Accuracy with Mish: 



Testing: |          | 0/? [00:00<?, ?it/s]

test_acc: 82.28
--------------------------------------------------
Accuracy with LeakyReLU: 



Testing: |          | 0/? [00:00<?, ?it/s]

test_acc: 81.92


In [12]:
from train import test_loader, train_loader, val_loader
max_epochs = 50
act = ActGHF(t=0.6, m1=-1.001, m2=50)   # t=0.6, m1=-1.001, m2=50
trainer = Trainer(
            max_epochs=max_epochs,
            accelerator="gpu", 
            devices=1, 
            callbacks=[
            pl.callbacks.ModelCheckpoint(
                monitor="val_acc",        # Metric to monitor
                mode="max",               # Save when max accuracy
                save_top_k=1,             # Save only the best model
                filename="{epoch}-{val_acc:.4f}",  # Include accuracy in filename
                save_last=False,          # Don't save final epoch if not best
                verbose=True              # Print when new best model is saved
            )
        ],)

model = ResNet(ResidualBlock, [3, 3, 3], num_classes=num_classes, activation_fn=act)
classifier = Classifier(model, num_classes=num_classes)
trainer.fit(classifier, train_loader, val_loader)
trainer.test(classifier, dataloaders=test_loader)

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_acc': 0.7943000197410583}]

In [13]:
# Path to the checkpoint file
version = 41
checkpoint_path = f"lightning_logs/version_{version}/checkpoints"
ckp_file = glob.glob(f'{checkpoint_path}/*.ckpt')

# cnn_model = CNN(activation_fn=act)
cnn_model = ResNet(ResidualBlock, [3, 3, 3], num_classes=num_classes, activation_fn=act)
model = Classifier.load_from_checkpoint(ckp_file[0], model=cnn_model)

results = trainer.test(model, dataloaders=test_loader)

Testing: |          | 0/? [00:00<?, ?it/s]