In [1]:
import os
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt

# For managing large outputs
import pandas as pd


In [2]:
# Assuming your images and labels are in the following directories
images_path = 'F:\\Repositories\\FLARE22Train\\images'  # Update with your actual path
labels_path = 'F:\\Repositories\\FLARE22Train\\labels'  # Update with your actual path

# List all files in the directories
image_files = sorted(os.listdir(images_path))
label_files = sorted(os.listdir(labels_path))

print(f"Total number of images: {len(image_files)}")
print(f"Total number of labels: {len(label_files)}")


Total number of images: 50
Total number of labels: 50


In [3]:
# Create lists to store the dimensions of images and labels
image_shapes = []
label_shapes = []

# Loop through all image and label files to get their dimensions
for img_file, lbl_file in zip(image_files, label_files):
    img_path = os.path.join(images_path, img_file)
    lbl_path = os.path.join(labels_path, lbl_file)

    # Load the NIfTI files
    img = nib.load(img_path).get_fdata()
    lbl = nib.load(lbl_path).get_fdata()

    # Store the shapes
    image_shapes.append(img.shape)
    label_shapes.append(lbl.shape)

# Convert to a DataFrame for better visualization
df = pd.DataFrame({
    'Image File': image_files,
    'Image Shape': image_shapes,
    'Label File': label_files,
    'Label Shape': label_shapes
})

# Display the DataFrame
df


Unnamed: 0,Image File,Image Shape,Label File,Label Shape
0,FLARE22_Tr_0001_0000.nii.gz,"(512, 512, 110)",FLARE22_Tr_0001.nii.gz,"(512, 512, 110)"
1,FLARE22_Tr_0002_0000.nii.gz,"(512, 512, 107)",FLARE22_Tr_0002.nii.gz,"(512, 512, 107)"
2,FLARE22_Tr_0003_0000.nii.gz,"(512, 512, 101)",FLARE22_Tr_0003.nii.gz,"(512, 512, 101)"
3,FLARE22_Tr_0004_0000.nii.gz,"(512, 512, 87)",FLARE22_Tr_0004.nii.gz,"(512, 512, 87)"
4,FLARE22_Tr_0005_0000.nii.gz,"(512, 512, 85)",FLARE22_Tr_0005.nii.gz,"(512, 512, 85)"
5,FLARE22_Tr_0006_0000.nii.gz,"(512, 512, 93)",FLARE22_Tr_0006.nii.gz,"(512, 512, 93)"
6,FLARE22_Tr_0007_0000.nii.gz,"(512, 512, 105)",FLARE22_Tr_0007.nii.gz,"(512, 512, 105)"
7,FLARE22_Tr_0008_0000.nii.gz,"(512, 512, 99)",FLARE22_Tr_0008.nii.gz,"(512, 512, 99)"
8,FLARE22_Tr_0009_0000.nii.gz,"(512, 512, 87)",FLARE22_Tr_0009.nii.gz,"(512, 512, 87)"
9,FLARE22_Tr_0010_0000.nii.gz,"(512, 512, 96)",FLARE22_Tr_0010.nii.gz,"(512, 512, 96)"


In [4]:
# Get the unique shapes for images and labels
unique_image_shapes = df['Image Shape'].value_counts()
unique_label_shapes = df['Label Shape'].value_counts()

print("Unique Image Shapes and Their Counts:")
print(unique_image_shapes)

print("\nUnique Label Shapes and Their Counts:")
print(unique_label_shapes)


Unique Image Shapes and Their Counts:
Image Shape
(512, 512, 93)     6
(512, 512, 98)     4
(512, 512, 101)    3
(512, 512, 87)     3
(512, 512, 85)     3
(512, 512, 99)     3
(512, 512, 95)     2
(512, 512, 103)    2
(512, 512, 100)    2
(512, 512, 107)    2
(512, 512, 84)     2
(512, 512, 89)     2
(512, 512, 113)    2
(512, 512, 102)    2
(512, 512, 109)    1
(512, 512, 79)     1
(512, 512, 83)     1
(512, 512, 94)     1
(512, 512, 110)    1
(512, 512, 104)    1
(512, 512, 92)     1
(512, 512, 71)     1
(512, 512, 91)     1
(512, 512, 96)     1
(512, 512, 105)    1
(512, 512, 108)    1
Name: count, dtype: int64

