In [1]:
import os
import pickle
import random
from tqdm import tqdm

import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset, random_split
from torchvision.transforms import Compose, RandomHorizontalFlip, ColorJitter, RandomAffine, RandomErasing, ToTensor, Resize
from sklearn.model_selection import KFold, StratifiedKFold
from sklearn.metrics import accuracy_score, roc_auc_score

import numpy as np
import pandas as pd
import pydicom

In [2]:
BATCH_SIZE = 128

In [3]:
BASEDIR = '../rsna-2023-abdominal-trauma-detection'

TRAIN_IMG_PATH = os.path.join(BASEDIR, 'train_images')
TRAIN_META_PATH = os.path.join(BASEDIR, 'train_series_meta.csv')
TEST_IMG_PATH = os.path.join(BASEDIR, 'test_images')
TEST_META_PATH = os.path.join(BASEDIR, 'test_series_meta.csv')

TRAIN_LABEL_PATH = os.path.join(BASEDIR, 'train.csv')

In [4]:
def fetch_img_paths_png():
    img_paths = []
    
    ppp = '../rsna-2023-png/train_images/'
    # ppp = '/kaggle/input/rsna-abdominal-trauma-detection-png-pt1'
    
    all_pngs = sorted(os.listdir(ppp))
    all_pngs = [os.path.join(ppp, d) for d in all_pngs]
    
    cur_ps = []
    png = all_pngs[0]
    prev = png[:png.rfind('_')]
    
    for png in tqdm(all_pngs):
        patient_series = png[:png.rfind('_')]
        if prev == patient_series:
            cur_ps.append(png)
        else:
            img_paths.append(cur_ps)
            cur_ps = [png]
        prev = patient_series

    if cur_ps:  # to make sure the last group is added too
        img_paths.append(cur_ps)
    
    return img_paths

def preprocess_png(png_path):
    img = cv2.imread(png_path)
    img = cv2.resize(img, (512, 512))
    greyscale = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)/255
    return greyscale

## Dataloader

In [5]:
def interpolate_channels(img_tensor):
    # Get the current number of channels
    C, H, W = img_tensor.shape

    # Initialize the output tensor
    output = torch.zeros((160, H, W))

    # Handle the edge case when C is 1
    if C == 1:
        for i in range(160):
            output[i] = img_tensor[0]
        return output

    # Handle the edge case when C is 2
    if C == 2:
        for i in range(80):
            output[i] = img_tensor[0]
        for i in range(80, 160):
            output[i] = img_tensor[1]
        return output

    # If channels are already 80 or more, return the original image
    if C >= 160:
        return img_tensor

    # Set the first and last channels
    output[0] = img_tensor[0]
    output[159] = img_tensor[-1]

    # Calculate the step for even spacing
    step = 158 / (C - 2)

    # Evenly space the remaining original channels in the range 1-78
    for i in range(1, C - 1):
        output[int(1 + i * step)] = img_tensor[i]

    # Perform linear interpolation
    for i in range(1, 159):
        if output[i].sum() == 0:
            left = i - 1
            right = i + 1
            while output[left].sum() == 0:
                left -= 1
            while output[right].sum() == 0:
                right += 1

            alpha = (i - left) / (right - left)

            output[i] = (1 - alpha) * output[left] + alpha * output[right]

    return output



