In [1]:
import torch
import pandas as pd
import os
import sys
from torchvision import transforms
from torch.utils.data import DataLoader
import torch
from torchvision.models import efficientnet_b1
import torch.nn as nn
import yacs

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
notebook_dir = os.getcwd()
project_dir = os.path.dirname(notebook_dir)

# Add the src directory to sys.path
src_dir = os.path.join(project_dir, 'src')
if src_dir not in sys.path:
    sys.path.append(src_dir)

from utils import generate_class_weights, show_batch_images
from data import ChestXray14Dataset, build_transform_classification
from trainers import MultiLabelLightningModule
from models import set_model

## load dataframes

In [3]:
data_path = "/cluster/home/taheeraa/datasets/chestxray-14"
images_path = data_path + "/images"
path_to_labels = '/cluster/home/taheeraa/code/BenchmarkTransformers/dataset'
file_path_train = path_to_labels + '/Xray14_train_official.txt'
file_path_val = path_to_labels + '/Xray14_val_official.txt'
file_path_test = path_to_labels + '/Xray14_test_official.txt'

num_labels = 14

In [4]:
train_transforms, val_transforms, test_transforms = build_transform_classification(
        normalize="chestx-ray",
        test_augment=False,
        add_transforms=True
    )

train_dataset = ChestXray14Dataset(images_path=images_path, file_path=file_path_train,
                                   augment=train_transforms, num_class=num_labels)
val_dataset = ChestXray14Dataset(images_path=images_path, file_path=file_path_val,
                                 augment=val_transforms, num_class=num_labels)
test_dataset = ChestXray14Dataset(images_path=images_path, file_path=file_path_test,
                                  augment=test_transforms, num_class=num_labels)

## load dataset and transforms

In [5]:
img_size = 256
num_workers = 4
pin_memory = False
batch_size = 32

## load model

not all keys in ckpt file matches with the actual model?

In [6]:
model_name = 'resnet50'
experiment_name = 'testing-notebook'

In [13]:
#import timm
#model = timm.create_model('resnet50', num_classes=num_labels)

from torchvision.models import resnet50

def classifying_head(in_features: int, num_labels: int):
    return nn.Sequential(
        nn.Dropout(p=0.2),
        nn.Linear(in_features=in_features, out_features=128),
        nn.ReLU(),
        nn.BatchNorm1d(num_features=128),
        nn.Linear(128, num_labels),
    )

model = resnet50(weights='IMAGENET1K_V2')
img_size = int(224)
input_features = model.fc.in_features

model.fc = classifying_head(input_features, num_labels)
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [19]:
pretrained_weights = "/cluster/home/taheeraa/code/master-thesis/01-multi-label/output/v1-experiments/005-less-extensive-classifier-head/2024-05-08-14:30:21-resnet50-bce-14-multi-label-e35-bs64-lr0.0005-baseline-with-less-extensive-classifier-head-no-aug-omg/model_checkpoints/lightning_logs/version_0/checkpoints/epoch=1-step=2354.ckpt"

In [20]:
checkpoint = torch.load(pretrained_weights)
checkpoint.keys()

dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers'])

In [36]:
model_state_dict_ckpt = checkpoint['state_dict']
updated_state_dict = {key.replace('model.', ''): value for key, value in model_state_dict_ckpt.items()}

In [31]:
model.load_state_dict(updated_state_dict, strict=False)

_IncompatibleKeys(missing_keys=['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'layer1.0.conv1.weight', 'layer1.0.bn1.weight', 'layer1.0.bn1.bias', 'layer1.0.bn1.running_mean', 'layer1.0.bn1.running_var', 'layer1.0.conv2.weight', 'layer1.0.bn2.weight', 'layer1.0.bn2.bias', 'layer1.0.bn2.running_mean', 'layer1.0.bn2.running_var', 'layer1.0.conv3.weight', 'layer1.0.bn3.weight', 'layer1.0.bn3.bias', 'layer1.0.bn3.running_mean', 'layer1.0.bn3.running_var', 'layer1.0.downsample.0.weight', 'layer1.0.downsample.1.weight', 'layer1.0.downsample.1.bias', 'layer1.0.downsample.1.running_mean', 'layer1.0.downsample.1.running_var', 'layer1.1.conv1.weight', 'layer1.1.bn1.weight', 'layer1.1.bn1.bias', 'layer1.1.bn1.running_mean', 'layer1.1.bn1.running_var', 'layer1.1.conv2.weight', 'layer1.1.bn2.weight', 'layer1.1.bn2.bias', 'layer1.1.bn2.running_mean', 'layer1.1.bn2.running_var', 'layer1.1.conv3.weight', 'layer1.1.bn3.weight', 'layer1.1.bn3.bias', 'layer1.1.bn3.runni

## creating pytorch lightning module

In [33]:
criterion = nn.BCEWithLogitsLoss()
learning_rate = 0.001
optimizer_func = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler_func = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_func, mode='max', factor=0.1, patience=3, verbose=True)
logger = None

In [34]:
training_module = MultiLabelLightningModule(
    model=model,
    criterion=criterion,
    learning_rate=learning_rate,
    num_labels=num_labels,
    labels=labels,
    optimizer_func=optimizer_func,
    scheduler_func=scheduler_func,
    file_logger=logger,
    root_path=root_path,
    model_name=model_name,
    experiment_name=experiment_name,
    img_size=img_size,
    model_ckpts_folder="checkpoints/",
)