In [1]:
import sys
from torch.utils.data import DataLoader
import os
import timm
import torch
from pytorch_lightning import Trainer

In [2]:
# Get the path to the directory containing the notebook and the src folder
notebook_dir = os.getcwd()
project_dir = os.path.dirname(notebook_dir)

# Add the src directory to sys.path
src_dir = os.path.join(project_dir, 'src')
if src_dir not in sys.path:
    sys.path.append(src_dir)

from data import ChestXray14Dataset, build_transform_classification
from trainers import MultiLabelLightningModule

  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (


In [3]:
path_to_labels = '/cluster/home/taheeraa/code/BenchmarkTransformers/dataset'
file_path_train = path_to_labels + '/Xray14_train_official.txt'
file_path_val = path_to_labels + '/Xray14_val_official.txt'
file_path_test = path_to_labels + '/Xray14_test_official.txt'
checkpoint_path = '/cluster/home/taheeraa/code/BenchmarkTransformers/Models/Classification/ChestXray14/swin_base_simmim/swin_base_simmim_run_0.pth.tar'

data_path = '/cluster/home/taheeraa/datasets/chestxray-14'
images_path = data_path + '/images'

num_workers = 4
pin_memory = False

num_labels = 14
batch_size = 32
learning_rate = 0.01

labels = [
    "Atelectasis",
    "Cardiomegaly",
    "Effusion",
    "Infiltration",
    "Mass",
    "Nodule",
    "Pneumonia",
    "Pneumothorax",
    "Consolidation",
    "Edema",
    "Emphysema",
    "Fibrosis",
    "Pleural_Thickening",
    "Hernia"
]

In [4]:
test_transforms = build_transform_classification(
    normalize="chestx-ray", mode="test")

In [5]:
test_dataset = ChestXray14Dataset(images_path=images_path, file_path=file_path_test,
                                  augment=test_transforms, num_class=num_labels)


In [6]:
test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=pin_memory
)

In [7]:
model = timm.create_model(
            'swin_base_patch4_window7_224', num_classes=num_labels)

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
criterion = torch.nn.BCEWithLogitsLoss()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=10)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [8]:
training_module = MultiLabelLightningModule(
        model=model,
        criterion=criterion,
        learning_rate=learning_rate,
        num_labels=num_labels,
        labels=labels,
        optimizer_func=optimizer,
        scheduler_func=scheduler,
        model_ckpts_folder='model_ckpts_folder',
        model_name='swin',
        experiment_name='eval',
        img_size=224,
    )

In [11]:
checkpoint = torch.load(checkpoint_path)
ckpt_state_dict = checkpoint['state_dict']
incompatiable_keys = training_module.load_state_dict(ckpt_state_dict, strict=False)
model.load_state_dict(ckpt_state_dict)

SwinTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
    (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (layers): Sequential(
    (0): BasicLayer(
      dim=128, input_resolution=(56, 56), depth=2
      (blocks): ModuleList(
        (0): SwinTransformerBlock(
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=128, out_features=384, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=128, out_features=128, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=128, out_features=512, bias=True)
     

In [12]:
model.eval()
pl_trainer = Trainer()

results = pl_trainer.test(
    model=training_module,
    dataloaders=test_loader,
)

# Print the results
print("Test Results:", results)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Testing: |          | 0/? [00:00<?, ?it/s]

{'loss': 0.1948736827005632, 'f1': 0.062350154606974684, 'f1_micro': 0.2512708240095526, 'auroc': 0.41177231178153306}
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
            Test metric                       DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
            test_auroc                     0.4118366837501526
   test_auroc_atelectasis_epoch            0.6233771443367004
   test_auroc_cardiomegaly_epoch           0.3357979357242584
  test_auroc_consolidation_epoch            0.513145387172699
      test_auroc_edema_epoch               0.3694961667060852
     test_auroc_effusion_epoch             0.6794830560684204
    test_auroc_emphysema_epoch             0.33530178666114807
     test_auroc_fibrosis_epoch             0.24285291135311127
      test_auroc_hernia_epoch              0.04066082835197449
   test_auroc_infil