# Contralateral Model Pipeline


## Environment setup

In [None]:
import pandas as pd
import shutil
import random
from DicomRTTool.ReaderWriter import DicomReaderWriter, ROIAssociationClass
import SimpleITK as sitk
import matplotlib.pyplot as plt
import numpy as np
import os
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, ConcatDataset, TensorDataset
import pickle
import json
import torch.nn.functional as F
import tqdm
from sklearn.metrics import confusion_matrix, roc_auc_score, accuracy_score
import seaborn as sns
from sklearn.preprocessing import label_binarize
import wandb

## Data preparation

### Load train/dev/test split lists

In [37]:
path = '/home/lam3654/MSAI_pneumonitis/lung_cancer_radiomics/train_dev_test_lists'

with open(os.path.join(path, "small_train_data.json"), "r") as file:
    train_data = json.load(file)

with open(os.path.join(path, "small_train_labels.json"), "r") as file:
    train_labels = json.load(file)
    
with open(os.path.join(path, "small_dev_data.json"), "r") as file:
    train_data = json.load(file)

with open(os.path.join(path, "small_dev_labels.json"), "r") as file:
    train_labels = json.load(file)

with open(os.path.join(path, "small_test_data.json"), "r") as file:
    train_data = json.load(file)

with open(os.path.join(path, "small_test_labels.json"), "r") as file:
    train_labels = json.load(file)

### Load as numpy arrays

In [None]:
## set number of skipped slices
skip = 4

left_contours = ['l lung', 'tru left lung', 'lt lung 2', 'lt lung', 'lung_l', 'l_lung', 'left lung']
right_contours = ['new right lung', 'r_lung', 'rt lung', 'lung_r', 'r lung', 'right lung']

In [None]:
def lung_side(path_name):

    # Load the Excel spreadsheet into a pandas DataFrame
    df = pd.read_excel('/home/lam3654/MSAI_pneumonitis/total_labels.xlsx')

    # Define the unique column and the column to retrieve values from
    unique_column = 'anon_id'
    retrieve_column = 'side'

    # Prompt the user for the unique column value to search
    split_values = path_name.split('/')
    search_value = split_values[-1]

    # Find the rows matching the search value in the unique column
    matching_rows = df[df[unique_column] == search_value]

    # Retrieve the corresponding values from the retrieve column
    side_values = matching_rows[retrieve_column]

    if 'R' in side_values:
        return True
    else:
        return False


In [None]:
missed_contours = []
skipped_cts = []
total_arrays = []
total_labels = []

for i in range(len(train_data)):
    try:
        single_arrays = []
        Dicom_path = train_data[i]
        Dicom_reader = DicomReaderWriter(description='Examples', arg_max=True)
        Dicom_reader.walk_through_folders(Dicom_path) 
        # all_rois = Dicom_reader.return_rois(print_rois=True) # Return a list of all rois present

        right_side = lung_side(Dicom_path)
        if right_side == False:
            Contour_names = ['rlung'] 
            associations = [ROIAssociationClass('rlung', right_contours)]
        else:
            Contour_names = ['llung'] 
            associations = [ROIAssociationClass('llung', left_contours)]
        Dicom_reader.set_contour_names_and_associations(contour_names=Contour_names, associations=associations)
        indexes = Dicom_reader.which_indexes_have_all_rois()
        if indexes != []:
            pt_indx = indexes[-1]
            Dicom_reader.set_index(pt_indx) 
            Dicom_reader.get_images_and_mask()  

            image = Dicom_reader.ArrayDicom # image array
            mask = Dicom_reader.mask # mask array

            slice_locations = np.unique(np.where(mask != 0)[0]) # get indexes for where there is a contour present 
            slice_start = slice_locations[0] # first slice of contour 
            slice_end = slice_locations[len(slice_locations)-1] # last slice of contour

            counter = 1

            for img_arr, contour_arr in zip(image[slice_start:slice_end+1], mask[slice_start:slice_end+1]): 
                if counter % skip == 0: # if current slice is divisible by desired skip amount 
                    select = np.multiply(img_arr, contour_arr)
                    single_arrays.append(select)
                counter += 1

            single_labels = [train_labels[i] for x in range(len(single_arrays))]
            total_arrays = total_arrays + single_arrays
            total_labels = total_labels + single_labels
        else:
            missed_contours.append(Dicom_path)
    
    except TypeError:
        print("skip this dataset")
        skipped_cts.append(Dicom_path)
        continue

### Save array and label lists

In [None]:
np_folder_path = '/home/lam3654/MSAI_pneumonitis/data/pneumonitis_np'

In [None]:
np.save(os.path.join(np_folder_path, "ipsi_train_arrays.npy"), total_arrays)

with open(os.path.join(np_folder_path, "ipsi_train_labels_np.json"), "w") as file:
    json.dump(total_labels, file)

## Build Dataset

In [None]:
class CTDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image = self.data[index]
        label = self.labels[index]
        return image, label

In [None]:
train_array_list = np.load(os.path.join(np_folder_path, "ipsi_train_arrays.npy"), allow_pickle=True)


with open(os.path.join(np_folder_path, "ipsi_train_labels_np.json"), "r") as file:
    train_labels_list = json.load(file)

In [None]:
dataset = CTDataset(train_array_list, train_labels_list)
data_loader = DataLoader(dataset, batch_size=4, shuffle=True)

In [None]:
class CNNClassifier(nn.Module):
    def __init__(self):
        super(CNNClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 128 * 128, 512)
        self.fc2 = nn.Linear(512, 2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 128 * 128)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

## Training

In [None]:
# Check if a GPU is available and if not, default to CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Training on {device}")

# Initialize the model, loss function, and optimizer
model = CNNClassifier().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
accuracy_list = []
loss_list = []


