In [1]:
# In a Jupyter cell, run this to install the necessary packages
!pip install nnunetv2

Collecting argparse (from unittest2->batchgenerators>=0.25.1->nnunetv2)
  Using cached argparse-1.4.0-py2.py3-none-any.whl.metadata (2.8 kB)
Using cached argparse-1.4.0-py2.py3-none-any.whl (23 kB)
Installing collected packages: argparse
Successfully installed argparse-1.4.0
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [2]:
!pip install SimpleITK pandas tqdm

# (Optional but recommended) Install hiddenlayer for network architecture plots
!pip install --upgrade git+https://github.com/FabianIsensee/hiddenlayer.git

# Set up the required nnU-Net environment variables.
# These paths tell nnU-Net where to find raw data, preprocessed data, and trained models.
import os

# Create directories for the project
os.makedirs("./nnUNet_raw", exist_ok=True)
os.makedirs("./nnUNet_preprocessed", exist_ok=True)
os.makedirs("./nnUNet_results", exist_ok=True)
os.makedirs("./data", exist_ok=True) # Assuming your data is here
os.makedirs("./my_custom_nnunet", exist_ok=True) # For our custom code

# Set the environment variables
os.environ['nnUNet_raw'] = os.path.abspath("./nnUNet_raw")
os.environ['nnUNet_preprocessed'] = os.path.abspath("./nnUNet_preprocessed")
os.environ['nnUNet_results'] = os.path.abspath("./nnUNet_results")

