# IMPORTS

In [None]:
!pip install pydicom
!pip install nibabel

from torch import nn
import torch
import torch.optim as optim
from google.colab import drive
import matplotlib.pyplot as plt
import numpy as np
from ipywidgets import interact, IntSlider
from torch.optim.lr_scheduler import StepLR
import pydicom
import nibabel as nib
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import gc
import torch.optim as optim
from torch.nn import BCELoss
from tqdm import tqdm
import tensorflow as tf

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#GIF SAVER

In [None]:
import imageio.v2 as imageio

def create_gif(display_function, frames_range, filename="visualization.gif", fps=5):
    """
    Creates a GIF from an interactive visualization.

    Args:
        display_function: The function that generates the visualization for a given slice index.
        frames_range: A range or list of slice indices to include in the GIF.
        filename: The name of the output GIF file.
        fps: Frames per second for the GIF.
    """

    frames = []
    for slice_index in frames_range:
        print(slice_index)
        # Call the display function to generate the plot
        display_function(slice_index)
        plt.savefig(f"frame_{slice_index}.png")
        plt.close()
        frames.append(f"frame_{slice_index}.png")

    imageio.mimsave(filename, [imageio.imread(f) for f in frames], fps=fps)

    # Clean up temporary frame files
    for f in frames:
        os.remove(f)


# MODEL

In [None]:
class DownCNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_padding=False):
        super(DownCNNBlock, self).__init__()
        padding = 1 if use_padding else 0

        # Two 3x3x3 convolutional layers with ReLU activation
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=(3, 3, 3), padding=padding)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=(3, 3, 3), padding=padding)
        self.relu2 = nn.ReLU()

        # 2x2x2 max pooling for downsampling
        self.pooling = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=2)

    def forward(self, input):
        # print("DownCNNBLOCK")
        res = self.relu1(self.conv1(input))
        res = self.relu2(self.conv2(res))
        out = self.pooling(res)
        return out, res

class UpCNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, last_layer=False):
        super(UpCNNBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        # 2x2x2 transposed convolution for upsampling
        self.upconv = nn.ConvTranspose3d(in_channels - out_channels, in_channels - out_channels, kernel_size=(2, 2, 2), stride=2)
        self.finalUpconv = nn.ConvTranspose3d(in_channels, in_channels, kernel_size=(2, 2, 2), stride=2)

        # Two 3x3x3 convolutional layers with ReLU activation
        self.conv1 = nn.Conv3d(in_channels + out_channels, out_channels, kernel_size=(3, 3, 3), padding=1)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=(3, 3, 3), padding=1)
        self.relu2 = nn.ReLU()
        self.conv1Final = nn.Conv3d(in_channels + in_channels, out_channels, kernel_size=(3, 3, 3), padding=1)

        # Flag to indicate if this is the last layer
        self.last_layer = last_layer

        # If it's the last layer, add a 1x1x1 convolution and sigmoid activation
        self.conv3 = nn.Conv3d(out_channels, 1, kernel_size=(1, 1, 1))
        self.sigmoid = nn.Sigmoid()

    def forward(self, input, residual=None):
        # Apply transposed convolution for upsampling
        if self.out_channels == 1:
            out = self.finalUpconv(input)
        else:
            out = self.upconv(input)
        if residual is not None:
            if len(out.shape) == 4:
                out = torch.cat((out, residual), dim=0)
            else:
              out = torch.cat((out, residual), dim=1) #!!!!!!!!!!!

        # Apply convolutions and activations
        if self.out_channels == 1:
            out = self.conv1Final(out)
        else:
            out = self.relu1(self.conv1(out))

        out = self.relu2(self.conv2(out))
        # Apply final convolution and sigmoid if it's the last layer
        if self.last_layer:
            out = self.conv3(out)
            out = self.sigmoid(out)
        return out


class UNet3D(nn.Module):
    def __init__(self, in_channels, level_channels=[16, 32, 64]):
        super(UNet3D, self).__init__()
        self.a_block1 = DownCNNBlock(in_channels, level_channels[0], use_padding=True)
        self.a_block2 = DownCNNBlock(level_channels[0], level_channels[1], use_padding=True)
        self.a_block3 = DownCNNBlock(level_channels[1], level_channels[2], use_padding=True)

        self.s_block3 = UpCNNBlock(level_channels[2] + level_channels[1], level_channels[1])
        self.s_block2 = UpCNNBlock(level_channels[1] + level_channels[0], level_channels[0])
        self.s_block1 = UpCNNBlock(level_channels[0], in_channels, last_layer=True)

    def forward(self, input):
        out1, res1 = self.a_block1(input)
        out2, res2 = self.a_block2(out1)
        out3, res3 = self.a_block3(out2)

        out = self.s_block3(out3, res3)
        out = self.s_block2(out, res2)
        out = self.s_block1(out, res1)
        return out

