# IMPORTS

In [None]:
!pip install pydicom
!pip install nibabel

import os

import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import pydicom
import tensorflow as tf
import torch

from google.colab import drive, files
from ipywidgets import interact, IntSlider
from torch import nn, optim
from torch.nn import BCELoss
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset, DataLoader
import gc
import torch.optim as optim
from tqdm import tqdm
from google.colab import drive


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

drive.mount('/content/drive')
drive_dir = '/content/drive/MyDrive/lung'
data_folders = [os.path.join(drive_dir, f) for f in os.listdir(drive_dir) if os.path.isdir(os.path.join(drive_dir, f))]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# MODEL

In [None]:
class DownCNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_padding=False):
        super(DownCNNBlock, self).__init__()
        padding = 1 if use_padding else 0

        # Two 3x3x3 convolutional layers with ReLU activation
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=(3, 3, 3), padding=padding)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=(3, 3, 3), padding=padding)
        self.relu2 = nn.ReLU()

        # 2x2x2 max pooling for downsampling
        self.pooling = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=2)

    def forward(self, input):
        # print("DownCNNBLOCK")
        res = self.relu1(self.conv1(input))
        res = self.relu2(self.conv2(res))
        out = self.pooling(res)
        return out, res

class UpCNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, last_layer=False):
        super(UpCNNBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        # 2x2x2 transposed convolution for upsampling
        self.upconv = nn.ConvTranspose3d(in_channels - out_channels, in_channels - out_channels, kernel_size=(2, 2, 2), stride=2)
        self.finalUpconv = nn.ConvTranspose3d(in_channels, in_channels, kernel_size=(2, 2, 2), stride=2)

        # Two 3x3x3 convolutional layers with ReLU activation
        self.conv1 = nn.Conv3d(in_channels + out_channels, out_channels, kernel_size=(3, 3, 3), padding=1)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=(3, 3, 3), padding=1)
        self.relu2 = nn.ReLU()
        self.conv1Final = nn.Conv3d(in_channels + in_channels, out_channels, kernel_size=(3, 3, 3), padding=1)

        # Flag to indicate if this is the last layer
        self.last_layer = last_layer

        # If it's the last layer, add a 1x1x1 convolution and sigmoid activation
        self.conv3 = nn.Conv3d(out_channels, 1, kernel_size=(1, 1, 1))
        self.sigmoid = nn.Sigmoid()

    def forward(self, input, residual=None):
        # Apply transposed convolution for upsampling
        if self.out_channels == 1:
            out = self.finalUpconv(input)
        else:
            out = self.upconv(input)
        if residual is not None:
            if len(out.shape) == 4:
                out = torch.cat((out, residual), dim=0)
            else:
              out = torch.cat((out, residual), dim=1)

        # Apply convolutions and activations
        if self.out_channels == 1:
            out = self.conv1Final(out)
        else:
            out = self.relu1(self.conv1(out))

        out = self.relu2(self.conv2(out))
        # Apply final convolution and sigmoid if it's the last layer
        if self.last_layer:
            out = self.conv3(out)
            out = self.sigmoid(out)
        return out


class UNet3D(nn.Module):
    def __init__(self, in_channels, level_channels=[16, 32, 64, 128, 256]):
        super(UNet3D, self).__init__()
        self.a_block1 = DownCNNBlock(in_channels, level_channels[0], use_padding=True)
        self.a_block2 = DownCNNBlock(level_channels[0], level_channels[1], use_padding=True)
        self.a_block3 = DownCNNBlock(level_channels[1], level_channels[2], use_padding=True)
        self.a_block4 = DownCNNBlock(level_channels[2], level_channels[3], use_padding=True)
        self.a_block5 = DownCNNBlock(level_channels[3], level_channels[4], use_padding=True)

        self.s_block5 = UpCNNBlock(level_channels[4] + level_channels[3], level_channels[3])
        self.s_block4 = UpCNNBlock(level_channels[3] + level_channels[2], level_channels[2])
        self.s_block3 = UpCNNBlock(level_channels[2] + level_channels[1], level_channels[1])
        self.s_block2 = UpCNNBlock(level_channels[1] + level_channels[0], level_channels[0])
        self.s_block1 = UpCNNBlock(level_channels[0], in_channels, last_layer=True)

    def forward(self, input):
        out1, res1 = self.a_block1(input)
        out2, res2 = self.a_block2(out1)
        out3, res3 = self.a_block3(out2)
        out4, res4 = self.a_block4(out3)
        out5, res5 = self.a_block5(out4)

        out = self.s_block5(out5, res5)
        out = self.s_block4(out, res4)
        out = self.s_block3(out, res3)
        out = self.s_block2(out, res2)
        out = self.s_block1(out, res1)
        return out