# IMPORTANT: Add our custom code directory to the Python path
# This allows nnU-Net to find our custom trainer and model
import sys
sys.path.append(os.path.abspath("./my_custom_nnunet"))

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
Collecting git+https://github.com/FabianIsensee/hiddenlayer.git
  Cloning https://github.com/FabianIsensee/hiddenlayer.git to /tmp/pip-req-build-4vwfdiyp
  Running command git clone --filter=blob:none --quiet https://github.com/FabianIsensee/hiddenlayer.git /tmp/pip-req-build-4vwfdiyp
  Resolved https://github.com/FabianIsensee/hiddenlayer.git to commit b7263b6dc4569da1b6dea5964e1eac78fa32fa77
  Preparing metadata (setup.py) ... [?25ldone
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
[?25h

In [5]:
#Dataset Pre-processing cell
# Conversion of dataset
import os
import shutil
import json
from collections import OrderedDict
from tqdm import tqdm
from pathlib import Path

# --------------------
# Resolve dataset root
# --------------------

# Find the absolute path to "data" folder even if notebook runs from subfolder
#For Google COLAB /content/MyDrive/...
base_dir = (Path(__file__).parent if "__file__" in globals() else Path.cwd()) / "Data"

if not base_dir.exists():
    raise FileNotFoundError(f"Could not find data folder at {base_dir}")

print("Using dataset root:", base_dir)

# nnU-Net environment
nnunet_raw_dir = Path(os.environ['nnUNet_raw'])
dataset_id = 501
dataset_name = f"Dataset{dataset_id:03d}_Pancreas"
task_dir = nnunet_raw_dir / dataset_name

# Create nnU-Net dataset directories
images_tr_dir = task_dir / 'imagesTr'
labels_tr_dir = task_dir / 'labelsTr'
images_ts_dir = task_dir / 'imagesTs'

images_tr_dir.mkdir(parents=True, exist_ok=True)
labels_tr_dir.mkdir(parents=True, exist_ok=True)
images_ts_dir.mkdir(parents=True, exist_ok=True)

# --- Process Training and Validation Data ---
all_files = []
for split in ['train', 'validation']:
    split_dir = base_dir / split
    if not split_dir.exists():
        raise FileNotFoundError(f"Missing split folder: {split_dir}")
    for subtype_folder in split_dir.iterdir():
        if subtype_folder.is_dir() and 'subtype' in subtype_folder.name:
            subtype = int(subtype_folder.name.replace('subtype', ''))
            for f in subtype_folder.iterdir():
                all_files.append({
                    "path": f,
                    "subtype": subtype
                })

# Create a dictionary to store classification labels
classification_labels = {}
num_training_cases = 0

print("Processing training & validation sets...")
for file_info in tqdm(all_files):
    file_path = file_info['path']
    subtype = file_info['subtype']
    
    if '_0000.nii.gz' in file_path.name:  # It's an image file
        case_id = file_path.name.split('_0000.nii.gz')[0]
        new_name = f"{case_id}_0000.nii.gz"
        shutil.copy(file_path, images_tr_dir / new_name)

        classification_labels[case_id] = subtype
        num_training_cases += 1

    elif file_path.suffixes == ['.nii', '.gz'] and '_0000' not in file_path.stem:
        case_id = file_path.name.replace('.nii.gz', '')
        new_name = f"{case_id}.nii.gz"
        shutil.copy(file_path, labels_tr_dir / new_name)

# Save classification labels
with open(task_dir / 'classification_labels.json', 'w') as f:
    json.dump(classification_labels, f, indent=4)

# --- Process Test Data ---
print("\nProcessing test set...")
test_dir = base_dir / 'test'
if test_dir.exists():
    for f in tqdm(test_dir.iterdir()):
        if f.suffixes == ['.nii', '.gz']:
            shutil.copy(f, images_ts_dir / f.name)
else:
    print("⚠️ No test set found, skipping.")

# --- Create dataset.json ---
print("\nCreating dataset.json...")
dataset_json = OrderedDict()
dataset_json['channel_names'] = {"0": "CT"}
dataset_json['labels'] = {"background": 0, "pancreas": 1, "lesion": 2}
dataset_json['num_classification_classes'] = 3  # Subtypes 0,1,2
dataset_json['numTraining'] = num_training_cases
dataset_json['file_ending'] = ".nii.gz"

with open(task_dir / 'dataset.json', 'w') as f:
    json.dump(dataset_json, f, indent=4)

print(f"\nData preparation complete for {dataset_name} at {task_dir}")


Using dataset root: /workspace/Data
Processing training & validation sets...


100%|██████████| 576/576 [00:15<00:00, 37.06it/s]



Processing test set...


72it [00:03, 22.25it/s]


Creating dataset.json...

Data preparation complete for Dataset501_Pancreas at /workspace/nnUNet_raw/Dataset501_Pancreas





In [6]:
# nnUNet from Github
# pip install https://github.com/MIC-DKFZ/nnUNet.git
#nnUNet/nnunetv2/training/nnUNetTrainer/

In [7]:

%%writefile ./my_custom_nnunet/multitask_network.py


import torch
import torch.nn as nn
import torch.nn.functional as F

# Simple 3D UNet backbone for demonstration
class UNetBackbone(nn.Module):
    def __init__(self, in_channels=1, base_channels=32):
        super().__init__()
        self.enc1 = nn.Sequential(nn.Conv3d(in_channels, base_channels, 3, padding=1),
                                  nn.ReLU(),
                                  nn.Conv3d(base_channels, base_channels, 3, padding=1),
                                  nn.ReLU())
        self.pool = nn.MaxPool3d(2)
        self.enc2 = nn.Sequential(nn.Conv3d(base_channels, base_channels*2, 3, padding=1),
                                  nn.ReLU(),
                                  nn.Conv3d(base_channels*2, base_channels*2, 3, padding=1),
                                  nn.ReLU())

        self.center = nn.Sequential(nn.Conv3d(base_channels*2, base_channels*4, 3, padding=1),
                                    nn.ReLU(),
                                    nn.Conv3d(base_channels*4, base_channels*4, 3, padding=1),
                                    nn.ReLU())

        self.up2 = nn.ConvTranspose3d(base_channels*4, base_channels*2, kernel_size=2, stride=2)
        self.dec2 = nn.Sequential(nn.Conv3d(base_channels*4, base_channels*2, 3, padding=1),
                                  nn.ReLU(),
                                  nn.Conv3d(base_channels*2, base_channels*2, 3, padding=1),
                                  nn.ReLU())

        self.up1 = nn.ConvTranspose3d(base_channels*2, base_channels, kernel_size=2, stride=2)
        self.dec1 = nn.Sequential(nn.Conv3d(base_channels*2, base_channels, 3, padding=1),
                                  nn.ReLU(),
                                  nn.Conv3d(base_channels, base_channels, 3, padding=1),
                                  nn.ReLU())

    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool(enc1))
        center = self.center(self.pool(enc2))

        dec2 = self.dec2(torch.cat([self.up2(center), enc2], dim=1))
        dec1 = self.dec1(torch.cat([self.up1(dec2), enc1], dim=1))
        return dec1


# Multi-task UNet: segmentation + classification
class UNet_MultiTask(nn.Module):
    def __init__(self, in_channels=1, num_classes_seg=2, num_classes_clf=2):
        super().__init__()
        self.backbone = UNetBackbone(in_channels)
        base_channels = 32

        # Segmentation head
        self.seg_head = nn.Conv3d(base_channels, num_classes_seg, kernel_size=1)

        # Classification head (global average pooling -> linear)
        self.clf_head = nn.Sequential(
            nn.AdaptiveAvgPool3d(1),
            nn.Flatten(),
            nn.Linear(base_channels, num_classes_clf)
        )

    def forward(self, x):
        features = self.backbone(x)
        seg_out = self.seg_head(features)
        clf_out = self.clf_head(features)
        return seg_out, clf_out



Writing ./my_custom_nnunet/multitask_network.py


In [None]:
#nnUNet/nnunetv2/training/nnUNetTrainer/

In [8]:
%%writefile ./my_custom_nnunet/nnUNetTrainer_MultiTask.py


import torch
import torch.nn as nn
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from .multitask_network import UNet_MultiTask
from nnunetv2.training.loss.compound_losses import DC_and_CE_loss

class nnUNetTrainer_MultiTask(nnUNetTrainer):
    def initialize_network(self):
        """Initialize multi-task UNet"""
        self.network = UNet_MultiTask(
            in_channels=self.num_input_channels,
            num_classes_seg=self.num_classes,
            num_classes_clf=2  # Change based on your dataset
        ).to(self.device)

    def compute_loss(self, x, y):
        """Compute combined loss for segmentation + classification"""
        # Expect y = (seg_target, clf_target)
        seg_target, clf_target = y
        seg_pred, clf_pred = self.network(x)

        seg_loss = DC_and_CE_loss(seg_pred, seg_target)
        clf_loss = nn.CrossEntropyLoss()(clf_pred, clf_target)

        total_loss = seg_loss + clf_loss
        return total_loss


Writing ./my_custom_nnunet/nnUNetTrainer_MultiTask.py


In [15]:
#COnverison from float to int
import nibabel as nib
import numpy as np
from pathlib import Path

labels_dir = Path("/workspace/nnUNet_raw/Dataset501_Pancreas/labelsTr") 

for file in labels_dir.glob("*.nii.gz"):
    img = nib.load(str(file))
    data = img.get_fdata()

    # Round floats to nearest int and cast
    data = np.rint(data).astype(np.int16)

    # Verify unique labels
    unique = np.unique(data)
    if not set(unique).issubset({0, 1, 2}):
        print(f"⚠️ Warning: {file.name} has unexpected labels {unique}")

    # Save back with same affine/header
    new_img = nib.Nifti1Image(data, img.affine, img.header)
    nib.save(new_img, str(file))

print("✅ All labels fixed to integer values {0,1,2}")


✅ All labels fixed to integer values {0,1,2}


In [16]:
#COnverison from float to int64

import nibabel as nib
import numpy as np
from pathlib import Path

labels_dir = Path("/workspace/nnUNet_raw/Dataset501_Pancreas/labelsTr")

for file in labels_dir.glob("*.nii.gz"):
    img = nib.load(str(file))
    data = img.get_fdata()
    
    # Round to nearest integer
    data = np.round(data).astype(np.int64)
    
    # Clip any possible out-of-range values just in case
    data = np.clip(data, 0, 2)
    
    # Verify unique labels
    unique = np.unique(data)
    if not set(unique).issubset({0, 1, 2}):
        print(f"⚠️ Warning: {file.name} has unexpected labels {unique}")
    
    # Save back with correct dtype
    new_img = nib.Nifti1Image(data, img.affine, img.header)
    nib.save(new_img, str(file))

print("✅ All labels fixed to exact integers 0, 1, 2 (np.int64)")


✅ All labels fixed to exact integers 0, 1, 2 (np.int64)


In [17]:
# In a Jupyter cell
!nnUNetv2_plan_and_preprocess -d {dataset_id} --verify_dataset_integrity

Fingerprint extraction...
Dataset501_Pancreas
Using <class 'nnunetv2.imageio.simpleitk_reader_writer.SimpleITKIO'> as reader/writer

####################
verify_dataset_integrity Done. 
If you didn't see any error messages then your dataset is most likely OK!
####################

Using <class 'nnunetv2.imageio.simpleitk_reader_writer.SimpleITKIO'> as reader/writer
100%|█████████████████████████████████████████| 288/288 [00:09<00:00, 31.97it/s]
Experiment planning...

############################
INFO: You are using the old nnU-Net default planner. We have updated our recommendations. Please consider using those instead! Read more here: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/resenc_presets.md
############################

Dropping 3d_lowres config because the image size difference to 3d_fullres is too small. 3d_fullres: [ 59.  117.  180.5], 3d_lowres: [59, 117, 180]
2D U-Net configuration:
{'data_identifier': 'nnUNetPlans_2d', 'preprocessor_name': 'DefaultPreproce

In [18]:
print(dataset_id)

501


In [22]:
# In a Jupyter cell
# Note: Training all 5 folds is recommended for best performance and ensembling.
# Here we train fold 0 as an example.
# To train all folds, you would run this command in a loop for fold in [0, 1, 2, 3, 4].
# 
#!nnUNetv2_train {dataset_id} 3d_fullres 0 -tr nnUNetTrainer_MultiTask --npz -num_epochs 10

!nnUNet_trainer_class_dir='/workspace/my_custom_nnunet' nnUNetv2_train {dataset_id} 3d_fullres 0 -tr nnUNetTrainer_MultiTask --npz

# !nnUNetv2_train 1 3d_fullres 0 -tr nnUNetTrainer_10epochs --num_gpus 1



############################
INFO: You are using the old nnU-Net default plans. We have updated our recommendations. Please consider using those instead! Read more here: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/resenc_presets.md
############################

Traceback (most recent call last):
  File "/usr/local/bin/nnUNetv2_train", line 8, in <module>
    sys.exit(run_training_entry())
             ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/nnunetv2/run/run_training.py", line 266, in run_training_entry
    run_training(args.dataset_name_or_id, args.configuration, args.fold, args.tr, args.p, args.pretrained_weights,
  File "/usr/local/lib/python3.11/dist-packages/nnunetv2/run/run_training.py", line 192, in run_training
    nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, trainer_class_name,
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/loca

In [24]:
##trying t fx above cell here

import subprocess
import os

# Create a copy of the current environment variables
env = os.environ.copy()

# Set the path to the directory containing your custom trainer class
env['nnUNet_trainer_class_dir'] = '/workspace/my_custom_nnunet'

# Define the command as a list of strings
command = [
    'nnUNetv2_train',
    '501',  # Replace with {dataset_id} if it's a variable
    '3d_fullres',
    '0',
    '-tr',
    'nnUNetTrainer_MultiTask',
    '--npz'
]

# Run the command with the modified environment
try:
    subprocess.run(command, env=env, check=True)
    print("✅ Training command executed successfully.")
except subprocess.CalledProcessError as e:
    print(f"❌ An error occurred while running the command: {e}")


############################
INFO: You are using the old nnU-Net default plans. We have updated our recommendations. Please consider using those instead! Read more here: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/resenc_presets.md
############################



Traceback (most recent call last):
  File "/usr/local/bin/nnUNetv2_train", line 8, in <module>
    sys.exit(run_training_entry())
             ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/nnunetv2/run/run_training.py", line 266, in run_training_entry
    run_training(args.dataset_name_or_id, args.configuration, args.fold, args.tr, args.p, args.pretrained_weights,
  File "/usr/local/lib/python3.11/dist-packages/nnunetv2/run/run_training.py", line 192, in run_training
    nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, trainer_class_name,
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/nnunetv2/run/run_training.py", line 42, in get_trainer_from_args
    raise RuntimeError(f'Could not find requested nnunet trainer {trainer_name} in '
RuntimeError: Could not find requested nnunet trainer nnUNetTrainer_MultiTask in nnunetv2.training.

❌ An error occurred while running the command: Command '['nnUNetv2_train', '501', '3d_fullres', '0', '-tr', 'nnUNetTrainer_MultiTask', '--npz']' returned non-zero exit status 1.


In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from pathlib import Path

# Define the path to the trained model directory
model_dir = Path(os.environ['nnUNet_results']) / dataset_name / "nnUNetTrainer_MultiTask__nnUNetPlans__3d_fullres"

# Find the training plot for fold 0
progress_png_path = model_dir / "fold_0" / "progress.png"

if progress_png_path.exists():
    print(f"Displaying training graph from: {progress_png_path}")
    plt.figure(figsize=(15, 8))
    img = mpimg.imread(progress_png_path)
    plt.imshow(img)
    plt.axis('off')
    plt.show()
else:
    print(f"❌ Could not find training graph at {progress_png_path}")

In [None]:
import re
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
from pathlib import Path

def parse_nnunet_log(log_file_path):
    """
    Parses an nnU-Net v2 training log file to extract key metrics per epoch.
    """
    epoch_data = []
    
    with open(log_file_path, 'r') as f:
        lines = f.readlines()

    current_epoch = -1
    epoch_metrics = {}

    # Regex patterns to find the data we need
    epoch_pattern = re.compile(r"Epoch (\d+)")
    train_loss_pattern = re.compile(r"train_loss ([\-\d\.]+)")
    val_loss_pattern = re.compile(r"val_loss ([\-\d\.]+)")
    ema_dice_pattern = re.compile(r"New best EMA pseudo Dice: ([\d\.]+)")

    for line in lines:
        epoch_match = epoch_pattern.search(line)
        if epoch_match:
            # When we find a new epoch, save the previous one's data
            if current_epoch != -1 and 'train_loss' in epoch_metrics:
                epoch_data.append(epoch_metrics)
            
            current_epoch = int(epoch_match.group(1))
            epoch_metrics = {'epoch': current_epoch}
            # Carry over the last known EMA Dice
            if epoch_data:
                epoch_metrics['ema_pseudo_dice'] = epoch_data[-1].get('ema_pseudo_dice')

        train_loss_match = train_loss_pattern.search(line)
        if train_loss_match:
            epoch_metrics['train_loss'] = float(train_loss_match.group(1))

        val_loss_match = val_loss_pattern.search(line)
        if val_loss_match:
            epoch_metrics['val_loss'] = float(val_loss_match.group(1))

        ema_dice_match = ema_dice_pattern.search(line)
        if ema_dice_match:
            epoch_metrics['ema_pseudo_dice'] = float(ema_dice_match.group(1))

    # Append the last epoch's data
    if 'train_loss' in epoch_metrics:
        epoch_data.append(epoch_metrics)
        
    return pd.DataFrame(epoch_data)

# --- Main Execution ---

# **CORRECTED PART**: Use the full path to the log file
# We build the path dynamically from the environment variables set earlier.
dataset_id = 501
dataset_name = f"Dataset{dataset_id:03d}_Pancreas"
model_folder_path = Path(os.environ['nnUNet_results']) / dataset_name / "nnUNetTrainer_MultiTask__nnUNetPlans__3d_fullres" / "fold_0"
log_filename = 'training_log_2025_9_9_22_44_55.txt'
LOG_FILE = model_folder_path / log_filename

if not LOG_FILE.exists():
    print(f"❌ ERROR: Log file not found at the specified path: {LOG_FILE}")
    print("Please verify the path and filename are correct.")
else:
    print(f"✅ Found log file: {LOG_FILE}")
    log_df = parse_nnunet_log(LOG_FILE)

    # --- Plotting ---
    sns.set_theme(style="whitegrid")
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), sharex=True)

    # Plot Training and Validation Loss
    ax1.plot(log_df['epoch'], log_df['train_loss'], 'o-', label='Train Loss', color='b')
    ax1.plot(log_df['epoch'], log_df['val_loss'], 'o-', label='Validation Loss', color='r')
    ax1.set_ylabel('Loss (Dice + CE)')
    ax1.set_title('Training and Validation Loss per Epoch')
    ax1.legend()
    ax1.grid(True)

    # Plot EMA Pseudo Dice
    ax2.plot(log_df['epoch'], log_df['ema_pseudo_dice'], 'o-', label='EMA Pseudo Dice', color='g')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('EMA Pseudo Dice Score')
    ax2.set_title('Validation EMA Pseudo Dice Score per Epoch')
    ax2.legend()
    ax2.grid(True)

    plt.tight_layout()
    plt.show()

    # Print the final metrics from the dataframe
    print("\n--- Parsed Metrics Summary ---")
    print(log_df.tail())

In [None]:

#aved weights to /mnt/data/weights_epoch50.pth and loss plot to /mnt/data/loss.png")



import torch
import torch.nn as nn
import os
import nibabel as nib
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
import matplotlib.pyplot as plt

# --- Ensure previous definitions are available ---
# Make sure the UNet2D class definition from your training script is in a previous cell.
# If not, you must redefine it here.
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
        )
    def forward(self,x): return self.net(x)

class UNet2D(nn.Module):
    def __init__(self, in_ch=1, base_ch=16, n_classes=3, n_subtypes=3):
        super().__init__()
        self.down1 = DoubleConv(in_ch, base_ch)
        self.pool = nn.MaxPool2d(2)
        self.down2 = DoubleConv(base_ch, base_ch*2)
        self.up1 = nn.ConvTranspose2d(base_ch*2, base_ch, 2, stride=2)
        self.conv_up = DoubleConv(base_ch*2, base_ch)
        self.seg_head = nn.Conv2d(base_ch, n_classes, 1)
        self.cls_pool = nn.AdaptiveAvgPool2d(1)
        self.cls_head = nn.Linear(base_ch, n_subtypes)
    def forward(self, x):
        d1 = self.down1(x)
        p1 = self.pool(d1)
        d2 = self.down2(p1)
        u1 = self.up1(d2)
        cat = torch.cat([u1, d1], dim=1)
        up = self.conv_up(cat)
        seg = self.seg_head(up)
        pooled = self.cls_pool(up).view(up.size(0), -1)
        cls = self.cls_head(pooled)
        return seg, cls

# --- New Dataset Class for Inference ---
# This dataset handles resizing to prevent the RuntimeError.
# It only deals with images, as test sets don't have masks.

# Define the resizing transform
infer_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)), # Resize to a fixed size, e.g., 224x224
    transforms.ToTensor(),
])

class InferenceSliceDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform
        self.slices = []  # List of (image_path, slice_idx)
        for img_p in self.image_paths:
            try:
                nii_img = nib.load(img_p)
                nz = nii_img.shape[2]
                for s in range(nz):
                    self.slices.append((img_p, s))
            except Exception as e:
                print(f"Warning: Could not read {img_p}. Error: {e}")

    def __len__(self):
        return len(self.slices)

    def __getitem__(self, idx):
        img_p, s_idx = self.slices[idx]
        img_data = nib.load(img_p).get_fdata()[:, :, s_idx].astype(np.float32)
        
        # Normalize image
        img_data = (img_data - img_data.mean()) / (img_data.std() + 1e-8)
        
        # Apply the transform
        if self.transform:
            # The transform expects a PIL image, so we need to convert the format
            img_uint8 = ((img_data - img_data.min()) / (img_data.max() - img_data.min() + 1e-8) * 255).astype(np.uint8)
            img_tensor = self.transform(img_uint8)
        else:
            img_tensor = torch.from_numpy(img_data[np.newaxis, :, :])

        return img_tensor, img_p, s_idx

# --- Main Inference Logic ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1. Load your trained model
model = UNet2D(in_ch=1, base_ch=16).to(device)
# weights_path = '/mnt/data/weights_epoch50.pth'
weights_path = '/nnUNet_results/Dataset501_Pancreas/nnUNetTrainer_MultiTask__nnUNetPlans__3d_fullres/fold_0/checkpoint_best.pth'
if os.path.exists(weights_path):
    model.load_state_dict(torch.load(weights_path))
    print(f"Successfully loaded weights from {weights_path}")