In [None]:
def dice_loss(prediction, target, smooth=1e-5):
    prediction = prediction.view(-1)
    target = target.view(-1)

    intersection = (prediction * target).sum()
    union = prediction.sum() + target.sum()

    dice = (2. * intersection + smooth) / (union + smooth)

    return 1.0 - dice

# DATA LOADING

In [None]:
drive.mount('/content/drive')

In [None]:
input_dir = '/content/drive/MyDrive/BinaryAirwaySegmentation/1.3.6.1.4.1.14519.5.2.1.6279.6001.221945191226273284587353530424/Input'
mask_path = '/content/drive/MyDrive/BinaryAirwaySegmentation/1.3.6.1.4.1.14519.5.2.1.6279.6001.221945191226273284587353530424/Mask/1.3.6.1.4.1.14519.5.2.1.6279.6001.221945191226273284587353530424.nii.gz'

In [None]:
dcm_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.dcm')]
dcm_files.sort(key=lambda f: pydicom.dcmread(f).SliceLocation)

input_volume = []
for file in dcm_files:
    ds = pydicom.dcmread(file)
    input_volume.append(ds.pixel_array)

input_volume = np.stack(input_volume, axis=0)
mask_volume = nib.load(mask_path).get_fdata()
mask_volume = np.asarray(mask_volume)
mask_volume = np.rot90(mask_volume, k=1, axes=(0, 1))
mask_volume = np.flip(mask_volume, axis=0)
mask_volume = np.moveaxis(mask_volume, -1, 0)

In [None]:
def display_slices(slice_index):
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))  # Create a figure with 2 subplots

    # Display input_volume slice
    axes[0].imshow(input_volume[slice_index], cmap='gray')
    axes[0].set_title(f"Input Slice {slice_index}")

    # Display mask_volume slice
    axes[1].imshow(mask_volume[slice_index], cmap='gray')  # Assuming mask is grayscale
    axes[1].set_title(f"Mask Slice {slice_index}")

    plt.show()

# Create an interactive slider
interact(display_slices, slice_index=IntSlider(min=0, max=input_volume.shape[0] - 1, step=1, value=0))
# create_gif(display_slices, range(input_volume.shape[0] - 1), filename="DATA.gif", fps=25)

# OLD DATASET