class AbdominalData(Dataset):
    def __init__(self, df_path=TRAIN_LABEL_PATH, max_channel=4):
        super().__init__()
        
        # collect all the image instance paths
        self.img_paths = fetch_img_paths_png()
        self.max_channel = max_channel
                
        df = pd.read_csv(df_path, index_col='patient_id')
        self.df_dict = df.to_dict(orient='index')
        for key, value in self.df_dict.items():
            self.df_dict[key] = list(value.values())
            
        df_meta = pd.read_csv(TRAIN_META_PATH)
        df_meta['ps'] = df_meta['patient_id'].astype(str) + "_" + df_meta['series_id'].astype(str)
        self.df_meta_dict = df_meta.set_index('ps')['incomplete_organ'].to_dict()

        
    def __len__(self):
        return len(self.img_paths)
    
    def __getitem__(self, idx):
        dicom_images = self.img_paths[idx]
        
        patient_id = int(dicom_images[0].split('/')[-1].split('_')[0])
        series_id = int(dicom_images[0].split('/')[-1].split('_')[1])
        
        images = []
        
        for d in dicom_images:
            image = preprocess_png(d)
            images.append(image)
        
        images = np.stack(images)
        image = torch.tensor(images, dtype = torch.float32).unsqueeze(dim = 1)
        image = self.transform(image).squeeze(dim = 1) # torch.Size([1727, 512, 512])
        image = interpolate_channels(image)
        center_idx = image.shape[0] // 2
        image = image[center_idx-80:center_idx+80:4]
                
        label = self.df_dict[patient_id]
        #incomplete_organ = self.df_meta_dict[ps]

        # labels
        # bowel = np.argmax(label[0:2], keepdims = False )
        # extravasation = np.argmax(label[2:4], keepdims = False)
        # kidney = np.argmax(label[4:7], keepdims = False)
        # liver = np.argmax(label[7:10], keepdims = False)
        # spleen = np.argmax(label[10:], keepdims = False)
        bowel = label[0:2]
        extravasation = label[2:4]
        kidney = label[4:7]
        liver = label[7:10]
        spleen = label[10:13]
                
        return image, {
            'bowel': bowel,
            'extravasation': extravasation,
            'kidney': kidney,
            'liver': liver,
            'spleen': spleen,
            # 'incomplete_organ': incomplete_organ
        }

## Net

In [6]:
data = AbdominalData()
train_size = int(0.8 * len(data))
val_size = len(data) - train_size
train_data, val_data = random_split(data, [train_size, val_size])
train_dataloader = DataLoader(train_data, batch_size = BATCH_SIZE, shuffle = True)
val_dataloader = DataLoader(val_data, batch_size = BATCH_SIZE, shuffle = False)

100%|██████████| 1500653/1500653 [00:00<00:00, 1918512.06it/s]


In [8]:
test = torch.randn(1, 512, 512)
image = interpolate_channels(test)
print(image.shape)
center_idx = image.shape[0] // 2
image = image[center_idx-80:center_idx+80:4]
print(image.shape)

img, _ = data[0]
print(img.shape)
# train_dataloader = next(iter(train_dataloader))
# print(train_dataloader[0].shape)
# y = train_dataloader[1]
# for k in y:
#     y[k] = torch.stack(y[k]).transpose(0, 1)
# y
# from IPython.display import clear_output, display
# import time
# import matplotlib.pyplot as plt

# print(data[1620])
# l = data[1620][0].shape[0]
# for i, img in enumerate(data[1620][0]):    
#     plt.figure(figsize=(5, 5))
#     plt.imshow(img, cmap="gray")
#     plt.title(f"Frame {i}/{l}")
#     plt.axis(False)
#     plt.show()
#     clear_output(wait=True)
#     time.sleep(0.01)

torch.Size([160, 512, 512])
torch.Size([40, 512, 512])
torch.Size([40, 512, 512])


In [8]:
from train_test_utils import create_writer
from RSNA_model import RSNA_model

device = torch.device('cuda:5')

unet = RSNA_model().to(device)
optimizer = torch.optim.Adam(unet.parameters(), lr = 1e-3)
writer = create_writer("kaggle", "unet-vit", "v0")

criterion_bowel = nn.BCEWithLogitsLoss().to(device)
criterion_extravasation = nn.BCEWithLogitsLoss().to(device)
criterion_kidney = nn.CrossEntropyLoss().to(device)
criterion_liver = nn.CrossEntropyLoss().to(device)
criterion_spleen = nn.CrossEntropyLoss().to(device)

[INFO] Created SummaryWriter, saving to: runs/2023-10-12/kaggle/unet-vit/v0


In [9]:

from typing import Dict, List, Tuple
import pdb