output_list = []
num_epochs = 10 
for epoch in range(num_epochs):
    true_labels = []
    pred_labels = []
    for images, labels in data_loader:
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        images = images[:, np.newaxis, :, :]
        images = images.float()

        # Forward pass
        outputs = model(images)
        output_list.append(outputs)
        loss = criterion(outputs, labels)

        _, preds = torch.max(outputs, 1)  # Get the predicted classes

        true_labels.extend(labels.cpu().numpy())
        pred_labels.extend(preds.cpu().numpy())

        # Backward and optimize
        loss.backward()
        optimizer.step()

    # Convert to numpy arrays for use with sklearn
    true_labels = np.array(true_labels)
    pred_labels = np.array(pred_labels)
    accuracy = accuracy_score(true_labels, pred_labels)
    accuracy_list.append(accuracy)
    loss_list.append(loss.item())
    # Print the loss after each epoch
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

x_list = [x for x in range(num_epochs)]

plt.subplot(2, 1, 1)  # (rows, columns, subplot index)
plt.plot(x_list, loss_list)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training loss over epochs')

# Create the second subplot
plt.subplot(2, 1, 2)  # (rows, columns, subplot index)
plt.plot(x_list, accuracy_list)
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Training accuracy over epochs')

# Adjust the spacing between plots
plt.tight_layout()

# Display the plots
plt.show()


## Testing

### Prep dataset

In [None]:
with open(os.path.join(path, "small_test_data.json"), "r") as file:
    test_data = json.load(file)

with open(os.path.join(path, "small_test_labels.json"), "r") as file:
    test_labels = json.load(file)

In [None]:
missed_test_contours = []
skipped_test_cts = []
total_test_arrays = []
total_test_labels = []

for i in range(len(test_data)):
    try:
        single_arrays = []
        Dicom_path = test_data[i]
        Dicom_reader = DicomReaderWriter(description='Examples', arg_max=True)
        Dicom_reader.walk_through_folders(Dicom_path) 
        # all_rois = Dicom_reader.return_rois(print_rois=True) # Return a list of all rois present

        right_side = lung_side(Dicom_path)
        if right_side == False:
            Contour_names = ['rlung'] 
            associations = [ROIAssociationClass('rlung', right_contours)]
        else:
            Contour_names = ['llung'] 
            associations = [ROIAssociationClass('llung', left_contours)]
        Dicom_reader.set_contour_names_and_associations(contour_names=Contour_names, associations=associations)
        indexes = Dicom_reader.which_indexes_have_all_rois()
        if indexes != []:
            pt_indx = indexes[-1]
            Dicom_reader.set_index(pt_indx) 
            Dicom_reader.get_images_and_mask()  

            image = Dicom_reader.ArrayDicom # image array
            mask = Dicom_reader.mask # mask array

            slice_locations = np.unique(np.where(mask != 0)[0]) # get indexes for where there is a contour present 
            slice_start = slice_locations[0] # first slice of contour 
            slice_end = slice_locations[len(slice_locations)-1] # last slice of contour

            counter = 1

            for img_arr, contour_arr in zip(image[slice_start:slice_end+1], mask[slice_start:slice_end+1]): 
                if counter % skip == 0: # if current slice is divisible by desired skip amount 
                    select = np.multiply(img_arr, contour_arr)
                    single_arrays.append(select)
                counter += 1

            single_labels = [test_labels[i] for x in range(len(single_arrays))]
            total_test_arrays = total_test_arrays + single_arrays
            total_test_labels = total_test_labels + single_labels
        else:
            missed_test_contours.append(Dicom_path)
    
    except TypeError:
        print("skip this dataset")
        skipped_test_cts.append(Dicom_path)
        continue

### Save arrays

In [None]:
np.save(os.path.join(np_folder_path, "ipsi_test_arrays.npy"), total_test_arrays)

with open(os.path.join(np_folder_path, "ipsi_test_labels_np.json"), "w") as file:
    json.dump(total_test_labels, file)

### Load dataset

In [None]:
test_array_list = np.load(os.path.join(np_folder_path, "ipsi_test_arrays.npy"), allow_pickle=True)

with open(os.path.join(np_folder_path, "ipsi_test_labels_np.json"), "r") as file:
    test_labels_list = json.load(file)

In [None]:
test_dataset = CTDataset(test_array_list, test_labels_list)
test_data_loader = DataLoader(test_dataset, batch_size=4, shuffle=True)

### Train

In [None]:
model = model.to(device)
model.eval()  # Set the model to evaluation mode

true_labels = []
pred_labels = []
outputs_list = []

# Loop through the test data
for inputs, labels in test_data_loader:
    inputs, labels = inputs.to(device), labels.to(device)
    inputs = inputs[:, np.newaxis, :, :]
    inputs = inputs.float()


    # Forward pass
    outputs = model(inputs)
    _, preds = torch.max(outputs, 1)  # Get the predicted classes

    true_labels.extend(labels.cpu().numpy())
    pred_labels.extend(preds.cpu().numpy())
    outputs_list.extend(outputs.detach().cpu().numpy())

# Convert to numpy arrays for use with sklearn
true_labels = np.array(true_labels)
pred_labels = np.array(pred_labels)

# Compute ROC AUC
roc_auc = roc_auc_score(label_binarize(true_labels, classes=[0,1]),
                        label_binarize(pred_labels, classes=[0,1]), 
                        average='macro')

# Compute accuracy
accuracy = accuracy_score(true_labels, pred_labels)

# Compute confusion matrix
cm = confusion_matrix(true_labels, pred_labels)

# Plot confusion matrix
sns.heatmap(cm, annot=True, cmap='Blues', fmt='g')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()

print(f"Accuracy: {accuracy}")
print(f"ROC AUC: {roc_auc}")