else:
    print(f"Error: Weights file not found at {weights_path}. Please ensure training was completed.")

model.eval()

# 2. Prepare the test data
test_data_root = 'data/test'
test_imgs = sorted([os.path.join(test_data_root, f) for f in os.listdir(test_data_root) if f.endswith('_0000.nii.gz')])
test_ds = InferenceSliceDataset(test_imgs, transform=infer_transform)
test_loader = DataLoader(test_ds, batch_size=16, shuffle=False, num_workers=0)

# 3. Run inference
results = {}
with torch.no_grad():
    for x, paths, slice_indices in tqdm(test_loader, desc="Running Inference"):
        x = x.to(device).float()
        seg_logits, cls_logits = model(x)
        
        # Get predictions for the batch
        seg_preds = torch.argmax(seg_logits, dim=1).cpu().numpy()
        cls_preds = torch.argmax(cls_logits, dim=1).cpu().numpy()
        
        # Store results slice by slice
        for i in range(len(paths)):
            path = paths[i]
            s_idx = slice_indices[i].item()
            
            if path not in results:
                original_img = nib.load(path)
                results[path] = {
                    'seg_volume': np.zeros(original_img.shape, dtype=np.uint8),
                    'cls_votes': [],
                    'affine': original_img.affine,
                    'header': original_img.header
                }
            
            # The segmentation needs to be resized back to its original dimensions
            # Since this is complex, we will save the resized segmentation for now.
            # A more advanced pipeline would handle this.
            # For now, we will just store the classification votes.
            results[path]['cls_votes'].append(cls_preds[i])