In [None]:
def dice_loss(prediction, target, smooth=1e-5):
    prediction = prediction.view(-1)
    target = target.view(-1)

    intersection = (prediction * target).sum()
    union = prediction.sum() + target.sum()

    dice = (2. * intersection + smooth) / (union + smooth)

    return 1.0 - dice

# DATA PROCESSING

In [None]:
class DICOMCacheDataset():
    def __init__(self, data_folders, cache_dir='numpy_dicom_images'):
        self.data_folders = data_folders
        self.cache_dir = cache_dir

    def __len__(self):
        return len(self.scans)

    def load_and_cache_volumes(self, input_dir, mask_path):
        folder_name = os.path.basename(os.path.dirname(input_dir))
        input_cache_path = os.path.join(self.cache_dir, folder_name + "/input.pt")
        mask_cache_path = os.path.join(self.cache_dir, folder_name + "/mask.pt")

        if os.path.exists(input_cache_path) and os.path.exists(mask_cache_path):
            # print(f"Loading volumes from cache for {input_dir}...")
            with open(input_cache_path, 'rb') as f:
                input_volume = torch.load(f, map_location=torch.device('cpu'), weights_only=True)
            with open(mask_cache_path, 'rb') as f:
                mask_volume = torch.load(f, map_location=torch.device('cpu'), weights_only=True)
        else:
            # print(f"Loading volumes from source and caching for {input_dir}...")
            dcm_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.dcm')]
            dcm_files.sort(key=lambda f: pydicom.dcmread(f).SliceLocation)
            input_volume = []

            for file_path in tqdm(dcm_files, desc="Loading DICOM files"):
                ds = pydicom.dcmread(file_path)
                input_volume.append(ds.pixel_array)
            input_volume = np.stack(input_volume, axis=0)

            mask_volume = nib.load(mask_path).get_fdata()
            mask_volume = np.asarray(mask_volume)
            mask_volume = np.rot90(mask_volume, k=1, axes=(0, 1))
            mask_volume = np.flip(mask_volume, axis=0)
            mask_volume = np.moveaxis(mask_volume, -1, 0)

            os.makedirs(os.path.dirname(input_cache_path), exist_ok=True)
            os.makedirs(os.path.dirname(mask_cache_path), exist_ok=True)

            input_volume = torch.from_numpy(input_volume).float()
            mask_volume = torch.from_numpy(mask_volume).float()

            torch.save(input_volume, input_cache_path)
            torch.save(mask_volume, mask_cache_path)
        return input_volume, mask_volume

    def get_volume(self, input_dir, mask_path):
        folder_name = os.path.basename(os.path.dirname(input_dir))
        input_cache_path = os.path.join(self.cache_dir, folder_name + "/input.pt")
        mask_cache_path = os.path.join(self.cache_dir, folder_name + "/mask.pt")
        if os.path.exists(input_cache_path) and os.path.exists(mask_cache_path):
            with open(input_cache_path, 'rb') as f:
                input_volume = torch.load(f, map_location=torch.device('cpu'), weights_only=True)
            with open(mask_cache_path, 'rb') as f:
                mask_volume = torch.load(f, map_location=torch.device('cpu'), weights_only=True)
            return input_volume, mask_volume
        else:
            return self.load_and_cache_volumes(input_dir, mask_path)