In [None]:
class MyDataset(Dataset):
    def __init__(self, input_volume, mask_volume):
        self.input_volume = torch.from_numpy(input_volume).float()  # Shape: (depth, height, width)
        self.mask_volume = torch.from_numpy(mask_volume).float()  # Shape: (depth, height, width)

    def __len__(self):
        return 1  # Since we're loading the whole volume

    def __getitem__(self, idx):  # Removed idx
        input_volume = self.input_volume[np.newaxis, ...]  # Shape: (1, depth, height, width)
        mask_volume = self.mask_volume[np.newaxis, ...]  # Shape: (1, depth, height, width)
        return input_volume, mask_volume

    def get_patch(self, center_coords, patch_size=128):
        d, w, h = center_coords
        d_start = d - patch_size // 2
        h_start = h - patch_size // 2
        w_start = w - patch_size // 2
        input_patch = self.input_volume[d_start : d_start + patch_size,
                                        h_start : h_start + patch_size,
                                        w_start : w_start + patch_size,]
        mask_patch = self.mask_volume[d_start : d_start + patch_size,
                                      h_start : h_start + patch_size,
                                      w_start : w_start + patch_size,]
        input_patch = input_patch[np.newaxis, ...]  # Shape: (1, depth, height, width)
        mask_patch = mask_patch[np.newaxis, ...]  # Shape: (1, depth, height, width)

        return input_patch, mask_patch

    def get_random_patch(self, patch_size=128):
        d = np.random.randint(patch_size // 2, self.input_volume.shape[0] - patch_size // 2 - 1)
        h = np.random.randint(patch_size // 2, self.input_volume.shape[1] - patch_size // 2 - 1)
        w = np.random.randint(patch_size // 2, self.input_volume.shape[2] - patch_size // 2 - 1)

        input_patch, mask_patch = self.get_patch(center_coords=(d, h, w), patch_size=patch_size)
        return input_patch, mask_patch

    def get_patches(self, num_patches, patch_size=128):
        input_patches = []
        mask_patches = []

        for _ in range(num_patches):
            while True:
                input_patch, mask_patch = self.get_random_patch(patch_size=patch_size)
                c = torch.sum(mask_patch)
                if c > 500:
                    # print(f"good patch | count of {c}")
                    break
                  # print("x", end="")
                  # print(f"empty patch, generating new patch | count of {c}")
            # print("")
            # print(f"count of {c}")
            input_patches.append(input_patch)
            mask_patches.append(mask_patch)

        # Stack patches into a batch
        input_patches = np.stack(input_patches, axis=0)
        mask_patches = np.stack(mask_patches, axis=0)

        return input_patches, mask_patches

    def visualize_patch(self, input_patch, mask_patch, slice_index=16):  # Default to center slice
        def display_slices(slice_index):
            fig, axes = plt.subplots(1, 2, figsize=(10, 5))

            # Display input patch slice
            axes[0].imshow(input_patch[0][slice_index], cmap='gray')
            axes[0].set_title(f"Input Patch Slice {slice_index}")

            # Display mask patch slice
            axes[1].imshow(mask_patch[0][slice_index], cmap='gray')
            axes[1].set_title(f"Mask Patch Slice {slice_index}")

            plt.show()

        # Create an interactive slider
        interact(display_slices, slice_index=IntSlider(min=0, max=input_patch.shape[1] - 1, step=1, value=input_patch.shape[1] // 2))
        # create_gif(display_slices, range(input_patch.shape[1] - 1), filename="PATCH.gif", fps=25)

# Create dataset instance
# dataset = MyDataset(input_volume, mask_volume)

In [None]:
test_input_dir = '/content/drive/MyDrive/BinaryAirwaySegmentation/1.3.6.1.4.1.14519.5.2.1.6279.6001.221945191226273284587353530424/Input'
test_mask_path = '/content/drive/MyDrive/BinaryAirwaySegmentation/1.3.6.1.4.1.14519.5.2.1.6279.6001.221945191226273284587353530424/Mask/1.3.6.1.4.1.14519.5.2.1.6279.6001.221945191226273284587353530424.nii.gz'

# Load data and create dataset object
dcm_files = [os.path.join(test_input_dir, f) for f in os.listdir(test_input_dir) if f.endswith('.dcm')]
dcm_files.sort(key=lambda f: pydicom.dcmread(f).SliceLocation)

test_input_volume = []
for file in dcm_files:
    ds = pydicom.dcmread(file)
    test_input_volume.append(ds.pixel_array)

test_input_volume = np.stack(input_volume, axis=0)
test_mask_volume = nib.load(test_mask_path).get_fdata()
test_mask_volume = np.asarray(test_mask_volume)
test_mask_volume = np.rot90(test_mask_volume, k=1, axes=(0, 1))
test_mask_volume = np.flip(test_mask_volume, axis=0)
test_mask_volume = np.moveaxis(test_mask_volume, -1, 0)

test_dataset = MyDataset(test_input_volume, test_mask_volume)

In [None]:
model = UNet3D(in_channels=1).to(device)

In [None]:
# Save the model
model_save_path = '/content/drive/MyDrive/MODEL.pth'  # Replace with your desired path and filename
torch.save(model.state_dict(), model_save_path)
# model.load_state_dict(torch.load(model_save_path))

In [None]:
example_patches, example_masks = dataset.get_patches(num_patches=1)
example_patch = example_patches[0]
example_mask = example_masks[0]
dataset.visualize_patch(example_patch, example_mask)


In [None]:
# Optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=1e-4)
# criterion = BCELoss()
# criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([150.0])).to(device)
criterion = dice_loss

# Training loop
num_epochs = 100
batch_size = 5
for epoch in range(num_epochs):
    model.train()
    input_patches, mask_patches = dataset.get_patches(num_patches=batch_size)
    input_patches = torch.from_numpy(input_patches).to(device)
    mask_patches = torch.from_numpy(mask_patches).to(device)

    optimizer.zero_grad()  # Reset gradients
    output = model(input_patches)  # Forward pass
    loss = criterion(output, mask_patches)  # Calculate loss
    loss.backward()  # Backpropagate gradients
    optimizer.step()  # Update model weights

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

In [None]:
input_patch, mask_patch = dataset.get_patch((500,190,250), patch_size=128)
model.eval()  # Set the model to evaluation mode
with torch.no_grad():
    output = model(input_patch)

# dataset.visualize_patch(input_patch, mask_patch)
o, m, i = output, mask_patch, input_patch
x = o[0,:].numpy()
z = i[0,:].numpy()

o_min, o_max = x.min(), x.max()
i_min, i_max = z.min(), z.max()

def display_slices2(slice_index):
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))  # Adjust figsize for colorbar space
    # Display input_volume slice with colorbar
    im1 = axes[0].imshow(o[0][slice_index], cmap='jet',vmin=o_min, vmax=o_max)
    fig.colorbar(im1, ax=axes[0], fraction=0.046, pad=0.04)  # Add colorbar to axes[0]
    axes[0].set_title(f"Output Slice {slice_index}")

    # Display mask_patch slice with colorbar
    im2 = axes[1].imshow(m[0][slice_index], cmap='jet',vmin=0.0, vmax=1.0)
    fig.colorbar(im2, ax=axes[1], fraction=0.046, pad=0.04)  # Add colorbar to axes[1]
    axes[1].set_title(f"Mask Slice {slice_index}")

    # Display input_patch slice with colorbar
    im3 = axes[2].imshow(i[0][slice_index], cmap='jet',vmin=i_min, vmax=i_max)
    fig.colorbar(im3, ax=axes[2], fraction=0.046, pad=0.04)  # Add colorbar to axes[2]
    axes[2].set_title(f"Input Slice {slice_index}")

    plt.show()
# Create an interactive slider (adjust max value if needed)
max_index = int(output.shape[1] - 1)
interact(display_slices2, slice_index=IntSlider(min=0, max=127, step=1, value=0))
# create_gif(display_slices2, range(127), filename="OUTPUT.gif", fps=25)


In [None]:
# Load your trained model
model.eval()

# Number of patches to test on
num_patches = 1 # You can adjust this number

# Lists to store predictions and ground truths for accuracy calculation
confusion_matrix = torch.zeros(2, 2, dtype=torch.int32)  # 2x2 for binary classification

# Test loop
with torch.no_grad():
    for _ in range(num_patches):
        # Get a random patch
        # one example
        input_patch, mask_patch = dataset.get_patch((500,190,250), patch_size=128)

        # Make prediction
        output = model(input_patch.float().to(device))
        output = torch.sigmoid(output)  # Apply sigmoid for probabilities

        # Convert to binary predictions (threshold at 0.5)
        predictions = (output > 0.5).int().cpu().numpy().flatten()
        ground_truth = mask_patch.int().cpu().numpy().flatten()

        # Update confusion matrix
        for p, t in zip(predictions, ground_truth):
            confusion_matrix[t, p] += 1

# Display confusion matrix
plt.figure(figsize=(8, 6))
plt.imshow(confusion_matrix, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.colorbar()

# Add labels to the plot
classes = ['Background', 'Airway']
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)

# Add text annotations inside the plot
thresh = confusion_matrix.max() / 2.
for i, j in np.ndindex(confusion_matrix.shape):
    plt.text(j, i, format(confusion_matrix[i, j].item(), 'd'),
             horizontalalignment="center",
             color="white" if confusion_matrix[i, j] > thresh else "black")

plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()

In [None]:
input_patch, mask_patch = dataset.get_patch((125,150,150), patch_size=128)
with torch.no_grad():
    output2 = model(input_patch)
    output2 = torch.sigmoid(output2)

# Visualize the patch (slice 16 by default)
# dataset.visualize_patch(input_patch, mask_patch)
o2, m2, i2 = output2, mask_patch, input_patch
x2 = o2[0,:].numpy()
z2 = i2[0,:].numpy()

o_min2, o_max2 = x2.min(), x2.max()
i_min2, i_max2 = z2.min(), z2.max()

def display_slices3(slice_index):
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))  # Adjust figsize for colorbar space
    # Display input_volume slice with colorbar
    im1 = axes[0].imshow(o2[0][slice_index], cmap='jet',vmin=o_min2, vmax=o_max2)
    fig.colorbar(im1, ax=axes[0], fraction=0.046, pad=0.04)  # Add colorbar to axes[0]
    axes[0].set_title(f"Output Slice {slice_index}")

    # Display mask_patch slice with colorbar
    im2 = axes[1].imshow(m2[0][slice_index], cmap='jet',vmin=0.0, vmax=1.0)
    fig.colorbar(im2, ax=axes[1], fraction=0.046, pad=0.04)  # Add colorbar to axes[1]
    axes[1].set_title(f"Mask Slice {slice_index}")

    # Display input_patch slice with colorbar
    im3 = axes[2].imshow(i2[0][slice_index], cmap='jet',vmin=i_min2, vmax=i_max2)
    fig.colorbar(im3, ax=axes[2], fraction=0.046, pad=0.04)  # Add colorbar to axes[2]
    axes[2].set_title(f"Input Slice {slice_index}")

    plt.show()

# Create an interactive slider (adjust max value if needed)
max_index = int(output2.shape[1] - 1)
interact(display_slices3, slice_index=IntSlider(min=0, max=127, step=1, value=0))

# NEW DATASET

In [None]:
class PatchLoader:
    def __init__(self, drive_dir, num_patches_per_volume=5, patch_size=128):
        self.drive_dir = drive_dir
        self.num_patches_per_volume = num_patches_per_volume
        self.patch_size = patch_size
        self.data_folders = self.get_data_folders()

        self.input_patches = []
        self.mask_patches = []
        self.load_and_extract_patches()

    def __len__(self):
        return len(self.input_patches)

    def __getitem__(self, idx):
        input_patch = torch.from_numpy(self.input_patches[idx]).float()
        mask_patch = torch.from_numpy(self.mask_patches[idx]).float()
        return input_patch, mask_patch

    def visualize_sample_patches(self, num_samples=5):
        indices = np.random.choice(len(self.input_patches), num_samples, replace=False)
        sample_input_patches = [self.input_patches[i] for i in indices]
        sample_mask_patches = [self.mask_patches[i] for i in indices]

        def display_slices(slice_index):
            num_cols = num_samples
            num_rows = 2
            fig, axes = plt.subplots(num_rows, num_cols, figsize=(10, 5 * num_rows))

            for i in range(num_samples):
                axes[0, i].imshow(sample_input_patches[i][0][slice_index], cmap='gray')
                axes[0, i].set_title(f"Input, Patch {i + 1}, Slice {slice_index}")

                axes[1, i].imshow(sample_mask_patches[i][0][slice_index], cmap='gray')
                axes[1, i].set_title(f"Mask, Patch {i + 1}, Slice {slice_index}")

            plt.tight_layout()
            plt.show()

        # Create interactive slider
        slice_slider = IntSlider(min=0, max=sample_input_patches[0].shape[1] - 1, step=1, value=0)
        interact(display_slices, slice_index=slice_slider)

    def get_patches(self):
        return self.input_patches, self.mask_patches

    def get_data_folders(self):
        data_folders = [os.path.join(self.drive_dir, f) for f in os.listdir(self.drive_dir) if os.path.isdir(os.path.join(self.drive_dir, f))]
        return data_folders[:30]

    def load_and_extract_patches(self):
        input_patches = []
        mask_patches = []

        for folder in tqdm(self.data_folders, desc="Processing Volumes"):
            print(f"\n{folder}")
            input_dir = os.path.join(folder, 'Input')
            mask_path = os.path.join(folder, 'Mask', os.listdir(os.path.join(folder, 'Mask'))[0])  # Assuming only one mask file
            dcm_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.dcm')]
            dcm_files.sort(key=lambda f: pydicom.dcmread(f).SliceLocation)

            input_volume = []
            for file in dcm_files:
                ds = pydicom.dcmread(file)
                input_volume.append(ds.pixel_array)
            input_volume = np.stack(input_volume, axis=0)

            mask_volume = nib.load(mask_path).get_fdata()
            mask_volume = np.asarray(mask_volume)
            mask_volume = np.rot90(mask_volume, k=1, axes=(0, 1))
            mask_volume = np.flip(mask_volume, axis=0)
            mask_volume = np.moveaxis(mask_volume, -1, 0)

            for _ in tqdm(range(self.num_patches_per_volume)):
                count = 0
                if np.random.rand() < 0.75:
                    while True:
                        input_patch, mask_patch = self.get_random_patch_from_volume(input_volume, mask_volume)
                        c = np.sum(mask_patch)
                        if c > 500:
                            break
                        count += 1
                        if count == 1000:
                            break
                else:
                    input_patch, mask_patch = self.get_random_patch_from_volume(input_volume, mask_volume)

                input_patches.append(input_patch)
                mask_patches.append(mask_patch)
            del input_volume, mask_volume
            gc.collect()
            print(f"total nums of patches is {len(input_patches)}")

        self.input_patches, self.mask_patches = input_patches, mask_patches

    def get_random_patch_from_volume(self, input_volume, mask_volume):
        patch_size = self.patch_size

        d = np.random.randint(patch_size // 2, input_volume.shape[0] - patch_size // 2 - 1)
        h = np.random.randint(patch_size // 2, input_volume.shape[1] - patch_size // 2 - 1)
        w = np.random.randint(patch_size // 2, input_volume.shape[2] - patch_size // 2 - 1)

        input_patch = input_volume[d - patch_size // 2:d + patch_size // 2,
                                    h - patch_size // 2:h + patch_size // 2,
                                    w - patch_size // 2:w + patch_size // 2]
        mask_patch = mask_volume[d - patch_size // 2:d + patch_size // 2,
                                  h - patch_size // 2:h + patch_size // 2,
                                  w - patch_size // 2:w + patch_size // 2]

        input_patch = input_patch[np.newaxis, ...]  # Shape: (1, depth, height, width)
        mask_patch = mask_patch[np.newaxis, ...]  # Shape: (1, depth, height, width)

        return input_patch, mask_patch


In [None]:
# Assuming you have a list of data folder paths in 'data_folders'
drive_dir = '/content/drive/MyDrive/lung'
patch_loader = PatchLoader(drive_dir)

In [None]:
input_patches, mask_patches = patch_loader.get_patches()
input_patches = torch.tensor(input_patches, dtype=torch.float32).to(device)
mask_patches = torch.tensor(mask_patches, dtype=torch.float32).to(device)
patch_loader.visualize_sample_patches()

In [None]:
model2 = UNet3D(in_channels=1)
model2.load_state_dict(torch.load('/content/drive/MyDrive/NEW_MODEL2.pth'))
model2.to(device)

In [None]:
# Training loop
num_epochs = 1000
batch_size = 10  # Adjust as needed
optimizer = optim.Adam(model2.parameters(), lr=1e-3)
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
# criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([100.0])).to(device)
criterion = dice_loss
save_epochs = 5

for epoch in range(num_epochs):
    model2.train()

    # Create 2shuffled indices for this epoch
    num_samples = len(input_patches)
    indices = torch.randperm(num_samples)
    if (epoch + 1) % save_epochs == 0:
        model_save_path = '/content/drive/MyDrive/NEW_MODEL2.pth'  # Replace with your desired path and filename
        torch.save(model2.state_dict(), model_save_path)
        print(f"Model saved at epoch {epoch}\n")

    for batch_start in range(0, num_samples, batch_size):
        # Get batch indices
        batch_indices = indices[batch_start : batch_start + batch_size]

        # Get input and mask batches
        input_batch = input_patches[batch_indices].to(device) #Move to device here to reduce memory footprint
        mask_batch = mask_patches[batch_indices].to(device)  #Move to device here to reduce memory footprint

        # Training steps:
        optimizer.zero_grad()
        output = model2(input_batch)
        loss = criterion(output, mask_batch)
        loss.backward()
        optimizer.step()
        scheduler.step()
        print(f'Epoch [{epoch + 1}/{num_epochs}], Batch [{batch_start // batch_size + 1}/{num_samples // batch_size}], Loss: {loss.item():.4f}')

In [None]:
model_save_path = '/content/drive/MyDrive/NEW_MODEL2.pth'  # Replace with your desired path and filename
torch.save(model2.state_dict(), model_save_path)

In [None]:
input_dir = '/content/drive/MyDrive/lung/1.3.6.1.4.1.14519.5.2.1.6279.6001.221945191226273284587353530424/Input'
mask_path = '/content/drive/MyDrive/lung/1.3.6.1.4.1.14519.5.2.1.6279.6001.221945191226273284587353530424/Mask/1.3.6.1.4.1.14519.5.2.1.6279.6001.221945191226273284587353530424.nii.gz'

dcm_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.dcm')]
dcm_files.sort(key=lambda f: pydicom.dcmread(f).SliceLocation)

input_volume = []
for file in dcm_files:
    ds = pydicom.dcmread(file)
    input_volume.append(ds.pixel_array)

input_volume = np.stack(input_volume, axis=0)
mask_volume = nib.load(mask_path).get_fdata()
mask_volume = np.asarray(mask_volume)
mask_volume = np.rot90(mask_volume, k=1, axes=(0, 1))
mask_volume = np.flip(mask_volume, axis=0)
mask_volume = np.moveaxis(mask_volume, -1, 0)

dataset = MyDataset(input_volume, mask_volume)

In [None]:
input_patch, mask_patch = dataset.get_patch((496,190,253), patch_size=128)
model2.eval()  # Set the model to evaluation mode
with torch.no_grad():
    output = model2(input_patch)
    # output = torch.sigmoid(output)

# Visualize the patch (slice 16 by default)
# dataset.visualize_patch(input_patch, mask_patch)
o, m, i = output, mask_patch, input_patch
x = o[0,:].numpy()
z = i[0,:].numpy()

o_min, o_max = x.min(), x.max()
i_min, i_max = z.min(), z.max()

def display_slices(slice_index):
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))  # Adjust figsize for colorbar space
    # Display input_volume slice with colorbar
    im1 = axes[0].imshow(o[0][slice_index], cmap='jet',vmin=o_min, vmax=o_max)
    fig.colorbar(im1, ax=axes[0], fraction=0.046, pad=0.04)  # Add colorbar to axes[0]
    axes[0].set_title(f"Output Slice {slice_index}")

    # Display mask_patch slice with colorbar
    im2 = axes[1].imshow(m[0][slice_index], cmap='jet',vmin=0.0, vmax=1.0)
    fig.colorbar(im2, ax=axes[1], fraction=0.046, pad=0.04)  # Add colorbar to axes[1]
    axes[1].set_title(f"Mask Slice {slice_index}")

    # Display input_patch slice with colorbar
    im3 = axes[2].imshow(i[0][slice_index], cmap='jet',vmin=i_min, vmax=i_max)
    fig.colorbar(im3, ax=axes[2], fraction=0.046, pad=0.04)  # Add colorbar to axes[2]
    axes[2].set_title(f"Input Slice {slice_index}")

    plt.show()

# Create an interactive slider (adjust max value if needed)
max_index = int(output.shape[1] - 1)
interact(display_slices, slice_index=IntSlider(min=0, max=127, step=1, value=0))

# RUN ALL


In [None]:
!pip install pydicom
!pip install nibabel

from torch import nn
import torch
import torch.optim as optim
from google.colab import drive
import matplotlib.pyplot as plt
import numpy as np
from ipywidgets import interact, IntSlider
from torch.optim.lr_scheduler import StepLR
import pydicom
import nibabel as nib
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import gc
import torch.optim as optim
from torch.nn import BCELoss
from tqdm import tqdm
import tensorflow as tf

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class DownCNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_padding=False):
        super(DownCNNBlock, self).__init__()
        padding = 1 if use_padding else 0

        # Two 3x3x3 convolutional layers with ReLU activation
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=(3, 3, 3), padding=padding)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=(3, 3, 3), padding=padding)
        self.relu2 = nn.ReLU()

        # 2x2x2 max pooling for downsampling
        self.pooling = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=2)

    def forward(self, input):
        # print("DownCNNBLOCK")
        res = self.relu1(self.conv1(input))
        res = self.relu2(self.conv2(res))
        out = self.pooling(res)
        return out, res

class UpCNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, last_layer=False):
        super(UpCNNBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        # 2x2x2 transposed convolution for upsampling
        self.upconv = nn.ConvTranspose3d(in_channels - out_channels, in_channels - out_channels, kernel_size=(2, 2, 2), stride=2)
        self.finalUpconv = nn.ConvTranspose3d(in_channels, in_channels, kernel_size=(2, 2, 2), stride=2)

        # Two 3x3x3 convolutional layers with ReLU activation
        self.conv1 = nn.Conv3d(in_channels + out_channels, out_channels, kernel_size=(3, 3, 3), padding=1)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=(3, 3, 3), padding=1)
        self.relu2 = nn.ReLU()
        self.conv1Final = nn.Conv3d(in_channels + in_channels, out_channels, kernel_size=(3, 3, 3), padding=1)

        # Flag to indicate if this is the last layer
        self.last_layer = last_layer

        # If it's the last layer, add a 1x1x1 convolution and sigmoid activation
        self.conv3 = nn.Conv3d(out_channels, 1, kernel_size=(1, 1, 1))
        self.sigmoid = nn.Sigmoid()

    def forward(self, input, residual=None):
        # Apply transposed convolution for upsampling
        if self.out_channels == 1:
            out = self.finalUpconv(input)
        else:
            out = self.upconv(input)
        if residual is not None:
            if len(out.shape) == 4:
                out = torch.cat((out, residual), dim=0)
            else:
              out = torch.cat((out, residual), dim=1) #!!!!!!!!!!!

        # Apply convolutions and activations
        if self.out_channels == 1:
            out = self.conv1Final(out)
        else:
            out = self.relu1(self.conv1(out))

        out = self.relu2(self.conv2(out))
        # Apply final convolution and sigmoid if it's the last layer
        if self.last_layer:
            out = self.conv3(out)
            out = self.sigmoid(out)
        return out