# 4. Process and save results
output_dir = "results"
os.makedirs(output_dir, exist_ok=True)
classification_results = []

print("\nProcessing and saving results...")
for path, data in results.items():
    # Final classification is the most frequent vote across all slices
    final_subtype = np.bincount(data['cls_votes']).argmax()
    
    # Filename for submission
    base_name = os.path.basename(path).replace('_0000.nii.gz', '.nii.gz')
    classification_results.append({'Names': base_name, 'Subtype': final_subtype})
    
    # For now, we are skipping saving the segmentation masks as it requires
    # a reverse-resize operation which complicates the script. The primary
    # goal here is to get classification results without the RuntimeError.

# 5. Save classification CSV
csv_path = os.path.join(output_dir, 'subtype_results.csv')
df = pd.DataFrame(classification_results)
df.to_csv(csv_path, index=False)

print(f"✅ Inference complete. Classification results saved to {csv_path}")
print(df.head())

In [None]:
import nibabel as nib
import matplotlib.pyplot as plt
import random
import os
from pathlib import Path
import numpy as np

# --- Add these lines to define the necessary variables ---
# Define the folder where your test images are located
TEST_FOLDER = Path("./data/test") 
# Define the folder where your segmentation results were saved
OUTPUT_FOLDER = Path("./results") 
# Get the list of test files
test_files = sorted(list(TEST_FOLDER.glob('*.nii.gz')))
# --- End of added lines ---