In [None]:
class GPUPatchLoader(Dataset):
    def __init__(self, data_folders, patches_per_vol=5, patch_size=256, cache_dir='numpy_dicom_images'):
        self.patches_per_vol = patches_per_vol
        self.patch_size = patch_size
        self.cache = DICOMCacheDataset(data_folders, cache_dir)
        self.data = data_folders

    def __len__(self):
        return len(self.data) * self.patches_per_vol

    def __getitem__(self, idx):
        if idx == 101:
            input_dir = '/content/drive/MyDrive/lung/1.3.6.1.4.1.14519.5.2.1.6279.6001.221945191226273284587353530424/Input'
            mask_path = '/content/drive/MyDrive/lung/1.3.6.1.4.1.14519.5.2.1.6279.6001.221945191226273284587353530424/Mask/1.3.6.1.4.1.14519.5.2.1.6279.6001.221945191226273284587353530424.nii.gz'
            example_volume, example_volume_mask = self.cache.load_and_cache_volumes(input_dir, mask_path)
            example_patch, example_mask = self.get_patch(example_volume, example_volume_mask, (480,240,253))
            return example_patch, example_mask
        volume_idx = idx // self.patches_per_vol
        input_dir = os.path.join(self.data[volume_idx], 'Input')
        mask_path = os.path.join(self.data[volume_idx], 'Mask', os.listdir(os.path.join(self.data[volume_idx], 'Mask'))[0])

        input_volume, mask_volume = self.cache.load_and_cache_volumes(input_dir, mask_path)
        print(self.data[volume_idx])
        print(input_volume.shape)
        input_patch, mask_patch = self.get_patch(input_volume, mask_volume)
        input_patch, mask_patch = self.augment_patch(input_patch, mask_patch)
        return input_patch, mask_patch

    def calculate_lung_bounding_box(mask_volume):
        # Find non-zero indices in the mask
        nonzero_indices = torch.nonzero(mask_volume)

        # Get minimum and maximum indices along each dimension
        min_d, min_h, min_w = nonzero_indices.min(dim=0).values
        max_d, max_h, max_w = nonzero_indices.max(dim=0).values

        # Add 30-voxel buffer
        min_d = max(0, min_d - 30)
        min_h = max(0, min_h - 30)
        min_w = max(0, min_w - 30)
        max_d = min(mask_volume.shape[0] - 1, max_d + 30)
        max_h = min(mask_volume.shape[1] - 1, max_h + 30)
        max_w = min(mask_volume.shape[2] - 1, max_w + 30)

        return min_d, min_h, min_w, max_d, max_h, max_w

    def get_patch(self, input_volume, mask_volume, coords=None):
        patch_size = self.patch_size
        mask_threshold = 500
        if coords is None:
            while True:
                d = torch.randint(patch_size // 2, input_volume.shape[0] - patch_size // 2 - 1, (1,)).item()
                h = torch.randint(patch_size // 2, input_volume.shape[1] - patch_size // 2 - 1, (1,)).item()
                w = torch.randint(patch_size // 2, input_volume.shape[2] - patch_size // 2 - 1, (1,)).item()

                input_patch = input_volume[d - patch_size // 2:d + patch_size // 2,
                                            h - patch_size // 2:h + patch_size // 2,
                                            w - patch_size // 2:w + patch_size // 2]
                mask_patch = mask_volume[d - patch_size // 2:d + patch_size // 2,
                                          h - patch_size // 2:h + patch_size // 2,
                                          w - patch_size // 2:w + patch_size // 2]

                if torch.sum(mask_patch) >= mask_threshold or torch.rand(1).item() >= 0.75:
                    break
        else:
            d, h, w = coords
            input_patch = input_volume[d - patch_size // 2:d + patch_size // 2,
                                        h - patch_size // 2:h + patch_size // 2,
                                        w - patch_size // 2:w + patch_size // 2]
            mask_patch = mask_volume[d - patch_size // 2:d + patch_size // 2,
                                      h - patch_size // 2:h + patch_size // 2,
                                      w - patch_size // 2:w + patch_size // 2]

        input_patch = input_patch.unsqueeze(0)
        mask_patch = mask_patch.unsqueeze(0)
        return input_patch, mask_patch

    def augment_patch(self, input_patch, mask_patch):
        flip, blackout, none = 0.05, 0.9, 0.05
        choice = torch.rand(1).item()
        if choice < flip:
            axis = torch.randint(1, 4, [1,]).item()
            input_patch = torch.flip(input_patch, dims=[axis])
            mask_patch = torch.flip(mask_patch, dims=[axis])
        elif choice < flip + blackout:
            blackout_size = 32
            mask_indices = torch.nonzero(mask_patch)
            d_start = torch.randint(blackout_size // 2, input_patch.shape[1] - blackout_size, (1,)).item()
            h_start = torch.randint(blackout_size // 2, input_patch.shape[2] - blackout_size, (1,)).item()
            w_start = torch.randint(blackout_size // 2, input_patch.shape[3] - blackout_size, (1,)).item()

            input_patch[:, d_start:d_start + blackout_size,
                        h_start:h_start + blackout_size,
                        w_start:w_start + blackout_size] = 0
            mask_patch[:, d_start:d_start + blackout_size,
                      h_start:h_start + blackout_size,
                      w_start:w_start + blackout_size] = 0
        return input_patch, mask_patch

def visualize_patch(input_patch, mask_patch, slice_index=64):
    plt.clf()
    fig, axes = plt.subplots(1, 2, figsize=(6, 5))

    input_patch_np = input_patch[0][slice_index].cpu().numpy()
    axes[0].imshow(input_patch_np, cmap='gray')
    axes[0].set_title(f"Input Patch Slice {slice_index}")
    del input_patch_np

    mask_patch_np = mask_patch[0][slice_index].cpu().numpy()
    axes[1].imshow(mask_patch_np, cmap='gray')
    axes[1].set_title(f"Mask Patch Slice {slice_index}")
    del mask_patch_np

    plt.show()
    plt.close()

# TRAINING

In [None]:
# patch_loader = GPUPatchLoader(data_folders[:10], patches_per_vol=5, patch_size=128)

# batch_size = 4
# for i in range(batch_size):
#   input_patch, mask_patch = patch_loader[i]
#   visualize_patch(input_patch, mask_patch)
#   del input_patch
#   del mask_patch
# del patch_loader
# gc.collect()
print(data_folders)

In [None]:
dummy_patch_loader = GPUPatchLoader(data_folders[:30], patches_per_vol=1, patch_size=256)
for i in range(len(dummy_patch_loader)):
    print(f"loading in data[{i}]")
    input_patch, mask_patch = dummy_patch_loader[i]
    del input_patch
    del mask_patch
del dummy_patch_loader
gc.collect()
# !zip -r /content/file.zip /content/numpy_dicom_images
# files.download("/content/file.zip")

In [None]:
from types import NoneType
model_GPU = UNet3D(in_channels=1).to(device)
optimizer = optim.Adam(model_GPU.parameters(), lr=1e-4)
criterion = dice_loss

patch_loader = GPUPatchLoader(data_folders[:24], patches_per_vol=5, patch_size=256)

num_epochs = 1000
slice_index = 64

loss_file_path = "loss_values.txt"
checkpoint_dir = '/content/drive/MyDrive/MyCheckpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

example_patch = None

train_losses = []
dice_scores = []
patch_losses = []

for epoch in range(num_epochs):
    model_GPU.train()
    epoch_loss = 0

    for i in tqdm(range(len(patch_loader)), desc=f"Epoch {epoch + 1}/{num_epochs}"):
        input_patch, mask_patch = patch_loader[i]
        input_patch = input_patch.unsqueeze(0).to(device)
        mask_patch = mask_patch.unsqueeze(0).to(device)

        optimizer.zero_grad()
        output = model_GPU(input_patch)
        loss = criterion(output, mask_patch)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    avg_epoch_loss = epoch_loss / len(patch_loader)
    train_losses.append(epoch_loss / len(patch_loader))

    model_GPU.eval()
    with torch.no_grad():
        if example_patch is None:
            example_patch, example_mask = patch_loader[101]
            example_patch = example_patch.unsqueeze(0).to(device)
            example_mask = example_mask.unsqueeze(0).to(device)

        output = model_GPU(example_patch)
        patch_loss = criterion(output, example_mask).item()
        patch_losses.append(patch_loss)

    if (epoch) % 10 == 0:
        with open(loss_file_path, 'a') as file:
            file.write(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_epoch_loss}\n")
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_epoch_loss}')

    if (epoch) % 10 == 0:
        o, m = output.detach().cpu().numpy(), example_mask.detach().cpu().numpy()
        fig, axes = plt.subplots(2, 3, figsize=(18, 8), gridspec_kw={'width_ratios': [1, 1, 1.5], 'height_ratios': [1, 1]})  # Changed to 2 rows, 3 columns
        axes[0, 0].imshow(o[0, 0, slice_index], cmap='jet')
        axes[0, 0].set_title(f"Output (Slice {slice_index})")
        axes[0, 0].axis('off')

        axes[0, 1].imshow(m[0, 0, slice_index], cmap='jet')
        axes[0, 1].set_title(f"True Mask (Slice {slice_index})")
        axes[0, 1].axis('off')

        axes[0, 2].plot(range(1, epoch + 2), patch_losses)
        axes[0, 2].set_title("Patch Performance Over Time")
        axes[0, 2].set_xlabel("Epoch")
        axes[0, 2].set_ylabel("Dice Loss")

        gs = axes[1, 0].get_gridspec()
        for ax in axes[1, :]:
            ax.remove()

        axbig = fig.add_subplot(gs[1, :])
        axbig.plot(range(1, epoch + 2), train_losses)
        axbig.set_title("Overall Model Performance Over Time")
        axbig.set_xlabel("Epoch")
        axbig.set_ylabel("Dice Loss")

        plt.tight_layout()
        plt.show()
        plt.close(fig)

        checkpoint_path = os.path.join(checkpoint_dir, f'model_epoch_{epoch + 1}.pth')
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model_GPU.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, checkpoint_path)
        print(f'Checkpoint saved at epoch {epoch + 1} to {checkpoint_path}')



In [None]:
test_patch_loader = GPUPatchLoader(data_folders[30:35], patches_per_vol=5, patch_size=256)
model_GPU.eval()

with torch.no_grad():
    total_dice_loss = 0
    for i in tqdm(range(len(test_patch_loader)), desc="Testing"):
        input_patch, mask_patch = test_patch_loader[i]
        input_patch = input_patch.unsqueeze(0).to(device)
        mask_patch = mask_patch.unsqueeze(0).to(device)

        output = model_GPU(input_patch)
        dice_loss_value = dice_loss(output, mask_patch).item()
        total_dice_loss += dice_loss_value

    average_dice_loss = total_dice_loss / len(test_patch_loader)
    print(f"Average Dice Loss on Test Set: {average_dice_loss}")
