In [None]:
! pip install tqdm

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import nibabel as nib
from tqdm import tqdm

In [None]:
! nvidia-smi

In [None]:
device = "cuda"

In [None]:
data_path = "data/25226366/cropped/"
patient_no = 10
data_end = np.array(nib.load(data_path + "pat" + str(patient_no) + "_cropped_seg_endpoints.nii.gz").get_fdata())
data_seg = np.array(nib.load(data_path + "pat" + str(patient_no) + "_cropped_seg.nii.gz").get_fdata())
data_cropped = np.array(nib.load(data_path + "pat" + str(patient_no) + "_cropped.nii.gz").get_fdata())
print("endpoint data shape: " + str(data_end.shape))
print("cropped data shape: " + str(data_cropped.shape))
print("segmented data shape: " + str(data_seg.shape))

In [None]:
plt.imshow(data_cropped[50])

In [None]:
# collect dataset
data_path = "data/25226366/cropped/"
dataset = []
for i in range(60):
    data_cropped = torch.tensor(nib.load(data_path + "pat" + str(i) + "_cropped.nii.gz").get_fdata(), dtype=torch.float32)
    data_seg = torch.tensor(nib.load(data_path + "pat" + str(i) + "_cropped_seg.nii.gz").get_fdata(), dtype=torch.float32)
    # split segments into stack
    seg_stack = []
    for i in range(1,9):
        slice = torch.zeros(data_seg.shape)
        slice[data_seg == float(i)] = float(i)
        seg_stack.append(slice)
    dataset.append((data_cropped, torch.stack(seg_stack, dim=0)))

In [None]:
def pad_to_cube(tensor):
    """
    Pads a 3D tensor (C, D, H, W) with zeros to make it cubic.
    """
    d, h, w = tensor.shape
    max_dim = max(d, h, w)

    # Compute padding amounts for each dimension (pad evenly on both sides)
    pad_d = (max_dim - d) // 2
    pad_h = (max_dim - h) // 2
    pad_w = (max_dim - w) // 2

    pad = (pad_w, pad_w, pad_h, pad_h, pad_d, pad_d)  # (W, W, H, H, D, D)
    padded_tensor = F.pad(tensor, pad, mode='constant', value=0)

    return padded_tensor

def resize_3d_tensor(tensor, target_size=(128, 128, 128), mode='trilinear'):
    """
    Resizes a 3D tensor to the target size after padding.
    """
    tensor = tensor.unsqueeze(0)
    tensor = tensor.unsqueeze(0)
    resized_tensor = F.interpolate(tensor, size=target_size, mode=mode, align_corners=False)
    return resized_tensor  # Remove batch dimension

def preprocess_3d_tensor(tensor, target_size=(128, 128, 128), mode='trilinear'):
    """
    Pads a 3D tensor to make it cubic, then resizes it.
    """
    tensor = pad_to_cube(tensor)
    tensor = resize_3d_tensor(tensor, target_size)
    return tensor
def preprocess_layers(tensor):
    out_layers = []
    for i in range(tensor.shape[0]):
        out_layers.append(preprocess_3d_tensor(tensor[i]))
    return torch.stack(out_layers, dim=0)

In [None]:
dataset[0][0].shape

In [None]:
im_tran_1 = preprocess_3d_tensor(dataset[0][0])
print(im_tran_1.shape)
plt.imshow(im_tran_1[0][0][0])

In [None]:
class DoubleConv3D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.conv(x)

