In [None]:
import torch
import torchvision
import torchmetrics
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
import numpy as np
import cv2
import imgaug.augmenters as iaa
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from dataset import CardiacDataset

In [None]:
processed_path = "./processed_heart_detection" #Add path to the folder with processed images and subjects files
labels_path = "./heart_detection_labels.csv" #Add path to the labels file for the train/val/test dataset

In [None]:
train_root_path = f"{processed_path}/train/"
train_subjects = f"{processed_path}/train_subjects.npy"
val_root_path = f"{processed_path}/val/"
val_subjects = f"{processed_path}/val_subjects.npy"
test_root_path = f"{processed_path}/test/"
test_subjects = f"{processed_path}/test_subjects.npy"

In [None]:
train_transforms = iaa.Sequential([
                                iaa.GammaContrast(),
                                iaa.Affine(
                                    scale=(0.8, 1.2),
                                    rotate=(-10, 10),
                                    translate_px=(-10, 10)
                                )
                            ])

In [None]:
train_dataset = CardiacDataset(labels_path, train_subjects, train_root_path, train_transforms)
val_dataset = CardiacDataset(labels_path, val_subjects, val_root_path, None)
test_dataset = CardiacDataset(labels_path, test_subjects, test_root_path, None)
print(f"There are {len(train_dataset)} train images, {len(val_dataset)} val images, and {len(test_dataset)} test images")

In [None]:
batch_size = 8
workers = 6 # <- adjust based on your system's performance

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=workers, persistent_workers=True, shuffle=True, pin_memory=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, num_workers=workers, persistent_workers=True, shuffle=False, pin_memory=True)

In [None]:
#--------------------------------------------------------------------
#                         INITIALIZE RESNET18
#--------------------------------------------------------------------

In [None]:
class CardiacDetectionModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        
        self.model = torchvision.models.resnet18()
        self.model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.model.fc = torch.nn.Linear(in_features=512 ,out_features=4)
        
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
        self.loss_fn = torch.nn.MSELoss()
        
    def forward(self, data):
        return self.model(data)
    
    def training_step(self, batch, batch_idx):
        x_ray, label = batch
        label = label.float()
        pred = self(x_ray)
        loss = self.loss_fn(pred, label)
        train_acc = abs(pred-label).mean(0)
        train_acc = torch.mean(train_acc)
        
        self.log("Train Mean Offset", train_acc, on_step=False, on_epoch=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x_ray, label = batch
        label = label.float()
        pred = self(x_ray)
        loss = self.loss_fn(pred, label)
        val_acc = abs(pred-label).mean(0)
        val_acc = torch.mean(val_acc)
        
        self.log("Val Mean Offset", val_acc, on_step=False, on_epoch=True)
        return loss
        
    def configure_optimizers(self):
        return [self.optimizer]


In [None]:
model = CardiacDetectionModel()

In [None]:
#--------------------------------------------------------------------
#                           TRAIN MODEL
#--------------------------------------------------------------------

In [None]:
checkpoint_callback = ModelCheckpoint(
    monitor='Train Mean Offset',
    save_top_k=140,
    mode='min')

In [None]:
epochs = 140
trainer = pl.Trainer(logger=TensorBoardLogger("./logs_heart"), log_every_n_steps=1, callbacks=checkpoint_callback, max_epochs=epochs)

In [None]:
trainer.fit(model, train_loader, val_loader)

In [None]:
#--------------------------------------------------------------------
#                             EVALUATE MODEL
#--------------------------------------------------------------------

In [None]:
model_version = "version_0" #Choose model version to be tested
checkpoint = "epoch=108-step=5450.ckpt" #Checkpoint file name

In [None]:
checkpoint_path = f"logs_heart/lightning_logs/{model_version}/checkpoints/{checkpoint}"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = CardiacDetectionModel.load_from_checkpoint(checkpoint_path)
model.eval()
model.to(device)

In [None]:
preds = []
labels = []

with torch.no_grad():
    for data, label in test_dataset:
        data = data.to(device).float().unsqueeze(0)
        pred = model(data)[0].cpu()
        preds.append(pred)
        labels.append(label)
        
preds=torch.stack(preds)
labels=torch.stack(labels)

In [None]:
offset_all = abs(preds-labels).mean(0)
offset = torch.mean(offset_all)
print(f"Mean offset: {offset}, mean per axis: {offset_all}")

In [None]:
#--------------------------------------------------------------------
#                             EXAMPLE
#--------------------------------------------------------------------

In [None]:
fig, axis = plt.subplots(2, 2)
for i in range(2):
    for j in range(2):
        random_index = np.random.randint(0, len(test_dataset))
        x_ray, labels = test_dataset[random_index]
        x = labels[0]
        y = labels[1]
        width = labels[2]-labels[0]
        height = labels[3]-labels[1]
        
        axis[i][j].imshow(x_ray[0], cmap="bone")
        rect = patches.Rectangle((x, y), width, height, linewidth=1, edgecolor="r", facecolor='none')
        axis[i][j].add_patch(rect)

        preds1 = []
        labels1 = []
        
        with torch.no_grad():
            x_ray = x_ray.to(device).float().unsqueeze(0)
            pred = model(x_ray)[0].cpu()
            preds1.append(pred)
            labels1.append(pred)
                
        preds1=torch.stack(preds1)
        labels1=torch.stack(labels1)
        current_pred = preds1[0]

        heart = patches.Rectangle((current_pred[0], current_pred[1]), current_pred[2]-current_pred[0],
                                  current_pred[3]-current_pred[1], linewidth=1, edgecolor='g', facecolor='none')
        axis[i][j].add_patch(heart)