<a href="https://colab.research.google.com/github/sajidcsecu/radioGenomic/blob/main/UnetinGPU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip uninstall pydicom

Found existing installation: pydicom 3.0.1
Uninstalling pydicom-3.0.1:
  Would remove:
    /usr/local/bin/pydicom
    /usr/local/lib/python3.11/dist-packages/pydicom-3.0.1.dist-info/*
    /usr/local/lib/python3.11/dist-packages/pydicom/*
Proceed (Y/n)? y
  Successfully uninstalled pydicom-3.0.1


In [4]:
!pip install pydicom==2.4.3

Collecting pydicom==2.4.3
  Downloading pydicom-2.4.3-py3-none-any.whl.metadata (7.8 kB)
Downloading pydicom-2.4.3-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m15.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pydicom
Successfully installed pydicom-2.4.3


In [5]:
import pydicom
print(pydicom.__version__)

2.4.2


In [6]:
!pip install --upgrade pydicom-seg



In [1]:
import pydicom
import pydicom._storage_sopclass_uids
print("Fixed!")

Fixed!


In [3]:
!pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [4]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import pydicom
import pydicom_seg
import pydicom._storage_sopclass_uids
from torch.utils.data import Dataset,DataLoader
import os
from sklearn.model_selection import train_test_split
import torchvision.transforms as transforms
import SimpleITK as sitk
from torchinfo import summary
import torch.optim.lr_scheduler as lr_scheduler
from torch.cuda.amp import autocast, GradScaler
import csv
import time
import random
from tqdm import tqdm
from operator import add
from sklearn.metrics import accuracy_score, f1_score, jaccard_score, precision_score, recall_score
import cv2
sitk.ProcessObject_SetGlobalWarningDisplay(False)

# 1. Data Preperation

In [None]:

class PatientDataset2DUNet(Dataset):
    def __init__(self, patients, metadata, transform = None, train=True):
        """
        Args:
            patients (list): List of patient IDs.
            metadata (DataFrame): Metadata containing patient information.
            train (bool): If True, filters out empty slices.
        """
        self.patients = patients
        self.metadata = metadata
        self.transform = transform
        self.train = train
        self.slices = self._extract_slices()  # Store (img_path, mask_path, slice_idx) tuples

    def get_path(self, subject, modality):
        subject_filtered = subject[subject['Modality'] == modality]
        return subject_filtered['File Location'].iloc[0] if not subject_filtered.empty else None

    def _extract_slices(self):
        slices = []
        for patient in self.patients:
            print(f"Processing Patient: {patient}")
            subject = self.metadata[self.metadata['Subject ID'] == patient]

            img_path = self.get_path(subject, "CT")
            msk_path = self.get_path(subject, "SEG")

            if img_path and msk_path:
                img = self.read_ct_array(img_path)
                msk = self.read_seg_array(msk_path, "GTV-1")

                if img is not None and msk is not None:
                    image = sitk.GetArrayFromImage(img).astype(np.float32)
                    mask = sitk.GetArrayFromImage(msk).astype(np.float32)

                    min_slices = min(image.shape[0], mask.shape[0])
                    image, mask = image[:min_slices], mask[:min_slices]

                    # Vectorized filtering of empty slices
                    slice_indices = np.arange(min_slices) if not self.train else np.where(np.any(mask, axis=(1, 2)))[0]

                    # Store (image_path, mask_path, slice_index)
                    slices.extend(zip([img_path] * len(slice_indices), [msk_path] * len(slice_indices), slice_indices))

        return slices

    def __len__(self):
        return len(self.slices)

    def read_ct_array(self, path):
        reader = sitk.ImageSeriesReader()
        reader.SetImageIO("GDCMImageIO")
        reader.SetFileNames(reader.GetGDCMSeriesFileNames(path))
        return reader.Execute()

    def read_seg_array(self, path, seg_type="GTV-1"):
        try:
            segmentation = pydicom.dcmread(os.path.join(path, '1-1.dcm'))
            seg_df = pd.DataFrame({f: [s[f].value for s in segmentation.SegmentSequence] for f in ['SegmentNumber', 'SegmentDescription']})
            seg_number = seg_df.loc[seg_df['SegmentDescription'] == seg_type, 'SegmentNumber'].iloc[0]
            return pydicom_seg.SegmentReader().read(segmentation).segment_image(seg_number)
        except Exception as e:
            print(f"Error reading segmentation from {path}: {e}")
            return None

    def __getitem__(self, idx):
        img_path, mask_path, slice_idx = self.slices[idx]

        # Load the full volume but extract only one slice
        img = self.read_ct_array(img_path)
        msk = self.read_seg_array(mask_path, "GTV-1")

        if img is None or msk is None:
            print(f"Skipping invalid file: {img_path}, {mask_path}")
            return torch.zeros(1, 512, 512), torch.zeros(1, 512, 512)  # Return empty tensor

        image = sitk.GetArrayFromImage(img).astype(np.float32)
        mask = sitk.GetArrayFromImage(msk).astype(np.float32)

        # Extract relevant slice vectorized
        image, mask = image[slice_idx], mask[slice_idx]  # Shape: (H, W)

        # Vectorized normalization
        image = (image - image.min()) / max(image.max(), 1e-6)  # Avoid divide-by-zero

        # Convert to PyTorch tensor and add channel dimension (C=1)
        image, mask = map(lambda x: torch.from_numpy(x).unsqueeze(0), (image, mask))  # Shape: [1, H, W]
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask


if __name__ == "__main__":
    batch_size = 8
    """ Define Paths (OS-Independent) """
    output_dir = os.path.join(".", "Segmentation", "files")
    os.makedirs(output_dir, exist_ok=True)
    loss_result_file = os.path.join(output_dir, "results.csv")
    model_file = os.path.join(output_dir, "checkpoint.pth")
    test_result_path = os.path.join(output_dir, "results")

    """ Load Metadata """
    try:
        metadata_lung1 = pd.read_csv(os.path.join(".", "metadata", "metadata_lung1.csv"), sep=',', index_col=False)
    except FileNotFoundError:
        print("❌ ERROR: metadata_lung1.csv not found. Please check the file path.")
        exit(1)

    """ Load Patient List """
    patient_list_lung1 = metadata_lung1["Subject ID"].unique().tolist()

    """ Remove Error Patients Safely """
    error_patients = ['LUNG1-128']
    patient_list_lung1 = [p for p in patient_list_lung1 if p not in error_patients]

    """ Split Data into Train, Validation, and Test """
    train_patient, valid_patient = train_test_split(patient_list_lung1, test_size=0.1, random_state=42)
    train_patient, test_patient = train_test_split(train_patient, test_size=0.1, random_state=42)

    print(f"Patients: Train={len(train_patient)}, Valid={len(valid_patient)}, Test={len(test_patient)}")

    """ Load Dataset """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Loading Training Data...")
    train_dataset = PatientDataset2DUNet(train_patient, metadata_lung1, train=True)
    print("Loading Validation Data...")
    valid_dataset = PatientDataset2DUNet(valid_patient, metadata_lung1, train=False)
    print("Loading Test Data...")
    test_dataset = PatientDataset2DUNet(test_patient, metadata_lung1, train=False)

    """ Initialize Data Loaders """
    num_workers = min(4, os.cpu_count() // 2)  # Safe multiprocessing
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers,pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers,pin_memory=True)

    """ Print Dataset Sizes """
    print(f"Train: {len(train_dataset)}, Valid: {len(valid_dataset)}, Test: {len(test_dataset)}")

tensor([0.8994, 0.1209, 0.7563, 0.2896, 0.0101], dtype=torch.float64)

# 2. Unet

In [None]:
class conv_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_c)

        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_c)

        self.relu = nn.ReLU()

    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        return x


class encoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.conv = conv_block(in_c, out_c)
        self.pool = nn.MaxPool2d((2, 2))

    def forward(self, inputs):
        x = self.conv(inputs)
        p = self.pool(x)

        return x, p


class decoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
        self.conv = conv_block(out_c + out_c, out_c)

    def forward(self, inputs, skip):
        x = self.up(inputs)
        x = torch.cat([x, skip], axis=1)
        x = self.conv(x)
        return x


class UNet(nn.Module):
    def __init__(self, input_channel, output_channel, dropout=0.0):  # Added dropout parameter
        super().__init__()

        """ Encoder """
        self.e1 = encoder_block(input_channel, 64)
        self.e2 = encoder_block(64, 128)
        self.e3 = encoder_block(128, 256)
        self.e4 = encoder_block(256, 512)

        """ Bottleneck """
        self.b = conv_block(512, 1024)

        """ Decoder """
        self.d1 = decoder_block(1024, 512)
        self.d2 = decoder_block(512, 256)
        self.d3 = decoder_block(256, 128)
        self.d4 = decoder_block(128, 64)

        """ Classifier """
        self.outputs = nn.Conv2d(64, output_channel, kernel_size=1, padding=0)

        # Applying Dropout (if dropout > 0)
        self.dropout = nn.Dropout(p=dropout) if dropout > 0 else None

    def forward(self, inputs):
        """ Encoder """
        s1, p1 = self.e1(inputs)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)

        """ Bottleneck """
        b = self.b(p4)

        # Applying Dropout (if dropout > 0) after the Bottleneck
        if self.dropout:
            b = self.dropout(b)
        """ Decoder """
        d1 = self.d1(b, s4)
        d2 = self.d2(d1, s3)
        d3 = self.d3(d2, s2)
        d4 = self.d4(d3, s1)

        outputs = self.outputs(d4)

        return outputs


if __name__ == "__main__":
    # double_conv = DoubleConv(256, 256)
    # print(double_conv)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    input_image = torch.randn((1, 1, 512, 512), dtype=torch.float32)
    model = UNet(1, 1).to(device)
    input_image = input_image.to(device)
    out = model(input_image)
    print(out.shape)

## 2. Loss Function

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, preds, targets):
        """
        preds: predicted values (output of the model, probabilities)
        targets: ground truth values (binary masks)
        """
        # comment out if your model contains a sigmoid or equivalent activation layer
        preds = torch.sigmoid(preds)

        preds = preds.view(-1)  # Flatten
        targets = targets.view(-1)  # Flatten

        intersection = (preds * targets).sum()
        dice_score = (2. * intersection + self.smooth) / (preds.sum() + targets.sum() + self.smooth)

        return 1 - dice_score  # Dice Loss


class DiceBCELoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceBCELoss, self).__init__()
        self.smooth = smooth
        self.bce = nn.BCEWithLogitsLoss()  # ✅ FIXED: Use BCEWithLogitsLoss

    def forward(self, preds, targets):
        # ❌ DO NOT apply sigmoid here if using BCEWithLogitsLoss
        preds = preds.view(-1)  # Flatten
        targets = targets.view(-1)  # Flatten

        # Dice Loss Calculation
        intersection = (torch.sigmoid(preds) * targets).sum()
        dice_loss = 1 - (2. * intersection + self.smooth) / (torch.sigmoid(preds).sum() + targets.sum() + self.smooth)

        # BCE Loss Calculation (Now numerically stable)
        bce_loss = self.bce(preds, targets)

        return bce_loss + dice_loss  # Combined Loss
        bce_loss = self.bce(preds, targets)

        return bce_loss + dice_loss  # Combined Loss


torch.Size([1, 1, 512, 512])


# 3. Test

In [None]:
class UnetTest:

    def calculate_metrics(self, y_true, y_pred):
        """ Convert tensors to NumPy and reshape """
        y_true = y_true.detach().cpu().numpy().astype(np.uint8)
        y_pred = y_pred.detach().cpu().numpy().astype(np.uint8)

        y_true = y_true > 0.5
        y_pred = y_pred > 0.5

        y_true = y_true.reshape(-1)
        y_pred = y_pred.reshape(-1)

        """ Compute Metrics (Handling Class Imbalance) """
        score_jaccard = jaccard_score(y_true, y_pred, zero_division=1)
        score_f1 = f1_score(y_true, y_pred, zero_division=1)
        score_recall = recall_score(y_true, y_pred, zero_division=1)
        score_precision = precision_score(y_true, y_pred, zero_division=1)
        score_acc = accuracy_score(y_true, y_pred)

        return [score_jaccard, score_f1, score_recall, score_precision, score_acc]

    def save_result(self, test_result_path, image, org_mask, predicted_mask, sample_id):
        """ Ensure directory exists """
        if not os.path.exists(test_result_path):
            os.makedirs(test_result_path, exist_ok=True)

        """ Convert tensors to NumPy & normalize """
        predicted_mask = predicted_mask.cpu().numpy().squeeze()
        org_mask = org_mask.cpu().numpy().squeeze()
        image = image.cpu().numpy().squeeze()

        """ Ensure correct format for OpenCV """
        predicted_mask = (predicted_mask > 0.5).astype(np.uint8) * 255
        org_mask = (org_mask > 0.5).astype(np.uint8) * 255
        image = (image * 255).astype(np.uint8)

        """ Dynamically get height & width """
        h, w = image.shape
        line = np.ones((h, 10), dtype=np.uint8) * 128

        """ Concatenate Image, Ground Truth & Prediction """
        cat_images = np.concatenate([image, line, org_mask, line, predicted_mask], axis=1)

        """ Save Image """
        file_name = os.path.join(test_result_path, f"sample_{sample_id}.png")
        success = cv2.imwrite(file_name, cat_images)

        """ Debugging Output """
        if success:
            print(f"✅ Saved: {file_name}")
        else:
            print(f"❌ Failed to save image: {file_name}")

    def test(self, model, test_loader, test_result_path, device):
        """ Initialize Metrics & Time Tracking """
        metrics_score = np.zeros(5)  # [Jaccard, F1, Recall, Precision, Accuracy]
        time_taken = []
        model.eval()

        with torch.no_grad():
            for pid, (x, y) in enumerate(test_loader):
                x = x.to(device, dtype=torch.float32)
                y = y.to(device, dtype=torch.float32)

                """ Measure Inference Time """
                start_time = time.time()
                y_pred = torch.sigmoid(model(x))  # Apply sigmoid for binary segmentation
                total_time = time.time() - start_time
                time_taken.append(total_time)

                """ Compute Metrics """
                score = self.calculate_metrics(y, y_pred)
                metrics_score = list(map(add, metrics_score, score))

                """ Save Results for Each Sample """
                for idx in range(x.size(0)):
                    sample_id = pid * x.size(0) + idx
                    self.save_result(test_result_path, x[idx, :, :, :], y[idx, :, :, :], y_pred[idx, :, :, :], sample_id)

        """ Compute Average Scores """
        num_samples = len(test_loader)
        avg_metrics = np.array(metrics_score) / num_samples
        print(f"Total Images in Test Set: {num_samples}")
        print(f"Jaccard: {avg_metrics[0]:.4f} - F1: {avg_metrics[1]:.4f} - Recall: {avg_metrics[2]:.4f} - Precision: {avg_metrics[3]:.4f} - Acc: {avg_metrics[4]:.4f}")

        """ Compute FPS """
        fps = 1 / np.mean(time_taken)
        print("FPS:", fps)


# 4. Training

In [None]:
class EarlyStopping:
    def __init__(self, patience=10, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = float('inf')
        self.counter = 0

    def __call__(self, val_loss):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1

        if self.counter >= self.patience:
            print(f"⛔ Early stopping triggered after {self.patience} epochs without improvement!")
            return True
        return False

class UnetTrain:
    def seeding(self, seed):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True

    def epoch_time(self, start_time, end_time):
        elapsed_time = end_time - start_time
        elapsed_mins = int(elapsed_time / 60)
        elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
        return elapsed_mins, elapsed_secs

    """ Training """

    def train(self, model, loader, optimizer, loss_fn, device):
        epoch_loss = 0.0
        model.train()
        for x, y in loader:
            x = x.to(device, dtype=torch.float32)
            y = y.to(device, dtype=torch.float32)

            optimizer.zero_grad()
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        epoch_loss = epoch_loss / len(loader)
        return epoch_loss

    """Validation """
    def evaluate(self, model, loader, loss_fn, device):
        epoch_loss = 0.0
        model.eval()
        with torch.no_grad():
            for x, y in loader:
                x = x.to(device, dtype=torch.float32)
                y = y.to(device, dtype=torch.float32)
                y_pred = model(x)
                loss = loss_fn(y_pred, y)
                epoch_loss += loss.item()

        return epoch_loss / len(loader)

    def execute(self, num_epochs, lr, train_loader, valid_loader, model_file, loss_result_path, device):
        model = UNet(input_channel=1, output_channel=1, dropout=0.3).to(device)
        optimizer = torch.optim.AdamW(model.parameters(), lr, weight_decay=1e-5)
        scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
        loss_fn = DiceBCELoss()

        early_stopping = EarlyStopping(patience=10, min_delta=0.001)
        best_valid_loss = float("inf")
        results = {"train_loss": [], "valid_loss": []}

        for epoch in tqdm(range(num_epochs)):
            start_time = time.time()

            train_loss = self.train(model, train_loader, optimizer, loss_fn, device)
            valid_loss = self.evaluate(model, valid_loader, loss_fn, device)

            if valid_loss < best_valid_loss:
                print(f"✅ Valid loss improved from {best_valid_loss:.4f} to {valid_loss:.4f}. Saving checkpoint.")
                best_valid_loss = valid_loss
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': best_valid_loss,
                }, model_file)

            end_time = time.time()
            epoch_mins, epoch_secs = self.epoch_time(start_time, end_time)

            results["train_loss"].append(train_loss)
            results["valid_loss"].append(valid_loss)
            print(f"Epoch {epoch+1}: Time: {epoch_mins}m {epoch_secs}s, Train Loss: {train_loss:.3f}, Val Loss: {valid_loss:.3f}")

            if early_stopping(valid_loss):
                print("🛑 Stopping training early due to no improvement.")
                break

        with open(loss_result_path, "w", newline="") as file:
            writer = csv.writer(file)
            for key, value in results.items():
                writer.writerow([key, value])

if __name__ == "__main__":
    ut = UnetTrain()

    """ Seeding for Reproducibility """
    ut.seeding(42)

    """ Hyperparameters """
    batch_size = 8
    num_epochs = 1
    lr = 1e-4

    """ Define Paths (OS-Independent) """
    output_dir = os.path.join(".", "Segmentation", "files")
    os.makedirs(output_dir, exist_ok=True)
    loss_result_file = os.path.join(output_dir, "results.csv")
    model_file = os.path.join(output_dir, "checkpoint.pth")
    test_result_path = os.path.join(output_dir, "results")

    """ Load Metadata """
    try:
        metadata_lung1 = pd.read_csv(os.path.join(".", "metadata", "metadata_lung1.csv"), sep=',', index_col=False)
    except FileNotFoundError:
        print("❌ ERROR: metadata_lung1.csv not found. Please check the file path.")
        exit(1)

    """ Load Patient List """
    patient_list_lung1 = metadata_lung1["Subject ID"].unique().tolist()[:10]

    """ Remove Error Patients Safely """
    error_patients = ['LUNG1-128']
    patient_list_lung1 = [p for p in patient_list_lung1 if p not in error_patients]

    """ Split Data into Train, Validation, and Test """
    train_patient, valid_patient = train_test_split(patient_list_lung1, test_size=0.1, random_state=42)
    train_patient, test_patient = train_test_split(train_patient, test_size=0.1, random_state=42)

    print(f"Patients: Train={len(train_patient)}, Valid={len(valid_patient)}, Test={len(test_patient)}")

    """ Load Dataset """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Loading Training Data...")
    train_dataset = PatientDataset2DUNet(train_patient, metadata_lung1, train=True)
    print("Loading Validation Data...")
    valid_dataset = PatientDataset2DUNet(valid_patient, metadata_lung1, train=False)
    print("Loading Test Data...")
    test_dataset = PatientDataset2DUNet(test_patient, metadata_lung1, train=False)

    """ Initialize Data Loaders """
    num_workers = min(4, os.cpu_count() // 2)  # Safe multiprocessing
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers,pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers,pin_memory=True)

    """ Print Dataset Sizes """
    print(f"Train: {len(train_dataset)}, Valid: {len(valid_dataset)}, Test: {len(test_dataset)}")

    """ Train Model """
    ut.execute(num_epochs, lr, train_loader, valid_loader, model_file, loss_result_file, device)

    """ Load Best Model for Testing """
    model = UNet(1, 1).to(device)
    checkpoint = torch.load(model_file, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])

    """ Model Summary """
    print(summary(model, input_size=(1, 1, 512, 512)))

    """ Test Model """
    utest = UnetTest()
    utest.test(model, test_loader, test_result_path, device)


## Data Preperation (for simulated data)

In [None]:
import pandas as pd
import numpy as np
from torch.utils.data import Dataset,DataLoader
import torch
import os
from sklearn.model_selection import train_test_split
import torchvision.transforms as transforms

class PatientDataset2DUNet(Dataset):
    def __init__(self, image, mask, transform = None, train=True):
        self.image = image
        self.mask = mask
        self.transform = transform
        self.train = train
        self.n_samples = len(self.image)

    def __getitem__(self, index):
        # img = torch.unsqueeze(self.image[index,:,:], 0)
        # msk = torch.unsqueeze(self.mask[index,:,:], 0)
      # Convert to PyTorch tensor and add channel dimension (C=1)
        img, msk = map(lambda x: torch.from_numpy(x).unsqueeze(0), (self.image[index,:,:], self.mask[index,:,:]))  # Shape: [1, H, W]
        if self.transform:
            img = self.transform(img)
            msk = self.transform(msk)

        return img,msk

    def __len__(self):
        return self.n_samples



if __name__ == "__main__":


    batch_size = 4

    # transform = transforms.Compose([
    #     transforms.RandomRotation(15),
    #     transforms.RandomHorizontalFlip(),
    #     transforms.RandomVerticalFlip(),
    #     transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    #     transforms.ColorJitter(brightness=0.2, contrast=0.2),  # Useful for MRI/CT images
    # ])
    train_image = np.random.rand(8,512,512)
    # print(train_image)
    train_mask = np.random.randint(2, size=(8,512,512))
    # print(train_mask)
    valid_image = np.random.rand(1,512,512)
    valid_mask = np.random.randint(2, size=(1,512,512))
    test_image = np.random.rand(1,512,512)
    test_mask = np.random.randint(2, size=(1,512,512))


    print("Training Loading...")
    train_dataset = PatientDataset2DUNet(train_image, train_mask, train=True)
    print("Valid Loading...")
    valid_dataset = PatientDataset2DUNet(valid_image, valid_mask, train=False)
    print("Testing Loading...")
    test_dataset = PatientDataset2DUNet(test_image, test_mask, train=False)
    # #
    # # # train_dataset, valid_dataset = torch.utils.data.random_split(patient_dataset, [0.8, 0.2])
    # # #
    # # #
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0
    )
    #
    valid_loader = DataLoader(
        dataset=valid_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0
    )

    #
    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0
    )
    # Fetch one batch
    for images, masks in train_loader:
        print(f"Batch Image Shape: {images.shape}")  # Expected: [8, 1, H, W]
        print(f"Batch Mask Shape: {masks.shape}")  # Expected: [8, 1, H, W]
        break
    for image ,mask in test_loader:
        print("Shape of image : ", image.shape)
        print("Shape of mask : ", mask.shape)

    for image ,mask in valid_loader:
        print("Shape of image : ", image.shape)
        print("Shape of mask : ", mask.shape)
    print(f"Total images in Training Dataset: {len(train_dataset)}")
    print(f"Total images in Valid Dataset: {len(valid_dataset)}")
    print(f"Total images in Testing Dataset: {len(test_dataset)}")



Training Loading...
Valid Loading...
Testing Loading...
Batch Image Shape: torch.Size([4, 1, 512, 512])
Batch Mask Shape: torch.Size([4, 1, 512, 512])
Shape of image :  torch.Size([1, 1, 512, 512])
Shape of mask :  torch.Size([1, 1, 512, 512])
Shape of image :  torch.Size([1, 1, 512, 512])
Shape of mask :  torch.Size([1, 1, 512, 512])
Total images in Training Dataset: 8
Total images in Valid Dataset: 1
Total images in Testing Dataset: 1


In [None]:
!pip install torchinfo



In [None]:
if __name__ == "__main__":
    ut = UnetTrain()

    """ Seeding for Reproducibility """
    ut.seeding(42)

    """ Hyperparameters """
    batch_size = 2
    num_epochs = 1
    lr = 1e-4

    # """ Define Paths (OS-Independent) """
    output_dir = os.path.join(".", "Segmentation", "files")
    os.makedirs(output_dir, exist_ok=True)
    loss_result_file = os.path.join(output_dir, "results.csv")
    model_file = os.path.join(output_dir, "checkpoint.pth")
    test_result_path = os.path.join(output_dir, "results")

    # """ Load Metadata """
    # try:
    #     metadata_lung1 = pd.read_csv(os.path.join(".", "metadata", "metadata_lung1.csv"), sep=',', index_col=False)
    # except FileNotFoundError:
    #     print("❌ ERROR: metadata_lung1.csv not found. Please check the file path.")
    #     exit(1)

    # """ Load Patient List """
    # patient_list_lung1 = metadata_lung1["Subject ID"].unique().tolist()[:10]

    # """ Remove Error Patients Safely """
    # error_patients = ['LUNG1-128']
    # patient_list_lung1 = [p for p in patient_list_lung1 if p not in error_patients]

    # """ Split Data into Train, Validation, and Test """
    # train_patient, valid_patient = train_test_split(patient_list_lung1, test_size=0.1, random_state=42)
    # train_patient, test_patient = train_test_split(train_patient, test_size=0.1, random_state=42)

    # print(f"Patients: Train={len(train_patient)}, Valid={len(valid_patient)}, Test={len(test_patient)}")

    # """ Load Dataset """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    train_image = np.random.rand(8,512,512)
    # print(train_image)
    train_mask = np.random.randint(2, size=(8,512,512))
    # print(train_mask)
    valid_image = np.random.rand(1,512,512)
    valid_mask = np.random.randint(2, size=(1,512,512))
    test_image = np.random.rand(1,512,512)
    test_mask = np.random.randint(2, size=(1,512,512))


    print("Training Loading...")
    train_dataset = PatientDataset2DUNet(train_image, train_mask, train=True)
    print("Valid Loading...")
    valid_dataset = PatientDataset2DUNet(valid_image, valid_mask, train=False)
    print("Testing Loading...")
    test_dataset = PatientDataset2DUNet(test_image, test_mask, train=False)

    # print("Loading Training Data...")
    # train_dataset = PatientDataset2DUNet(train_patient, metadata_lung1, train=True)
    # print("Loading Validation Data...")
    # valid_dataset = PatientDataset2DUNet(valid_patient, metadata_lung1, train=False)
    # print("Loading Test Data...")
    # test_dataset = PatientDataset2DUNet(test_patient, metadata_lung1, train=False)

    """ Initialize Data Loaders """
    num_workers = min(4, os.cpu_count()//2)  # Safe multiprocessing
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    """ Print Dataset Sizes """
    print(f"Train: {len(train_dataset)}, Valid: {len(valid_dataset)}, Test: {len(test_dataset)}")

    """ Train Model """
    ut.execute(num_epochs, lr, train_loader, valid_loader, model_file, loss_result_file, device)

    """ Load Best Model for Testing """
    model = UNet(1, 1).to(device)
    checkpoint = torch.load(model_file, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])

    """ Model Summary """
    print(summary(model, input_size=(1, 1, 512, 512)))

    """ Test Model """
    utest = UnetTest()
    utest.test(model, test_loader, test_result_path, device)

Training Loading...
Valid Loading...
Testing Loading...
Train: 8, Valid: 1, Test: 1


  0%|          | 0/1 [00:00<?, ?it/s]

✅ Valid loss improved from inf to 1.1794. Saving checkpoint.


100%|██████████| 1/1 [00:07<00:00,  7.23s/it]

Epoch 1: Time: 0m 7s, Train Loss: 1.184, Val Loss: 1.179





Layer (type:depth-idx)                   Output Shape              Param #
UNet                                     [1, 1, 512, 512]          --
├─encoder_block: 1-1                     [1, 64, 512, 512]         --
│    └─conv_block: 2-1                   [1, 64, 512, 512]         --
│    │    └─Conv2d: 3-1                  [1, 64, 512, 512]         640
│    │    └─BatchNorm2d: 3-2             [1, 64, 512, 512]         128
│    │    └─ReLU: 3-3                    [1, 64, 512, 512]         --
│    │    └─Conv2d: 3-4                  [1, 64, 512, 512]         36,928
│    │    └─BatchNorm2d: 3-5             [1, 64, 512, 512]         128
│    │    └─ReLU: 3-6                    [1, 64, 512, 512]         --
│    └─MaxPool2d: 2-2                    [1, 64, 256, 256]         --
├─encoder_block: 1-2                     [1, 128, 256, 256]        --
│    └─conv_block: 2-3                   [1, 128, 256, 256]        --
│    │    └─Conv2d: 3-7                  [1, 128, 256, 256]        73,856
│   