# Select a random test image and its corresponding predicted segmentation
if not test_files:
    print("❌ No test files found in the specified directory.")
else:
    random_test_file = random.choice(test_files)
    image_path = random_test_file
    seg_path = OUTPUT_FOLDER / random_test_file.name.replace('_0000.nii.gz', '.nii.gz')

    print(f"Visualizing image: {image_path.name}")
    print(f"Segmentation: {seg_path.name}")

    if not seg_path.exists():
        print("❌ Segmentation file not found. Skipping visualization.")
    else:
        # Load the NIfTI files
        img_nib = nib.load(image_path)
        img_data = img_nib.get_fdata()
        seg_nib = nib.load(seg_path)
        seg_data = seg_nib.get_fdata()

        # Find a good slice to display (one with a segmentation)
        slice_indices = np.where(np.sum(seg_data, axis=(0, 1)) > 0)[0]
        if len(slice_indices) > 0:
            mid_slice = slice_indices[len(slice_indices) // 2]
        else:
            mid_slice = img_data.shape[2] // 2

        # Plot
        fig, axes = plt.subplots(1, 2, figsize=(12, 6))
        
        axes[0].imshow(np.rot90(img_data[:, :, mid_slice]), cmap='gray')
        axes[0].set_title(f"Original Image (Slice {mid_slice})")
        axes[0].axis('off')

        axes[1].imshow(np.rot90(img_data[:, :, mid_slice]), cmap='gray')
        # Use a masked array to only show non-zero labels
        seg_masked = np.ma.masked_where(seg_data[:, :, mid_slice] == 0, seg_data[:, :, mid_slice])
        axes[1].imshow(np.rot90(seg_masked), alpha=0.6, cmap='viridis') # 'viridis' shows classes in different colors
        axes[1].set_title("Segmentation Overlay")
        axes[1].axis('off')
        
        plt.tight_layout()
        plt.show()

In [None]:
!zip -r -q results.zip ./results
print("✅ Created results.zip containing all test segmentations and the subtype_results.csv file.")

In [None]:
# Inference script using nnUNetPredictor (v2 compatible)
import torch
import SimpleITK as sitk
import numpy as np
import pandas as pd
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
from tqdm import tqdm
import os

# --- Configuration ---
DATASET_ID = 501
TEST_FOLDER = './data/test'       # Folder containing your test NIfTI files
OUTPUT_FOLDER = './results'       # Folder to save segmentations
SUBMISSION_CSV = os.path.join(OUTPUT_FOLDER, 'subtype_results.csv')

os.makedirs(OUTPUT_FOLDER, exist_ok=True)

# Find the trained model folder
model_folder = os.path.join(
    os.environ['nnUNet_results'],
    f"Dataset{DATASET_ID:03d}_Pancreas",
    "nnUNetTrainer_MultiTask__nnUNetPlans__3d_fullres"
)
print(f"Using model from: {model_folder}")

# --- Initialize Predictor ---
predictor = nnUNetPredictor(
    tile_step_size=0.5,
    use_gaussian=True,
    use_mirroring=True,
    perform_everything_on_device=True,
    device=torch.device('cuda'),
    verbose=False,
    verbose_preprocessing=False,
    allow_tqdm=True
)

predictor.initialize_from_trained_model_folder(
    model_folder,
    use_folds=(0,),  # change if you trained multiple folds
    checkpoint_name='checkpoint_final.pth'
)

# --- Run Inference Loop ---
test_files = [f for f in os.listdir(TEST_FOLDER) if f.endswith('.nii.gz')]
classification_results = []

print("Starting inference...")

for f in tqdm(test_files):
    input_file = os.path.join(TEST_FOLDER, f)
    
    # Run sequential inference
    ret = predictor.predict_from_files_sequential(
        [[input_file]],  # list of lists
        OUTPUT_FOLDER,
        save_probabilities=False,
        overwrite=True,
        folder_with_segs_from_prev_stage=None
    )

    # Check if segmentation was returned
    if len(ret) == 0:
        print(f"No segmentation produced for {f}, skipping...")
        classification_results.append({
            'Names': f,
            'Subtype': -1  # or another placeholder for missing prediction
        })
        continue

    seg_data = ret[0]
    
    # Save segmentation mask
    if isinstance(seg_data, np.ndarray):
        import SimpleITK as sitk
        seg_itk = sitk.GetImageFromArray(seg_data)
        img_sitk = sitk.ReadImage(input_file)
        seg_itk.CopyInformation(img_sitk)
        seg_file = os.path.join(OUTPUT_FOLDER, f.replace('_0000', ''))
        sitk.WriteImage(seg_itk, seg_file)
    else:
        seg_file = os.path.join(OUTPUT_FOLDER, f.replace('_0000', ''))

    # Extract classification
    if hasattr(seg_data, 'get') and 'logits' in seg_data:
        cls_logits = seg_data['logits']
        predicted_subtype = int(torch.argmax(cls_logits, dim=1).cpu())
    else:
        predicted_subtype = -1

    classification_results.append({
        'Names': os.path.basename(seg_file),
        'Subtype': predicted_subtype
    })


# --- Save Classification CSV ---
df = pd.DataFrame(classification_results)
df.to_csv(SUBMISSION_CSV, index=False)

print(f"\nInference complete! Results saved in {OUTPUT_FOLDER}")
print(f"Classification CSV saved at: {SUBMISSION_CSV}")
df.head()


In [None]:
import nibabel as nib
import matplotlib.pyplot as plt
import numpy as np
import os

output_dir = "/home/usama/uw_akash2/nnUNet_results/Dataset501_Pancreas/nnUNetTrainer_MultiTask__nnUNetPlans__3d_fullres/inference_results"

# Use the single output file
seg_file = os.path.join(output_dir, "quiz_.nii.gz")

seg_img = nib.load(seg_file)
seg_data = seg_img.get_fdata()

# If you have the original image to overlay
input_file = "/home/usama/uw_akash2/data/test/quiz_037_0000.nii.gz"
img_nib = nib.load(input_file)
img_data = img_nib.get_fdata()

# Visualization (middle slice)
mid_slice = img_data.shape[2] // 2

plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
plt.imshow(img_data[:, :, mid_slice], cmap='gray')
plt.title("Original Image (slice {})".format(mid_slice))
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(img_data[:, :, mid_slice], cmap='gray')
plt.imshow(seg_data[:, :, mid_slice], alpha=0.5, cmap='jet')
plt.title("Segmentation Overlay")
plt.axis('off')

plt.show()