def train_step(model: torch.nn.Module,
               dataloader: torch.utils.data.DataLoader,
               optimizer: torch.optim.Optimizer, scheduler: None,
               device: torch.device) -> Tuple[float, float]:

    # Put model in train mode
    model.train()

    # Setup train loss and train accuracy values
    train_loss, train_acc = 0, 0

    # Loop through data loader data batches
    for batch, (X, y) in enumerate(dataloader):
        # Send data to target device
        X = X.to(device)
        for k in y:
            y[k] = torch.stack(y[k]).transpose(0, 1).to(dtype=torch.float32)
            y[k] = y[k].to(device)

        # 1. Forward pass
        y_pred = model(X)

        # 2. Calculate  and accumulate loss
        loss_b = criterion_bowel(y_pred[0], y["bowel"])
        loss_e = criterion_extravasation(y_pred[1], y["extravasation"])
        loss_k = criterion_kidney(y_pred[2], y["kidney"].argmax(dim=1))
        loss_l = criterion_liver(y_pred[3], y["liver"].argmax(dim=1))
        loss_s = criterion_spleen(y_pred[4], y["spleen"].argmax(dim=1))
        
        total_loss = loss_b + loss_e + loss_k + loss_l + loss_s
        train_loss += total_loss.item()

        # 3. Optimizer zero grad
        optimizer.zero_grad()

        # 4. Loss backward
        total_loss.backward()

        # 5. Optimizer step
        optimizer.step()

        if scheduler:
            scheduler.step()

        # Calculate and accumulate accuracy metric across all batches
        acc_b = (torch.argmax(y_pred[0], dim=1) == torch.argmax(y["bowel"], dim=1)).sum().item()
        acc_e = (torch.argmax(y_pred[1], dim=1) == torch.argmax(y["extravasation"], dim=1)).sum().item()
        acc_k = (torch.argmax(y_pred[2], dim=1) == torch.argmax(y["kidney"], dim=1)).sum().item()
        acc_l = (torch.argmax(y_pred[3], dim=1) == torch.argmax(y["liver"], dim=1)).sum().item()
        acc_s = (torch.argmax(y_pred[4], dim=1) == torch.argmax(y["spleen"], dim=1)).sum().item()
        
        train_acc += (acc_b + acc_e + acc_k + acc_l + acc_s) / 5

    # Adjust metrics to get average loss and accuracy per batch
    train_loss = train_loss / len(dataloader)
    train_acc = train_acc / len(dataloader)
    return train_loss, train_acc


def test_step(model: torch.nn.Module,
              dataloader: torch.utils.data.DataLoader,
              device: torch.device) -> Tuple[float, float]:

    # Put model in eval mode
    model.eval()

    # Setup test loss and test accuracy values
    test_loss, test_acc = 0, 0

    # Turn on inference context manager
    with torch.inference_mode():
        # Loop through DataLoader batches
        for batch, (X, y) in enumerate(dataloader):
            # Send data to target device
            X = X.to(device)
            for k in y:
                y[k] = torch.stack(y[k]).transpose(0, 1).to(dtype=torch.float32) # [B, 2/3]        
                y[k] = y[k].to(device)

            # 1. Forward pass
            test_pred_logits = model(X)

            # 2. Calculate and accumulate loss
            loss_b = criterion_bowel(test_pred_logits[0], y["bowel"])
            loss_e = criterion_extravasation(test_pred_logits[1], y["extravasation"])
            loss_k = criterion_kidney(test_pred_logits[2], y["kidney"].argmax(dim=1))
            loss_l = criterion_liver(test_pred_logits[3], y["liver"].argmax(dim=1))
            loss_s = criterion_spleen(test_pred_logits[4], y["spleen"].argmax(dim=1))
            
            total_loss = loss_b + loss_e + loss_k + loss_l + loss_s
            test_loss += total_loss.item()

            # Calculate and accumulate accuracy
            acc_b = (torch.argmax(test_pred_logits[0], dim=1) == torch.argmax(y["bowel"], dim=1)).sum().item()
            acc_e = (torch.argmax(test_pred_logits[1], dim=1) == torch.argmax(y["extravasation"], dim=1)).sum().item()
            acc_k = (torch.argmax(test_pred_logits[2], dim=1) == torch.argmax(y["kidney"], dim=1)).sum().item()
            acc_l = (torch.argmax(test_pred_logits[3], dim=1) == torch.argmax(y["liver"], dim=1)).sum().item()
            acc_s = (torch.argmax(test_pred_logits[4], dim=1) == torch.argmax(y["spleen"], dim=1)).sum().item()
            
            test_acc += (acc_b + acc_e + acc_k + acc_l + acc_s) / 5

    # Adjust metrics to get average loss and accuracy per batch
    test_loss = test_loss / len(dataloader)
    test_acc = test_acc / len(dataloader)
    return test_loss, test_acc


