In [None]:
from FileLoader import load_file
import torch
import torchvision
from torchvision import transforms
import torchmetrics
import distutils.version
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt

In [None]:
train_transforms = transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize([0.49044,], [0.24787,]),
                                    transforms.RandomAffine(degrees=(-10, 10), translate=(0, 0.1), scale=(0.8, 1.2)),
                                    transforms.RandomResizedCrop((224, 224), scale=(0.65, 1))

])

val_test_transforms = transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize([0.49044,], [0.24787,])
])

In [None]:
processed_path = "./processed" #Add path to the folder with processed images

In [None]:
train_dataset = torchvision.datasets.DatasetFolder(
    f"{processed_path}/train/", loader=load_file, extensions="npy", transform=train_transforms)

val_dataset = torchvision.datasets.DatasetFolder(
    f"{processed_path}/val/", loader=load_file, extensions="npy", transform=val_test_transforms)

test_dataset = torchvision.datasets.DatasetFolder(
    f"{processed_path}/test/", loader=load_file, extensions="npy", transform=val_test_transforms)

print(f"There are {len(train_dataset)} train images, {len(val_dataset)} val images and {len(test_dataset)} test images")

In [None]:
# Check dataset balances

np.unique(train_dataset.targets, return_counts=True), np.unique(val_dataset.targets, return_counts=True), np.unique(test_dataset.targets, return_counts=True)

In [None]:
# CHECK DATASETS
dataset = train_dataset # <- dataset to be tested

fig, axis = plt.subplots(2, 2, figsize=(9, 9))
for i in range(2):
    for j in range(2):
        random_index = np.random.randint(0, len(dataset))
        x_ray, label = dataset[random_index]
        axis[i][j].imshow(x_ray[0], cmap="bone")
        axis[i][j].set_title(f"Label:{label}")

In [None]:
batch_size = 16
workers = 6 # <- adjust based on your system's performance

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=workers, persistent_workers=True, shuffle=True, pin_memory=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, num_workers=workers, persistent_workers=True, shuffle=False, pin_memory=True)

In [None]:
#--------------------------------------------------------------------
#                        INITIALIZE DENSENET121
#--------------------------------------------------------------------

In [None]:
class PneumoniaModelDenseNet121(pl.LightningModule):
    def __init__(self, weight=(20672/6012)):
        super().__init__()
        
        self.model = torchvision.models.densenet121()
        self.model.features.conv0 = torch.nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2), padding=(3, 3), bias=False)
        self.model.classifier = torch.nn.Linear(in_features=1024, out_features=1)
        
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4, weight_decay=1e-3)
        self.loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([weight]))
        
        self.train_acc = torchmetrics.classification.BinaryAccuracy()
        self.val_acc = torchmetrics.classification.BinaryAccuracy()

    def forward(self, data):
        pred = self.model(data)
        return pred
    
    def training_step(self, batch, batch_idx):
        x_ray, label = batch
        label = label.float()
        pred = self(x_ray)[:,0]
        
        loss = self.loss_fn(pred, label)
        self.train_acc(torch.sigmoid(pred), label.int())
        
        self.log("Train Acc", self.train_acc, on_step=True, on_epoch=True)
        self.log("Train Loss", loss, on_step=True, on_epoch=True)
        return loss
        
    def validation_step(self, batch, batch_idx):
        x_ray, label = batch
        label = label.float()
        pred = self(x_ray)[:,0]
        
        loss = self.loss_fn(pred, label)
        self.val_acc(torch.sigmoid(pred), label.int())

        self.log("Val Acc", self.val_acc, on_step=True, on_epoch=True)
        self.log("Val Loss", loss, on_step=True, on_epoch=True)
        return loss
    
    def configure_optimizers(self):
        return [self.optimizer]

In [None]:
model = PneumoniaModelDenseNet121()

In [None]:
#--------------------------------------------------------------------
#                           TRAIN MODEL
#--------------------------------------------------------------------

In [None]:
checkpoint_callback = ModelCheckpoint(
    monitor='Val Acc_epoch',
    save_top_k=70,
    mode='max')

In [None]:
epochs = 70
trainer = pl.Trainer(logger=TensorBoardLogger(save_dir="./logs_densenet121"), log_every_n_steps=100, callbacks=checkpoint_callback, max_epochs=epochs)

In [None]:
trainer.fit(model, train_loader, val_loader)

In [None]:
#--------------------------------------------------------------------
#                             EVALUATE MODEL
#--------------------------------------------------------------------

In [None]:
model_version = "version_0" #Choose model version to be tested
checkpoint = "epoch=58-step=83662.ckpt" #Checkpoint file name

In [None]:
checkpoint_path = f"logs_densenet121/lightning_logs/{model_version}/checkpoints/{checkpoint}"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = PneumoniaModelDenseNet121.load_from_checkpoint(checkpoint_path)
model.eval()
model.to(device);

In [None]:
preds = []
labels = []

with torch.no_grad():
    for data, label in tqdm(test_dataset):
        data = data.to(device).float().unsqueeze(0)
        pred = torch.sigmoid(model(data)[0].cpu())
        preds.append(pred)
        labels.append(label)
preds = torch.tensor(preds)
labels = torch.tensor(labels).int()

In [None]:
thr = 0.48 #Specify classification threshold

In [None]:
acc = torchmetrics.classification.BinaryAccuracy(threshold=thr)(preds, labels)
precision = torchmetrics.classification.BinaryPrecision(threshold=thr)(preds, labels)
recall = torchmetrics.classification.BinaryRecall(threshold=thr)(preds, labels)
f1 = torchmetrics.classification.BinaryF1Score(threshold=thr)(preds, labels)
cm = torchmetrics.classification.BinaryConfusionMatrix(threshold=thr)(preds, labels)

print(f"Accuracy: {acc}")
print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"F1 Score {f1}")
print(f"Confusion Matrix:\n {cm}")

In [None]:
fig, axis = plt.subplots(3, 3, figsize=(9, 9))

for i in range(3):
    for j in range(3):
        rnd_idx = np.random.randint(0, len(preds))
        axis[i][j].imshow(test_dataset[rnd_idx][0][0], cmap="bone")
        axis[i][j].set_title(f"Pred:{int(preds[rnd_idx] > 0.5)}, Label:{labels[rnd_idx]}")
        axis[i][j].axis("off")