class UNet3D(nn.Module):
    def __init__(self, in_channels, level_channels=[16, 32, 64]):
        super(UNet3D, self).__init__()
        self.a_block1 = DownCNNBlock(in_channels, level_channels[0], use_padding=True)
        self.a_block2 = DownCNNBlock(level_channels[0], level_channels[1], use_padding=True)
        self.a_block3 = DownCNNBlock(level_channels[1], level_channels[2], use_padding=True)

        self.s_block3 = UpCNNBlock(level_channels[2] + level_channels[1], level_channels[1])
        self.s_block2 = UpCNNBlock(level_channels[1] + level_channels[0], level_channels[0])
        self.s_block1 = UpCNNBlock(level_channels[0], in_channels, last_layer=True)

    def forward(self, input):
        out1, res1 = self.a_block1(input)
        out2, res2 = self.a_block2(out1)
        out3, res3 = self.a_block3(out2)

        out = self.s_block3(out3, res3)
        out = self.s_block2(out, res2)
        out = self.s_block1(out, res1)
        return out

def dice_loss(prediction, target, smooth=1e-5):
    prediction = prediction.view(-1)
    target = target.view(-1)

    intersection = (prediction * target).sum()
    union = prediction.sum() + target.sum()

    dice = (2. * intersection + smooth) / (union + smooth)

    return 1.0 - dice