# Add writer parameter to train()
def train(model: torch.nn.Module,
          train_dataloader: torch.utils.data.DataLoader,
          test_dataloader: torch.utils.data.DataLoader,
          optimizer: torch.optim.Optimizer,
          scheduler: None,
          epochs: int,
          device: torch.device,
          writer: torch.utils.tensorboard.writer.SummaryWriter # new parameter to take in a writer
          ) -> Dict[str, List]:

    # Create empty results dictionary
    results = {"train_loss": [],
               "train_acc": [],
               "test_loss": [],
               "test_acc": []
    }

    # Loop through training and testing steps for a number of epochs
    for epoch in tqdm(range(epochs)):
        train_loss, train_acc = train_step(model=model,
                                          dataloader=train_dataloader,
                                          optimizer=optimizer,
                                          scheduler = scheduler,
                                          device=device)
        test_loss, test_acc = test_step(model=model,
          dataloader=test_dataloader,
          device=device)

        # Print out what's happening
        print(
          f"Epoch: {epoch+1} | "
          f"train_loss: {train_loss:.4f} | "
          f"train_acc: {train_acc:.4f} | "
          f"test_loss: {test_loss:.4f} | "
          f"test_acc: {test_acc:.4f}"
        )

        # Update results dictionary
        results["train_loss"].append(train_loss)
        results["train_acc"].append(train_acc)
        results["test_loss"].append(test_loss)
        results["test_acc"].append(test_acc)


        ### New: Use the writer parameter to track experiments ###
        # See if there's a writer, if so, log to it
        if writer:
            writer.add_scalar(tag = "Train Loss", scalar_value = train_loss, global_step = epoch)
            writer.add_scalar(tag = "Test Loss", scalar_value = test_loss, global_step = epoch)
            writer.add_scalar(tag = "Train Acc", scalar_value = train_acc, global_step = epoch)
            writer.add_scalar(tag = "Test Acc", scalar_value = test_acc, global_step = epoch)

            writer.add_graph(model = model,
                             input_to_model = torch.randn(32, 3, 224, 224).to(device))


    if writer:
        writer.close()

    return results

import matplotlib.pyplot as plt
# Plot loss curves of a model
def plot_loss_curves(results):
    loss = results["train_loss"]
    test_loss = results["test_loss"]

    accuracy = results["train_acc"]
    test_accuracy = results["test_acc"]

    epochs = range(len(results["train_loss"]))

    plt.figure(figsize=(15, 7))

    # Plot loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, loss, label="train_loss")
    plt.plot(epochs, test_loss, label="test_loss")
    plt.title("Loss")
    plt.xlabel("Epochs")
    plt.legend()

    # Plot accuracy
    plt.subplot(1, 2, 2)
    plt.plot(epochs, accuracy, label="train_accuracy")
    plt.plot(epochs, test_accuracy, label="test_accuracy")
    plt.title("Accuracy")
    plt.xlabel("Epochs")
    plt.legend()

In [10]:
train(unet, train_dataloader, 
    val_dataloader, 
    optimizer,
    scheduler=None,
    epochs=1,
    device=device,
    writer=writer
)

torch.save(obj=unet.state_dict(), f="./unet.pth")

  0%|          | 0/1 [00:00<?, ?it/s]

: 