class UNet3D(nn.Module):
    def __init__(self, in_channels, num_classes, dropout=0.3):
        super().__init__()

        self.encoder = nn.ModuleList([
            DoubleConv3D(in_channels, 64),
            DoubleConv3D(64, 128),
            DoubleConv3D(128, 256),
            DoubleConv3D(256, 512)
        ])
        
        self.pool = nn.MaxPool3d(kernel_size=2, stride=2)
        
        self.bottleneck = nn.Sequential(
            DoubleConv3D(512, 1024),
            nn.Dropout3d(p=dropout)  # Regularization in the bottleneck
        )
        
        self.upconvs = nn.ModuleList([
            nn.ConvTranspose3d(1024, 512, kernel_size=2, stride=2),
            nn.ConvTranspose3d(512, 256, kernel_size=2, stride=2),
            nn.ConvTranspose3d(256, 128, kernel_size=2, stride=2),
            nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2)
        ])
        
        self.decoder = nn.ModuleList([
            DoubleConv3D(1024, 512),
            DoubleConv3D(512, 256),
            DoubleConv3D(256, 128),
            DoubleConv3D(128, 64)
        ])
        
        self.final_conv = nn.Conv3d(64, num_classes, kernel_size=1)
        
        self._initialize_weights()

    def forward(self, x):
        skip_connections = []
        for down in self.encoder:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)
        
        x = self.bottleneck(x)
        
        skip_connections = skip_connections[::-1]
        for idx in range(len(self.upconvs)):
            x = self.upconvs[idx](x)
            skip_connection = skip_connections[idx]
            
            # Ensure the shapes match before concatenation
            if x.shape != skip_connection.shape:
                diff_d = skip_connection.shape[2] - x.shape[2]
                diff_h = skip_connection.shape[3] - x.shape[3]
                diff_w = skip_connection.shape[4] - x.shape[4]
                x = F.pad(x, [diff_w // 2, diff_w - diff_w // 2,
                              diff_h // 2, diff_h - diff_h // 2,
                              diff_d // 2, diff_d - diff_d // 2])

            x = torch.cat((skip_connection, x), dim=1)
            x = self.decoder[idx](x)
        
        return self.final_conv(x)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
                nn.init.xavier_uniform_(m.weight)
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


In [None]:
dataset_transformed[0][1].shape

In [None]:
split = 50
dataset_transformed = [(preprocess_3d_tensor(t[0]).squeeze(0), preprocess_layers(t[1]).squeeze(1).squeeze(1)) for t in dataset]
train_dataloader = DataLoader(dataset_transformed[:split], batch_size=1, shuffle=True)
test_dataloader = DataLoader(dataset_transformed[split:], batch_size=1, shuffle=True)

In [None]:
model_0 = UNet3D(in_channels=1, num_classes=8).to(device)

In [None]:
loss_fn = nn.CrossEntropyLoss()
lr_0 = 0.0001
optimizer = torch.optim.Adam(model_0.parameters(), lr=lr_0)

In [None]:
print(dataset_transformed[0][0].shape)

In [None]:
torch.cuda.empty_cache()

In [None]:
loss_list = []
num_epochs = 50

for epoch in range(num_epochs):
    epoch_loss = 0.0  # Track loss per epoch
    num_batches = 0
    correct_pixels = 0
    total_pixels = 0
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch}"):
        X = batch[0].to(device)  # Input tensor
        Y = batch[1].to(device) # Ensure target labels are int64
        #print(f"Min Y: {Y.min()}, Max Y: {Y.max()}")
        optimizer.zero_grad()  
        Y_hat = model_0(X)  # Model outputs logits
        #print((Y.shape, Y_hat.shape))
        loss = loss_fn(Y_hat, Y)  # Ensure correct shape for loss function
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        num_batches += 1
        Y_hat_classes = torch.argmax(Y_hat, dim=1)  # Get predicted class per pixel
        # Compute accuracy
        correct_pixels += (Y_hat_classes == Y).sum().item()
        total_pixels += Y.numel()

    avg_loss = epoch_loss / num_batches  # Compute mean loss for logging
    loss_list.append(avg_loss)
    accuracy = correct_pixels / total_pixels
    print(f"\tEpoch {epoch} - Loss: {loss.item()}, Accuracy: {accuracy:.4f}")


In [None]:
loss = F.cross_entropy(input, target)