In [None]:
drive.mount('/content/drive')

In [None]:
class PatchLoader:
    def __init__(self, drive_dir, num_patches_per_volume=5, patch_size=128):
        self.drive_dir = drive_dir
        self.num_patches_per_volume = num_patches_per_volume
        self.patch_size = patch_size
        self.data_folders = self.get_data_folders()

        self.input_patches = []
        self.mask_patches = []
        self.load_and_extract_patches()

    def __len__(self):
        return len(self.input_patches)

    def __getitem__(self, idx):
        input_patch = torch.from_numpy(self.input_patches[idx]).float()
        mask_patch = torch.from_numpy(self.mask_patches[idx]).float()
        return input_patch, mask_patch

    def visualize_sample_patches(self, num_samples=5):
        indices = np.random.choice(len(self.input_patches), num_samples, replace=False)
        sample_input_patches = [self.input_patches[i] for i in indices]
        sample_mask_patches = [self.mask_patches[i] for i in indices]

        def display_slices(slice_index):
            num_cols = num_samples
            num_rows = 2
            fig, axes = plt.subplots(num_rows, num_cols, figsize=(10, 5 * num_rows))

            for i in range(num_samples):
                axes[0, i].imshow(sample_input_patches[i][0][slice_index], cmap='gray')
                axes[0, i].set_title(f"Input, Patch {i + 1}, Slice {slice_index}")

                axes[1, i].imshow(sample_mask_patches[i][0][slice_index], cmap='gray')
                axes[1, i].set_title(f"Mask, Patch {i + 1}, Slice {slice_index}")

            plt.tight_layout()
            plt.show()

        # Create interactive slider
        slice_slider = IntSlider(min=0, max=sample_input_patches[0].shape[1] - 1, step=1, value=0)
        interact(display_slices, slice_index=slice_slider)

    def get_patches(self):
        return self.input_patches, self.mask_patches

    def get_data_folders(self):
        data_folders = [os.path.join(self.drive_dir, f) for f in os.listdir(self.drive_dir) if os.path.isdir(os.path.join(self.drive_dir, f))]
        return data_folders[:30]

    def load_and_extract_patches(self):
        input_patches = []
        mask_patches = []

        for folder in tqdm(self.data_folders, desc="Processing Volumes"):
            print(f"\n{folder}")
            input_dir = os.path.join(folder, 'Input')
            mask_path = os.path.join(folder, 'Mask', os.listdir(os.path.join(folder, 'Mask'))[0])  # Assuming only one mask file
            dcm_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.dcm')]
            dcm_files.sort(key=lambda f: pydicom.dcmread(f).SliceLocation)

            input_volume = []
            for file in dcm_files:
                ds = pydicom.dcmread(file)
                input_volume.append(ds.pixel_array)
            input_volume = np.stack(input_volume, axis=0)

            mask_volume = nib.load(mask_path).get_fdata()
            mask_volume = np.asarray(mask_volume)
            mask_volume = np.rot90(mask_volume, k=1, axes=(0, 1))
            mask_volume = np.flip(mask_volume, axis=0)
            mask_volume = np.moveaxis(mask_volume, -1, 0)

            for _ in tqdm(range(self.num_patches_per_volume)):
                count = 0
                if np.random.rand() < 0.75:
                    while True:
                        input_patch, mask_patch = self.get_random_patch_from_volume(input_volume, mask_volume)
                        c = np.sum(mask_patch)
                        if c > 500:
                            break
                        count += 1
                        if count == 1000:
                            break
                else:
                    input_patch, mask_patch = self.get_random_patch_from_volume(input_volume, mask_volume)

                input_patches.append(input_patch)
                mask_patches.append(mask_patch)
            del input_volume, mask_volume
            gc.collect()
            print(f"total nums of patches is {len(input_patches)}")

        self.input_patches, self.mask_patches = input_patches, mask_patches

    def get_random_patch_from_volume(self, input_volume, mask_volume):
        patch_size = self.patch_size

        d = np.random.randint(patch_size // 2, input_volume.shape[0] - patch_size // 2 - 1)
        h = np.random.randint(patch_size // 2, input_volume.shape[1] - patch_size // 2 - 1)
        w = np.random.randint(patch_size // 2, input_volume.shape[2] - patch_size // 2 - 1)

        input_patch = input_volume[d - patch_size // 2:d + patch_size // 2,
                                    h - patch_size // 2:h + patch_size // 2,
                                    w - patch_size // 2:w + patch_size // 2]
        mask_patch = mask_volume[d - patch_size // 2:d + patch_size // 2,
                                  h - patch_size // 2:h + patch_size // 2,
                                  w - patch_size // 2:w + patch_size // 2]

        input_patch = input_patch[np.newaxis, ...]  # Shape: (1, depth, height, width)
        mask_patch = mask_patch[np.newaxis, ...]  # Shape: (1, depth, height, width)

        return input_patch, mask_patch


In [None]:
# Assuming you have a list of data folder paths in 'data_folders'
drive_dir = '/content/drive/MyDrive/lung'
patch_loader = PatchLoader(drive_dir)

In [None]:
input_patches, mask_patches = patch_loader.get_patches()
input_patches = torch.tensor(input_patches, dtype=torch.float32).to(device)
mask_patches = torch.tensor(mask_patches, dtype=torch.float32).to(device)

In [None]:
model2 = UNet3D(in_channels=1)
model2.load_state_dict(torch.load('/content/drive/MyDrive/NEW_MODEL2.pth'))
model2.to(device)

In [None]:
# Training loop
num_epochs = 1000
batch_size = 10  # Adjust as needed
optimizer = optim.Adam(model2.parameters(), lr=1e-3)
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
# criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([100.0])).to(device)
criterion = dice_loss
save_epochs = 5

for epoch in range(num_epochs):
    model2.train()

    # Create 2shuffled indices for this epoch
    num_samples = len(input_patches)
    indices = torch.randperm(num_samples)
    if (epoch + 1) % save_epochs == 0:
        model_save_path = '/content/drive/MyDrive/NEW_MODEL2.pth'  # Replace with your desired path and filename
        torch.save(model2.state_dict(), model_save_path)
        print(f"Model saved at epoch {epoch}\n")

    for batch_start in range(0, num_samples, batch_size):
        # Get batch indices
        batch_indices = indices[batch_start : batch_start + batch_size]

        # Get input and mask batches
        input_batch = input_patches[batch_indices].to(device) #Move to device here to reduce memory footprint
        mask_batch = mask_patches[batch_indices].to(device)  #Move to device here to reduce memory footprint

        # Training steps:
        optimizer.zero_grad()
        output = model2(input_batch)
        loss = criterion(output, mask_batch)
        loss.backward()
        optimizer.step()
        scheduler.step()
        print(f'Epoch [{epoch + 1}/{num_epochs}], Batch [{batch_start // batch_size + 1}/{num_samples // batch_size}], Loss: {loss.item():.4f}')