Unique Label Shapes and Their Counts:
Label Shape
(512, 512, 93)     6
(512, 512, 98)     4
(512, 512, 101)    3
(512, 512, 87)     3
(512, 512, 85)     3
(512, 512, 99)     3
(512, 512, 95)     2
(512, 512, 103)    2
(512, 512, 100)    2
(512, 512, 107)    2
(512, 512, 84)     2
(512, 512, 89)     2
(512, 512, 113)    2
(512, 512, 102)    2
(512, 512, 109)    1
(512, 512, 7

In [5]:
import SimpleITK as sitk
import os

def resample_image(image, target_shape=(128, 128, 128), is_label=False):
    # Get the original spacing and size
    original_spacing = image.GetSpacing()
    original_size = image.GetSize()

    # Compute the target spacing to achieve the desired target shape
    target_spacing = [
        (original_size[i] * original_spacing[i]) / target_shape[i] for i in range(3)
    ]

    # Resample the image
    resample = sitk.ResampleImageFilter()
    resample.SetOutputSpacing(target_spacing)
    resample.SetSize(target_shape)
    resample.SetOutputDirection(image.GetDirection())
    resample.SetOutputOrigin(image.GetOrigin())
    resample.SetTransform(sitk.Transform())

    if is_label:
        resample.SetInterpolator(sitk.sitkNearestNeighbor)  # Use nearest neighbor for label resampling
    else:
        resample.SetInterpolator(sitk.sitkLinear)  # Use linear interpolation for image resampling

    return resample.Execute(image)


In [6]:
# Directories to save resampled images and labels
resampled_images_path = 'F:\\Repositories\\FLARE22Train\\resampled-images' 
resampled_labels_path = 'F:\\Repositories\\FLARE22Train\\resampled-labels' 

os.makedirs(resampled_images_path, exist_ok=True)
os.makedirs(resampled_labels_path, exist_ok=True)

target_shape = (128, 128, 128)  # New uniform shape compatible with V-Net

for img_file, lbl_file in zip(image_files, label_files):
    img_path = os.path.join(images_path, img_file)
    lbl_path = os.path.join(labels_path, lbl_file)

    # Load image and label
    img = sitk.ReadImage(img_path)
    lbl = sitk.ReadImage(lbl_path)

    # Resample image and label
    img_resampled = resample_image(img, target_shape=target_shape, is_label=False)
    lbl_resampled = resample_image(lbl, target_shape=target_shape, is_label=True)

    # Save resampled images and labels
    sitk.WriteImage(img_resampled, os.path.join(resampled_images_path, img_file))
    sitk.WriteImage(lbl_resampled, os.path.join(resampled_labels_path, lbl_file))

print("Resampling and resizing complete.")


Resampling and resizing complete.


In [1]:
import os
import torch
import SimpleITK as sitk

# Assuming your images and labels are in the following directories
images_path = 'F:\\Repositories\\FLARE22Train\\images'  # Update with your actual path
labels_path = 'F:\\Repositories\\FLARE22Train\\labels'  # Update with your actual path

# Resampled directories
resampled_images_dir = 'F:\\Repositories\\FLARE22Train\\resampled_images'
resampled_labels_dir = 'F:\\Repositories\\FLARE22Train\\resampled_labels'

# Create directories if they don't exist
os.makedirs(resampled_images_dir, exist_ok=True)
os.makedirs(resampled_labels_dir, exist_ok=True)

# List all files in the directories
image_files = sorted(os.listdir(images_path))
label_files = sorted(os.listdir(labels_path))

print(f"Total number of images: {len(image_files)}")
print(f"Total number of labels: {len(label_files)}")

# Define the function to extract and remap labels
def extract_and_remap_labels(labels):
    new_labels = torch.zeros_like(labels)
    new_labels[labels == 1] = 1  # Liver
    new_labels[labels == 2] = 2  # Right Kidney
    new_labels[labels == 13] = 3  # Left Kidney
    new_labels[labels == 3] = 4  # Spleen
    return new_labels

# Function to resample images and labels
def resample_image(image, target_shape=(128, 128, 128), is_label=False):
    original_spacing = image.GetSpacing()
    original_size = image.GetSize()
    target_spacing = [
        (original_size[i] * original_spacing[i]) / target_shape[i] for i in range(3)
    ]

    resample = sitk.ResampleImageFilter()
    resample.SetOutputSpacing(target_spacing)
    resample.SetSize(target_shape)
    resample.SetOutputDirection(image.GetDirection())
    resample.SetOutputOrigin(image.GetOrigin())
    resample.SetTransform(sitk.Transform())

    if is_label:
        resample.SetInterpolator(sitk.sitkNearestNeighbor)
    else:
        resample.SetInterpolator(sitk.sitkLinear)

    return resample.Execute(image)

# Process each image and label pair
for img_file, lbl_file in zip(image_files, label_files):
    img_path = os.path.join(images_path, img_file)
    lbl_path = os.path.join(labels_path, lbl_file)

    image = sitk.ReadImage(img_path)
    label = sitk.ReadImage(lbl_path)

    label_tensor = torch.tensor(sitk.GetArrayFromImage(label).astype(int))
    remapped_labels_tensor = extract_and_remap_labels(label_tensor)

    remapped_label_image = sitk.GetImageFromArray(remapped_labels_tensor.numpy())
    remapped_label_image.CopyInformation(label)

    resampled_image = resample_image(image, target_shape=(128, 128, 128), is_label=False)
    resampled_label = resample_image(remapped_label_image, target_shape=(128, 128, 128), is_label=True)

    sitk.WriteImage(resampled_image, os.path.join(resampled_images_dir, img_file))
    sitk.WriteImage(resampled_label, os.path.join(resampled_labels_dir, lbl_file))

print("Processing and resampling complete.")


Total number of images: 50
Total number of labels: 50
Processing and resampling complete.


In [3]:
# Assuming your images and labels are in the following directories
re_images_path ='F:\\Repositories\\FLARE22Train\\resampled_images' # Update with your actual path
re_labels_path = 'F:\\Repositories\\FLARE22Train\\resampled_labels' # Update with your actual path

# List all files in the directories
image_files = sorted(os.listdir(images_path))
label_files = sorted(os.listdir(labels_path))

print(f"Total number of images: {len(image_files)}")
print(f"Total number of labels: {len(label_files)}")


Total number of images: 50
Total number of labels: 50


In [5]:
import os
import pandas as pd
import nibabel as nib

# Assuming your images and labels are in the following directories
re_images_path = 'F:\\Repositories\\FLARE22Train\\resampled_images'  # Update with your actual path
re_labels_path = 'F:\\Repositories\\FLARE22Train\\resampled_labels'  # Update with your actual path

# List all files in the directories
image_files = sorted(os.listdir(re_images_path))
label_files = sorted(os.listdir(re_labels_path))

# Create lists to store the dimensions of images and labels
image_shapes = []
label_shapes = []

# Loop through all image and label files to get their dimensions
for img_file, lbl_file in zip(image_files, label_files):
    img_path = os.path.join(re_images_path, img_file)
    lbl_path = os.path.join(re_labels_path, lbl_file)

    # Load the NIfTI files
    img = nib.load(img_path).get_fdata()
    lbl = nib.load(lbl_path).get_fdata()

    # Store the shapes
    image_shapes.append(img.shape)
    label_shapes.append(lbl.shape)

# Convert to a DataFrame for better visualization
re_df = pd.DataFrame({
    'Image File': image_files,
    'Image Shape': image_shapes,
    'Label File': label_files,
    'Label Shape': label_shapes
})

# Display the DataFrame
print(re_df)

# Get the unique shapes for images and labels
unique_image_shapes = re_df['Image Shape'].value_counts()
unique_label_shapes = re_df['Label Shape'].value_counts()

print("\nUnique Image Shapes and Their Counts:")
print(unique_image_shapes)

print("\nUnique Label Shapes and Their Counts:")
print(unique_label_shapes)


                     Image File      Image Shape              Label File  \
0   FLARE22_Tr_0001_0000.nii.gz  (128, 128, 128)  FLARE22_Tr_0001.nii.gz   
1   FLARE22_Tr_0002_0000.nii.gz  (128, 128, 128)  FLARE22_Tr_0002.nii.gz   
2   FLARE22_Tr_0003_0000.nii.gz  (128, 128, 128)  FLARE22_Tr_0003.nii.gz   
3   FLARE22_Tr_0004_0000.nii.gz  (128, 128, 128)  FLARE22_Tr_0004.nii.gz   
4   FLARE22_Tr_0005_0000.nii.gz  (128, 128, 128)  FLARE22_Tr_0005.nii.gz   
5   FLARE22_Tr_0006_0000.nii.gz  (128, 128, 128)  FLARE22_Tr_0006.nii.gz   
6   FLARE22_Tr_0007_0000.nii.gz  (128, 128, 128)  FLARE22_Tr_0007.nii.gz   
7   FLARE22_Tr_0008_0000.nii.gz  (128, 128, 128)  FLARE22_Tr_0008.nii.gz   
8   FLARE22_Tr_0009_0000.nii.gz  (128, 128, 128)  FLARE22_Tr_0009.nii.gz   
9   FLARE22_Tr_0010_0000.nii.gz  (128, 128, 128)  FLARE22_Tr_0010.nii.gz   
10  FLARE22_Tr_0011_0000.nii.gz  (128, 128, 128)  FLARE22_Tr_0011.nii.gz   
11  FLARE22_Tr_0012_0000.nii.gz  (128, 128, 128)  FLARE22_Tr_0012.nii.gz   
12  FLARE22_

In [7]:
import numpy as np
import SimpleITK as sitk

def normalize_ct_scan(image, lower_bound=-1000, upper_bound=400):
    """
    Normalize the intensity values of a CT scan by clipping and scaling.

    Parameters:
    - image: SimpleITK image object.
    - lower_bound: Minimum HU value to clip.
    - upper_bound: Maximum HU value to clip.

    Returns:
    - Normalized SimpleITK image object.
    """
    # Convert the SimpleITK image to a NumPy array
    image_array = sitk.GetArrayFromImage(image)

    # Clip the intensity values
    image_array = np.clip(image_array, lower_bound, upper_bound)

    # Normalize the values to the range [0, 1]
    image_array = (image_array - lower_bound) / (upper_bound - lower_bound)

    # Convert back to SimpleITK image
    normalized_image = sitk.GetImageFromArray(image_array)

    # Copy the metadata from the original image
    normalized_image.CopyInformation(image)

    return normalized_image


In [9]:
# Directories where resampled data is stored
resampled_images_path ='F:\\Repositories\\FLARE22Train\\resampled_images'
normalized_images_path ='F:\\Repositories\\FLARE22Train\\normalised_images'

os.makedirs(normalized_images_path, exist_ok=True)

for img_file in image_files:
    img_path = os.path.join(resampled_images_path, img_file)

    # Load the resampled image
    img = sitk.ReadImage(img_path)

    # Normalize the image
    normalized_img = normalize_ct_scan(img)

    # Save the normalized image
    sitk.WriteImage(normalized_img, os.path.join(normalized_images_path, img_file))

print("Normalization of resampled CT scans complete.")


Normalization of resampled CT scans complete.


In [11]:
# Assuming your images and labels are in the following directories
nprm_images_path = 'F:\\Repositories\\FLARE22Train\\normalised_images' # Update with your actual path
# re_labels_path = '/content/drive/My Drive/FLARE22Train/resampled_labels/'  # Update with your actual path

# List all files in the directories
image_files = sorted(os.listdir(images_path))
# label_files = sorted(os.listdir(labels_path))

print(f"Total number of images: {len(image_files)}")
# print(f"Total number of labels: {len(label_files)}")


Total number of images: 50


In [15]:
import os
import numpy as np
from sklearn.model_selection import train_test_split

# Paths to your normalized images and labels
normalized_images_path = 'F:\\Repositories\\FLARE22Train\\normalised_images'
labels_path = 'F:\\Repositories\\FLARE22Train\\resampled_labels'

# List all files in the directories
image_files = sorted(os.listdir(normalized_images_path))
label_files = sorted(os.listdir(labels_path))

# Print filenames to identify mismatches
for img_file, lbl_file in zip(image_files, label_files):
    # Remove the '_0000' suffix from image filenames
    expected_lbl_file = img_file.rsplit('_', 1)[0] + '.nii.gz'
    if expected_lbl_file != lbl_file:
        print(f"Mismatch: {img_file} -> {expected_lbl_file} (expected) vs {lbl_file} (actual)")

# Ensure image files match label files
assert len(image_files) == len(label_files)
assert all([img_file.rsplit('_', 1)[0] + '.nii.gz' == lbl_file for img_file, lbl_file in zip(image_files, label_files)])

# Split the dataset into training (70%), validation (15%), and test (15%) sets
train_imgs, test_imgs, train_lbls, test_lbls = train_test_split(image_files, label_files, test_size=0.3, random_state=42)
val_imgs, test_imgs, val_lbls, test_lbls = train_test_split(test_imgs, test_imgs, test_size=0.5, random_state=42)

print(f"Training set: {len(train_imgs)} images")
print(f"Validation set: {len(val_imgs)} images")
print(f"Test set: {len(test_imgs)} images")


Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at
the same time. Both libraries are known to be incompatible and this
can cause random crashes or deadlocks on Linux when loaded in the
same Python program.
Using threadpoolctl may cause crashes or deadlocks. For more
information and possible workarounds, please see
    https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md



Training set: 35 images
Validation set: 7 images
Test set: 8 images


In [17]:
import torch
from torch.utils.data import Dataset, DataLoader
import SimpleITK as sitk

class CTScanDataset(Dataset):
    def __init__(self, image_files, label_files, image_dir, label_dir, transform=None):
        self.image_files = image_files
        self.label_files = label_files
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Load image and label
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        lbl_path = os.path.join(self.label_dir, self.label_files[idx])

        image = sitk.ReadImage(img_path)
        label = sitk.ReadImage(lbl_path)

        # Convert to numpy arrays
        image = sitk.GetArrayFromImage(image).astype(np.float32)
        label = sitk.GetArrayFromImage(label).astype(np.int64)

        # Optional: Apply additional transformations here
        if self.transform:
            # Add your transform logic here
            image, label = self.transform(image, label)

        # Convert to PyTorch tensors
        image = torch.from_numpy(image).unsqueeze(0)  # Add channel dimension
        label = torch.from_numpy(label)

        return image, label

# Example usage
# Define directories and file lists
train_dataset = CTScanDataset(train_imgs, train_lbls, normalized_images_path, labels_path)
val_dataset = CTScanDataset(val_imgs, val_lbls, normalized_images_path, labels_path)
test_dataset = CTScanDataset(test_imgs, test_lbls, normalized_images_path, labels_path)


In [23]:
import torch
from torch.utils.data import DataLoader

# Parameters
batch_size = 1  # Choose based on your CPU memory capacity
num_workers = 0  # Set to 0 for debugging (loads data in the main process)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

# Example: Checking a batch from the training DataLoader
for i, (images, labels) in enumerate(train_loader):
    try:
        # Ensure tensors are on the CPU
        images = images.cpu()
        labels = labels.cpu()
        
        print(f"Batch {i+1}:")
        print(f"Image batch shape: {images.shape}")
        print(f"Label batch shape: {labels.shape}")
        break  # Only check the first batch
    except Exception as e:
        print(f"Error at batch {i+1}: {e}")


Batch 1:
Image batch shape: torch.Size([1, 1, 128, 128, 128])
Label batch shape: torch.Size([1, 128, 128, 128])


In [27]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class VNet(nn.Module):
    def __init__(self, num_classes=4):
        super(VNet, self).__init__()

        # Define the encoder (downsampling path)
        self.enc1 = self.conv_block(1, 16)
        self.enc2 = self.conv_block(16, 32)
        self.enc3 = self.conv_block(32, 64)
        self.enc4 = self.conv_block(64, 128)

        # Bottleneck
        self.bottleneck = self.conv_block(128, 256)

        # Define the decoder (upsampling path)
        self.upconv4 = self.upconv(256, 128)
        self.dec4 = self.conv_block(256, 128)
        self.upconv3 = self.upconv(128, 64)
        self.dec3 = self.conv_block(128, 64)
        self.upconv2 = self.upconv(64, 32)
        self.dec2 = self.conv_block(64, 32)
        self.upconv1 = self.upconv(32, 16)
        self.dec1 = self.conv_block(32, 16)

        # Output layer
        self.out_conv = nn.Conv3d(16, num_classes, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def upconv(self, in_channels, out_channels):
        return nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2)

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(F.max_pool3d(e1, 2))
        e3 = self.enc3(F.max_pool3d(e2, 2))
        e4 = self.enc4(F.max_pool3d(e3, 2))

        # Bottleneck
        b = self.bottleneck(F.max_pool3d(e4, 2))

        # Decoder
        d4 = self.upconv4(b)
        d4 = torch.cat((d4, e4), dim=1)
        d4 = self.dec4(d4)
        d3 = self.upconv3(d4)
        d3 = torch.cat((d3, e3), dim=1)
        d3 = self.dec3(d3)
        d2 = self.upconv2(d3)
        d2 = torch.cat((d2, e2), dim=1)
        d2 = self.dec2(d2)
        d1 = self.upconv1(d2)
        d1 = torch.cat((d1, e1), dim=1)
        d1 = self.dec1(d1)

        # Output
        out = self.out_conv(d1)
        return out

# Check if CUDA is available; if not, use CPU
device = torch.device("cpu")

# Initialize the model and move it to the appropriate device
model = VNet(num_classes=4).to(device)

# Example of model summary (optional)
print(model)

# Assuming you have some input data
input_data = torch.randn(1, 1, 64, 64, 64).to(device)  # Example input tensor

# Forward pass
output = model(input_data)

print(f"Output shape: {output.shape}")


VNet(
  (enc1): Sequential(
    (0): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (4): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (enc2): Sequential(
    (0): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (4): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (enc3): Sequential(
    (0): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): BatchNorm3d(64, eps=1e-0

In [41]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.nn.functional as F

# Define the VNet model
class VNet(nn.Module):
    def __init__(self, num_classes=4):
        super(VNet, self).__init__()

        # Define the encoder (downsampling path)
        self.enc1 = self.conv_block(1, 16)
        self.enc2 = self.conv_block(16, 32)
        self.enc3 = self.conv_block(32, 64)
        self.enc4 = self.conv_block(64, 128)

        # Bottleneck
        self.bottleneck = self.conv_block(128, 256)

        # Define the decoder (upsampling path)
        self.upconv4 = self.upconv(256, 128)
        self.dec4 = self.conv_block(256, 128)
        self.upconv3 = self.upconv(128, 64)
        self.dec3 = self.conv_block(128, 64)
        self.upconv2 = self.upconv(64, 32)
        self.dec2 = self.conv_block(64, 32)
        self.upconv1 = self.upconv(32, 16)
        self.dec1 = self.conv_block(32, 16)

        # Output layer
        self.out_conv = nn.Conv3d(16, num_classes, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def upconv(self, in_channels, out_channels):
        return nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2)

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(F.max_pool3d(e1, 2))
        e3 = self.enc3(F.max_pool3d(e2, 2))
        e4 = self.enc4(F.max_pool3d(e3, 2))

        # Bottleneck
        b = self.bottleneck(F.max_pool3d(e4, 2))

        # Decoder
        d4 = self.upconv4(b)
        d4 = torch.cat((d4, e4), dim=1)
        d4 = self.dec4(d4)
        d3 = self.upconv3(d4)
        d3 = torch.cat((d3, e3), dim=1)
        d3 = self.dec3(d3)
        d2 = self.upconv2(d3)
        d2 = torch.cat((d2, e2), dim=1)
        d2 = self.dec2(d2)
        d1 = self.upconv1(d2)
        d1 = torch.cat((d1, e1), dim=1)
        d1 = self.dec1(d1)

        # Output
        out = self.out_conv(d1)
        return out

# Define the Dice Loss
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-5):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, outputs, targets):
        outputs = torch.softmax(outputs, dim=1)
        targets_one_hot = F.one_hot(targets, num_classes=outputs.shape[1]).permute(0, 4, 1, 2, 3).float()
        
        intersection = (outputs * targets_one_hot).sum(dim=(2, 3, 4))
        union = outputs.sum(dim=(2, 3, 4)) + targets_one_hot.sum(dim=(2, 3, 4))
        dice = (2. * intersection + self.smooth) / (union + self.smooth)
        return 1 - dice.mean()

# Initialize the model
model = VNet(num_classes=4)

# Define the Cross-Entropy Loss and Dice Loss
cross_entropy_loss = nn.CrossEntropyLoss()
dice_loss = DiceLoss()

# Initialize the optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Learning rate scheduler to reduce the learning rate when a metric has stopped improving
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

# Number of epochs to train
num_epochs = 50

# Lists to store training and validation losses
train_loss_list = []
val_loss_list = []

# Best validation loss for checkpointing
best_val_loss = float('inf')

def check_and_adjust_labels(labels, num_classes):
    min_label = labels.min().item()
    max_label = labels.max().item()

    if min_label < 0 or max_label >= num_classes:
        print(f"Original label range: min {min_label}, max {max_label}")

        # Adjust labels if necessary
        labels = torch.clamp(labels, min=0, max=num_classes - 1)
        print(f"Adjusted label range: min {labels.min().item()}, max {labels.max().item()}")
    
    return labels

# Training loop
for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    running_loss = 0.0

    for images, labels in train_loader:
        images, labels = images.cpu(), labels.cpu()  # Move to CPU

        # Check and adjust labels
        labels = check_and_adjust_labels(labels, num_classes=4)
        
        # Forward pass
        outputs = model(images)

        # Compute the losses
        try:
            loss_dice = dice_loss(outputs, labels)
            loss_ce = cross_entropy_loss(outputs, labels)
            loss = loss_dice + loss_ce
        except RuntimeError as e:
            print(f"RuntimeError: {e}")
            print(f"Output shape: {outputs.shape}")
            print(f"Label shape: {labels.shape}")
            raise e

        # Backpropagation and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    epoch_loss = running_loss / len(train_loader)
    train_loss_list.append(epoch_loss)
    print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {epoch_loss:.4f}")

    # Validation step
    model.eval()  # Set the model to evaluation mode
    val_loss = 0.0
    with torch.no_grad():
        for val_images, val_labels in val_loader:
            val_images, val_labels = val_images.cpu(), val_labels.cpu()  # Move to CPU

            # Check and adjust labels
            val_labels = check_and_adjust_labels(val_labels, num_classes=4)
            
            val_outputs = model(val_images)

            val_loss_dice = dice_loss(val_outputs, val_labels)
            val_loss_ce = cross_entropy_loss(val_outputs, val_labels)
            val_loss += (val_loss_dice + val_loss_ce).item()

    val_loss /= len(val_loader)
    val_loss_list.append(val_loss)
    print(f"Validation Loss: {val_loss:.4f}")

    # Update the learning rate scheduler
    scheduler.step(val_loss)


Original label range: min 0, max 4
Adjusted label range: min 0, max 3
Original label range: min 0, max 4
Adjusted label range: min 0, max 3
Original label range: min 0, max 4
Adjusted label range: min 0, max 3
Original label range: min 0, max 4
Adjusted label range: min 0, max 3
Original label range: min 0, max 4
Adjusted label range: min 0, max 3
Original label range: min 0, max 4
Adjusted label range: min 0, max 3
Original label range: min 0, max 4
Adjusted label range: min 0, max 3
Original label range: min 0, max 4
Adjusted label range: min 0, max 3
Original label range: min 0, max 4
Adjusted label range: min 0, max 3
Original label range: min 0, max 4
Adjusted label range: min 0, max 3
Original label range: min 0, max 4
Adjusted label range: min 0, max 3
Original label range: min 0, max 4
Adjusted label range: min 0, max 3
Original label range: min 0, max 4
Adjusted label range: min 0, max 3
Original label range: min 0, max 4
Adjusted label range: min 0, max 3
Original label range

RuntimeError: Exception thrown in SimpleITK ImageFileReader_Execute: D:\a\1\sitk\Code\IO\src\sitkImageReaderBase.cxx:97:
sitk::ERROR: The file "F:\Repositories\FLARE22Train\resampled_labels\FLARE22_Tr_0004_0000.nii.gz" does not exist.